gensbi.recipes.simformer#

Pipeline for training and using a Flux1 model for simulation-based inference.

Example

Classes#

SimformerDiffusionPipeline

Abstract base class for GenSBI training pipelines.

SimformerFlowPipeline

Abstract base class for GenSBI training pipelines.

Functions#

sample_strutured_conditional_mask(key, num_samples, ...)

Sample structured conditional masks for the Simformer model.

Module Contents#

class gensbi.recipes.simformer.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

_make_model()[source]#

Create and return the Simformer model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn[source]#
path[source]#
undirected_edge_mask[source]#
class gensbi.recipes.simformer.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model()[source]#

Create and return the Simformer model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn[source]#
p0_dist_model[source]#
path[source]#
undirected_edge_mask[source]#
gensbi.recipes.simformer.sample_strutured_conditional_mask(key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7)[source]#

Sample structured conditional masks for the Simformer model.

Parameters:
  • key (jax.random.PRNGKey) – Random key for sampling.

  • num_samples (int) – Number of samples to generate.

  • theta_dim (int) – Dimension of the parameter space.

  • x_dim (int) – Dimension of the observation space.

  • p_joint (float) – Probability of selecting the joint mask.

  • p_posterior (float) – Probability of selecting the posterior mask.

  • p_likelihood (float) – Probability of selecting the likelihood mask.

  • p_rnd1 (float) – Probability of selecting the first random mask.

  • p_rnd2 (float) – Probability of selecting the second random mask.

  • rnd1_prob (float) – Probability of a True value in the first random mask.

  • rnd2_prob (float) – Probability of a True value in the second random mask.

Returns:

condition_mask – Array of shape (num_samples, theta_dim + x_dim) with boolean masks.

Return type:

jnp.ndarray