gensbi.models.flux1.model#
Classes#
Transformer model for flow matching on sequences. |
|
Parameters for the Flux model. |
|
This class is used to wrap around another model. We define a call method which returns the model output. |
|
Base class for all neural network modules. |
Module Contents#
- class gensbi.models.flux1.model.Flux(params)[source]#
Bases:
flax.nnx.Module
Transformer model for flow matching on sequences.
- Parameters:
params (FluxParams)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Return type:
jax.Array
- class gensbi.models.flux1.model.FluxParams[source]#
Parameters for the Flux model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
guidance_embed (bool) – Whether to use guidance embedding.
qkv_multiplier (int) – Multiplier for QKV features.
param_dtype (DTypeLike) – Data type for model parameters.
- class gensbi.models.flux1.model.FluxWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapper
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with any additional keyword arguments.
- Optional things to do here:
check that t is in the dimensions that the model is expecting.
add a custom forward pass logic.
call the wrapped model.
given obs, treturns the model output for input obs at time t, with extra information extra.- Parameters:
obs (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
**extras – additional information forwarded to the model, e.g., text condition.
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Returns:
model output.
- Return type:
Array
- class gensbi.models.flux1.model.Identity[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)