gensbi.models.flux1.model#

Classes#

Flux1

Transformer model for flow matching on sequences.

Flux1Params

Parameters for the Flux1 model.

Module Contents#

class gensbi.models.flux1.model.Flux1(params)[source]#

Bases: flax.nnx.Module

Transformer model for flow matching on sequences.

Parameters:

params (Flux1Params)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Return type:

jax.Array

cond_in[source]#
condition_embedding[source]#
condition_null[source]#
double_blocks[source]#
final_layer[source]#
hidden_size[source]#
in_channels[source]#
num_heads[source]#
obs_in[source]#
out_channels[source]#
params[source]#
pe_embedder[source]#
qkv_features[source]#
single_blocks[source]#
time_in[source]#
vector_in[source]#
class gensbi.models.flux1.model.Flux1Params[source]#

Parameters for the Flux1 model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int][source]#
cond_dim: int[source]#
context_in_dim: int[source]#
depth: int[source]#
depth_single_blocks: int[source]#
guidance_embed: bool = False[source]#
in_channels: int[source]#
mlp_ratio: float[source]#
num_heads: int[source]#
obs_dim: int[source]#
param_dtype: jax.typing.DTypeLike[source]#
qkv_bias: bool[source]#
rngs: flax.nnx.Rngs[source]#
theta: int = 10000[source]#
vec_in_dim: int | None[source]#