gensbi.models.flux1#
Submodules#
Classes#
Transformer model for flow matching on sequences. |
|
Parameters for the Flux1 model. |
Package Contents#
- class gensbi.models.flux1.Flux1(params)[source]#
Bases:
flax.nnx.ModuleTransformer model for flow matching on sequences.
- Parameters:
params (Flux1Params)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Return type:
jax.Array
- cond_in#
- condition_embedding#
- condition_null#
- double_blocks#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- pe_embedder#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.flux1.Flux1Params[source]#
Parameters for the Flux1 model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- cond_dim: int#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- obs_dim: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 10000#
- vec_in_dim: int | None#