gensbi.models.flux1joint.model#
Classes#
Flux1Joint model for joint density estimation. |
|
Parameters for the Flux1Joint model. |
Module Contents#
- class gensbi.models.flux1joint.model.Flux1Joint(params)[source]#
Bases:
flax.nnx.ModuleFlux1Joint model for joint density estimation.
- Parameters:
params (Flux1JointParams) – Parameters for the Flux1Joint model.
- __call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
node_ids (jax.Array)
condition_mask (jax.Array)
guidance (jax.Array | None)
edge_mask (Optional[jax.Array])
- Return type:
jax.Array
- class gensbi.models.flux1joint.model.Flux1JointParams[source]#
Parameters for the Flux1Joint model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.