gensbi.recipes.simformer#

Pipeline for training and using a Flux1 model for simulation-based inference.

Examples

import grain
import numpy as np
import jax
from jax import numpy as jnp
from gensbi.recipes import SimformerPipeline

# Define your training and validation datasets.
train_data = jax.random.rand((1024, 4)) # your training dataset
val_data = jax.random.rand((128, 4)) # your validation dataset

batch_size = 32

train_dataset_grain = (
    grain.MapDataset.source(np.array(train_data)[...,None])
    .shuffle(42)
    .repeat()
    .to_iter_dataset()
    .batch(batch_size)
    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)

val_dataset_grain = (
    grain.MapDataset.source(np.array(val_data)[...,None])
    .shuffle(42)
    .repeat()
    .to_iter_dataset()
    .batch(batch_size)
    # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)

# Define the model
dim_theta = 2  # Dimension of the parameter space
dim_x = 2      # Dimension of the observation space
pipeline = SimformerPipeline(train_dataset_grain, val_dataset_grain, dim_theta, dim_x)

# Train the model
rngs = jax.random.PRNGKey(0)
pipeline.train(rngs)

# Sample from the posterior
x_o = jnp.array([0.5, -0.2])  # Example
samples = pipeline.sample(rngs, x_o, nsamples=10000, step_size=0.01)

Note

If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == “__main__”: guard. See https://docs.python.org/3/library/multiprocessing.html

Classes#

SimformerDiffusionPipeline

Abstract base class for GenSBI training pipelines.

SimformerFlowPipeline

Abstract base class for GenSBI training pipelines.

Functions#

parse_simformer_params(config_path)

Parse a Simformer configuration file.

parse_training_config(config_path)

Parse a training configuration file.

sample_strutured_conditional_mask(key, num_samples, ...)

Sample structured conditional masks for the Simformer model.

Module Contents#

class gensbi.recipes.simformer.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointDiffusionPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

edge_mask = None[source]#
ema_model[source]#
model[source]#
class gensbi.recipes.simformer.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointFlowPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None)[source]#
classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

edge_mask = None[source]#
ema_model[source]#
model[source]#
gensbi.recipes.simformer.parse_simformer_params(config_path)[source]#

Parse a Simformer configuration file.

Parameters:

config_path (str) – Path to the configuration file.

Returns:

config – Parsed configuration dictionary.

Return type:

dict

gensbi.recipes.simformer.parse_training_config(config_path)[source]#

Parse a training configuration file.

Parameters:

config_path (str) – Path to the configuration file.

Returns:

config – Parsed configuration dictionary.

Return type:

dict

gensbi.recipes.simformer.sample_strutured_conditional_mask(key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7)[source]#

Sample structured conditional masks for the Simformer model.

Parameters:
  • key (jax.random.PRNGKey) – Random key for sampling.

  • num_samples (int) – Number of samples to generate.

  • theta_dim (int) – Dimension of the parameter space.

  • x_dim (int) – Dimension of the observation space.

  • p_joint (float) – Probability of selecting the joint mask.

  • p_posterior (float) – Probability of selecting the posterior mask.

  • p_likelihood (float) – Probability of selecting the likelihood mask.

  • p_rnd1 (float) – Probability of selecting the first random mask.

  • p_rnd2 (float) – Probability of selecting the second random mask.

  • rnd1_prob (float) – Probability of a True value in the first random mask.

  • rnd2_prob (float) – Probability of a True value in the second random mask.

Returns:

condition_mask – Array of shape (num_samples, theta_dim + x_dim) with boolean masks.

Return type:

jnp.ndarray