gensbi.models.flux1joint#

This package provides the Flux1Joint transformer-based model and related loss functions for simulation-based inference. The architecture is derived from the following foundational work:

Submodules#

Classes#

Flux1Joint

Flux1Joint model for joint density estimation.

Flux1JointParams

Parameters for the Flux1Joint model.

Package Contents#

class gensbi.models.flux1joint.Flux1Joint(params)[source]#

Bases: flax.nnx.Module

Flux1Joint model for joint density estimation.

Parameters:

params (Flux1JointParams) – Parameters for the Flux1Joint model.

__call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • node_ids (jax.Array)

  • condition_mask (jax.Array)

  • guidance (jax.Array | None)

  • edge_mask (Optional[jax.Array])

Return type:

jax.Array

condition_embedding#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
pe_embedder#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.flux1joint.Flux1JointParams[source]#

Parameters for the Flux1Joint model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • guidance_embed (bool) – Whether to use guidance embedding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
condition_dim: list[int]#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
joint_dim: int#
mlp_ratio: float#
num_heads: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
rngs: flax.nnx.Rngs#
theta: int = 10000#
vec_in_dim: int | None#