gensbi.models#

Submodules#

Classes#

Flux

Transformer model for flow matching on sequences.

FluxCFMLoss

FluxCFMLoss is a class that computes the continuous flow matching loss for the Flux model.

FluxDiffLoss

FluxDiffLoss is a class that computes the diffusion score matching loss for the Flux model.

FluxParams

Parameters for the Flux model.

FluxWrapper

This class is used to wrap around another model. We define a call method which returns the model output.

Simformer

Simformer model for joint density estimation.

SimformerCFMLoss

SimformerCFMLoss is a class that computes the continuous flow matching loss for the Simformer model.

SimformerDiffLoss

SimformerDiffLoss is a class that computes the diffusion score matching loss for the Simformer model.

SimformerParams

Parameters for the Simformer model.

SimformerWrapper

This class is used to wrap around another model. We define a call method which returns the model output.

Package Contents#

class gensbi.models.Flux(params)[source]#

Bases: flax.nnx.Module

Transformer model for flow matching on sequences.

Parameters:

params (FluxParams)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Parameters:
  • t (jax.Array)

  • obs (jax.Array)

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Return type:

jax.Array

cond_in#
condition_embedding#
condition_null#
double_blocks#
final_layer#
hidden_size#
in_channels#
num_heads#
obs_in#
out_channels#
params#
pe_embedder#
qkv_features#
single_blocks#
time_in#
vector_in#
class gensbi.models.FluxCFMLoss(path, reduction='mean', cfg_scale=None)[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

FluxCFMLoss is a class that computes the continuous flow matching loss for the Flux model.

Parameters:
  • path – Probability path (x-prediction training).

  • reduction (str, optional) – Specify the reduction to apply to the output 'none' | 'mean' | 'sum'. 'none': no reduction is applied to the output, 'mean': the output is reduced by mean over sequence elements, 'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.

__call__(vf, batch, cond, obs_ids, cond_ids)[source]#

Evaluates the continuous flow matching loss.

Parameters:
  • vf (callable) – The vector field model to evaluate.

  • batch (tuple) – A tuple containing the input data (x_0, x_1, t).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

The computed loss.

Return type:

jnp.ndarray

cfg_scale = None#
class gensbi.models.FluxDiffLoss(path)[source]#

Bases: flax.nnx.Module

FluxDiffLoss is a class that computes the diffusion score matching loss for the Flux model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, cond, obs_ids, cond_ids)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.FluxParams[source]#

Parameters for the Flux model.

Parameters:
  • in_channels (int) – Number of input channels.

  • vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.

  • context_in_dim (int) – Dimension of the context input.

  • mlp_ratio (float) – Ratio for the MLP layers.

  • num_heads (int) – Number of attention heads.

  • depth (int) – Number of double stream blocks.

  • depth_single_blocks (int) – Number of single stream blocks.

  • axes_dim (list[int]) – Dimensions of the axes for positional encoding.

  • qkv_bias (bool) – Whether to use bias in QKV layers.

  • rngs (nnx.Rngs) – Random number generators for initialization.

  • obs_dim (int) – Observation dimension.

  • cond_dim (int) – Condition dimension.

  • theta (int) – Scaling factor for positional encoding.

  • guidance_embed (bool) – Whether to use guidance embedding.

  • qkv_multiplier (int) – Multiplier for QKV features.

  • param_dtype (DTypeLike) – Data type for model parameters.

__post_init__()[source]#
axes_dim: list[int]#
cond_dim: int#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int#
param_dtype: jax.typing.DTypeLike#
qkv_bias: bool#
qkv_multiplier: int = 1#
rngs: flax.nnx.Rngs#
theta: int = 10000#
vec_in_dim: int | None#
class gensbi.models.FluxWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#

This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with any additional keyword arguments.

Optional things to do here:
  • check that t is in the dimensions that the model is expecting.

  • add a custom forward pass logic.

  • call the wrapped model.

given obs, t
returns the model output for input obs at time t, with extra information extra.
Parameters:
  • obs (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • **extras – additional information forwarded to the model, e.g., text condition.

  • obs_ids (jax.Array)

  • cond (jax.Array)

  • cond_ids (jax.Array)

  • conditioned (bool | jax.Array)

  • guidance (jax.Array | None)

Returns:

model output.

Return type:

Array

class gensbi.models.Simformer(params)[source]#

Bases: flax.nnx.Module

Simformer model for joint density estimation.

Parameters:

params (SimformerParams) – Parameters for the Simformer model.

__call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#

Forward pass of the Simformer model.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Input data.

  • args (Optional[dict]) – Additional arguments.

  • node_ids (Array) – Node identifiers.

  • condition_mask (Array) – Mask for conditioning.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Model output.

Return type:

Array

condition_embedding#
dim_condition#
dim_id#
dim_value#
embedding_net_id#
embedding_net_value#
embedding_time#
output_fn#
params#
total_tokens#
transformer#
class gensbi.models.SimformerCFMLoss(path, reduction='mean')[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

SimformerCFMLoss is a class that computes the continuous flow matching loss for the Simformer model.

Parameters:
  • path – Probability path for training.

  • reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).

__call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • vf (Callable) – Vector field model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

class gensbi.models.SimformerDiffLoss(path)[source]#

Bases: flax.nnx.Module

SimformerDiffLoss is a class that computes the diffusion score matching loss for the Simformer model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.SimformerParams[source]#

Parameters for the Simformer model.

Parameters:
  • rngs (nnx.Rngs) – Random number generators for initialization.

  • dim_value (int) – Dimension of the value embeddings.

  • dim_id (int) – Dimension of the ID embeddings.

  • dim_condition (int) – Dimension of the condition embeddings.

  • dim_joint (int) – Total dimension of the joint embeddings.

  • fourier_features (int) – Number of Fourier features for time embedding.

  • num_heads (int) – Number of attention heads.

  • num_layers (int) – Number of transformer layers.

  • widening_factor (int) – Widening factor for the transformer.

  • qkv_features (int) – Number of features for QKV layers.

  • num_hidden_layers (int) – Number of hidden layers in the transformer.

__post_init__()[source]#
dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
fourier_features: int = 128#
num_heads: int#
num_hidden_layers: int = 1#
num_layers: int#
qkv_features: int | None = None#
rngs: flax.nnx.Rngs#
widening_factor: int = 3#
class gensbi.models.SimformerWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.

Parameters:

model (Simformer)

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, edge_mask=None)[source]#

This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with additional keyword arguments.

Parameters:
  • obs (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • cond (Array) – conditioning data to the model (batch_size, …).

  • obs_ids (Array) – observation ids (batch_size, obs_dim).

  • cond_ids (Array) – condition ids (batch_size, cond_dim).

  • conditioned (bool | Array) – whether to use conditioning or not.

  • edge_mask (Optional[Array]) – mask for edges.

Returns:

model output.

Return type:

Array

conditioned(t, obs, obs_ids, cond, cond_ids, edge_mask=None)[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Conditioned output.

Return type:

Array

unconditioned(t, obs, obs_ids, edge_mask=None)[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Unconditioned output.

Return type:

Array

dim_joint#
model#