gensbi.models#

Submodules#

Classes#

ConditionalCFMLoss

ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.

ConditionalDiffLoss

ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.

ConditionalWrapper

Wrapper for conditional models to handle input expansion and calling convention.

Flux1

Transformer model for flow matching on sequences.

Flux1Joint

Flux1Joint model for joint density estimation.

Flux1JointParams

Parameters for the Flux1Joint model.

Flux1Params

Parameters for the Flux1 model.

JointCFMLoss

JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.

JointDiffLoss

JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.

JointWrapper

Wrapper for joint models to handle both conditioned and unconditioned inference.

Simformer

Simformer model for joint density estimation.

SimformerParams

Parameters for the Simformer model.

UnconditionalCFMLoss

UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.

UnconditionalDiffLoss

UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.

UnconditionalWrapper

Wrapper for unconditional models to handle input expansion and calling convention.

Package Contents#

class gensbi.models.ConditionalCFMLoss(path, reduction='mean', cfg_scale=None)[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.

Parameters:
  • path – Probability path (x-prediction training).

  • reduction (str, optional) – Specify the reduction to apply to the output 'none' | 'mean' | 'sum'. 'none': no reduction is applied to the output, 'mean': the output is reduced by mean over sequence elements, 'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.

__call__(vf, batch, cond, obs_ids, cond_ids)[source]#

Evaluates the continuous flow matching loss.

Parameters:
  • vf (callable) – The vector field model to evaluate.

  • batch (tuple) – A tuple containing the input data (x_0, x_1, t).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

The computed loss.

Return type:

jnp.ndarray

cfg_scale = None#
class gensbi.models.ConditionalDiffLoss(path)[source]#

Bases: flax.nnx.Module

ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, cond, obs_ids, cond_ids)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.ConditionalWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for conditional models to handle input expansion and calling convention.

Parameters:

model – The conditional model instance to wrap.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#

Call the wrapped model with expanded inputs.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • conditioned (bool | Array, optional) – Whether to use conditioning. Defaults to True.

  • guidance (Array | None, optional) – Optional guidance input.

Returns:

Model output.

Return type:

Array

class gensbi.models.Flux1(params)[source]#

Bases: flax.nnx.Module

Transformer model for flow matching on sequences.

Parameters:

params (Flux1Params)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Return type:

jax.Array

cond_in#
condition_embedding#
condition_null#
double_blocks#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
pe_embedder#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.Flux1Joint(params)[source]#

Bases: flax.nnx.Module

Flux1Joint model for joint density estimation.

Parameters:

params (Flux1JointParams) – Parameters for the Flux1Joint model.

__call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • node_ids (jax.Array)

  • condition_mask (jax.Array)

  • guidance (jax.Array | None)

  • edge_mask (Optional[jax.Array])

Return type:

jax.Array

condition_embedding#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
pe_embedder#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.Flux1JointParams[source]#

Parameters for the Flux1Joint model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • guidance_embed (bool) – Whether to use guidance embedding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
condition_dim: list[int]#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
joint_dim: int#
mlp_ratio: float#
num_heads: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
rngs: flax.nnx.Rngs#
theta: int = 10000#
vec_in_dim: int | None#
class gensbi.models.Flux1Params[source]#

Parameters for the Flux1 model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
cond_dim: int#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
rngs: flax.nnx.Rngs#
theta: int = 10000#
vec_in_dim: int | None#
class gensbi.models.JointCFMLoss(path, reduction='mean')[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.

Parameters:
  • path – Probability path for training.

  • reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).

__call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • vf (Callable) – Vector field model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

class gensbi.models.JointDiffLoss(path)[source]#

Bases: flax.nnx.Module

JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.JointWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for joint models to handle both conditioned and unconditioned inference.

Parameters:
  • model – The joint model instance to wrap.

  • conditioned (bool, optional) – Whether to use conditioning by default. Defaults to True.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, **kwargs)[source]#

Call the wrapped model for either conditioned or unconditioned inference.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • conditioned (bool, optional) – Whether to use conditioning. If None, uses the default set at initialization.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Model output.

Return type:

Array

conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs)[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Conditioned output (only for unconditioned variables).

Return type:

Array

unconditioned(obs, obs_ids, t, **kwargs)[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Unconditioned output.

Return type:

Array

class gensbi.models.Simformer(params)[source]#

Bases: flax.nnx.Module

Simformer model for joint density estimation.

Parameters:

params (SimformerParams) – Parameters for the Simformer model.

__call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#

Forward pass of the Simformer model.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Input data.

  • args (Optional[dict]) – Additional arguments.

  • node_ids (Array) – Node identifiers.

  • condition_mask (Array) – Mask for conditioning.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Model output.

Return type:

Array

condition_embedding#
dim_condition#
dim_id#
dim_value#
embedding_net_id#
embedding_net_value#
embedding_time#
in_channels#
output_fn#
params#
total_tokens#
transformer#
class gensbi.models.SimformerParams[source]#

Parameters for the Simformer model.

Parameters:
  • rngs (nnx.Rngs) – Random number generators for initialization.

  • in_channels (int) – Number of input channels.

  • dim_value (int) – Dimension of the value embeddings.

  • dim_id (int) – Dimension of the ID embeddings.

  • dim_condition (int) – Dimension of the condition embeddings.

  • dim_joint (int) – Total dimension of the joint embeddings.

  • fourier_features (int) – Number of Fourier features for time embedding.

  • num_heads (int) – Number of attention heads.

  • num_layers (int) – Number of transformer layers.

  • widening_factor (int) – Widening factor for the transformer.

  • qkv_features (int) – Number of features for QKV layers.

  • num_hidden_layers (int) – Number of hidden layers in the transformer.

__post_init__()[source]#
dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
fourier_features: int = 128#
in_channels: int#
num_heads: int#
num_hidden_layers: int = 1#
num_layers: int#
qkv_features: int | None = None#
rngs: flax.nnx.Rngs#
widening_factor: int = 3#
class gensbi.models.UnconditionalCFMLoss(path, reduction='mean')[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.

Parameters:
  • path – Probability path for training.

  • reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).

__call__(vf, batch, *args, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • vf (Callable) – Vector field model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).

  • args (Optional[dict]) – Additional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

class gensbi.models.UnconditionalDiffLoss(path)[source]#

Bases: flax.nnx.Module

UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.UnconditionalWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for unconditional models to handle input expansion and calling convention.

Parameters:

model – The unconditional model instance to wrap.

__call__(t, obs, obs_ids, **kwargs)[source]#

Call the wrapped model with expanded inputs.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Model output.

Return type:

Array