gensbi.models.flux1joint#
This package provides the Flux1Joint transformer-based model and related loss functions for simulation-based inference. The architecture is derived from the following foundational work:
Gloeckler et al. “All-in-one simulation-based inference.” arXiv:2404.09636
Submodules#
Classes#
Flux1Joint model for joint density estimation. |
|
Parameters for the Flux1Joint model. |
Package Contents#
- class gensbi.models.flux1joint.Flux1Joint(params)[source]#
Bases:
flax.nnx.ModuleFlux1Joint model for joint density estimation.
- Parameters:
params (Flux1JointParams) – Parameters for the Flux1Joint model.
- __call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
node_ids (jax.Array)
condition_mask (jax.Array)
guidance (jax.Array | None)
edge_mask (Optional[jax.Array])
- Return type:
jax.Array
- condition_embedding#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- pe_embedder#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.flux1joint.Flux1JointParams[source]#
Parameters for the Flux1Joint model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- condition_dim: list[int]#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- joint_dim: int#
- mlp_ratio: float#
- num_heads: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 10000#
- vec_in_dim: int | None#