gensbi.models.flux1#

Submodules#

Classes#

Flux1

Transformer model for flow matching on sequences.

Flux1Params

Parameters for the Flux1 model.

Package Contents#

class gensbi.models.flux1.Flux1(params)[source]#

Bases: flax.nnx.Module

Transformer model for flow matching on sequences.

Parameters:

params (Flux1Params)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Return type:

jax.Array

cond_in#
condition_embedding#
condition_null#
double_blocks#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
pe_embedder#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.flux1.Flux1Params[source]#

Parameters for the Flux1 model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
cond_dim: int#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
rngs: flax.nnx.Rngs#
theta: int = 10000#
vec_in_dim: int | None#