gensbi.models.flux1.model#
Classes#
Transformer model for flow matching on sequences. |
|
Parameters for the Flux1 model. |
Module Contents#
- class gensbi.models.flux1.model.Flux1(params)[source]#
Bases:
flax.nnx.ModuleTransformer model for flow matching on sequences.
- Parameters:
params (Flux1Params)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Return type:
jax.Array
- class gensbi.models.flux1.model.Flux1Params[source]#
Parameters for the Flux1 model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
param_dtype (DTypeLike) – Data type for model parameters.