gensbi.recipes.simformer#
Pipeline for training and using a Flux1 model for simulation-based inference.
Examples
import grain
import numpy as np
import jax
from jax import numpy as jnp
from gensbi.recipes import SimformerPipeline
# Define your training and validation datasets.
train_data = jax.random.rand((1024, 4)) # your training dataset
val_data = jax.random.rand((128, 4)) # your validation dataset
batch_size = 32
train_dataset_grain = (
grain.MapDataset.source(np.array(train_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
val_dataset_grain = (
grain.MapDataset.source(np.array(val_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
# Define the model
dim_theta = 2 # Dimension of the parameter space
dim_x = 2 # Dimension of the observation space
pipeline = SimformerPipeline(train_dataset_grain, val_dataset_grain, dim_theta, dim_x)
# Train the model
rngs = jax.random.PRNGKey(0)
pipeline.train(rngs)
# Sample from the posterior
x_o = jnp.array([0.5, -0.2]) # Example
samples = pipeline.sample(rngs, x_o, nsamples=10000, step_size=0.01)
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == “__main__”: guard. See https://docs.python.org/3/library/multiprocessing.html
Classes#
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
Functions#
|
Parse a Simformer configuration file. |
|
Parse a training configuration file. |
|
Sample structured conditional masks for the Simformer model. |
Module Contents#
- class gensbi.recipes.simformer.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- class gensbi.recipes.simformer.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- gensbi.recipes.simformer.parse_simformer_params(config_path)[source]#
Parse a Simformer configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
- Returns:
config – Parsed configuration dictionary.
- Return type:
dict
- gensbi.recipes.simformer.parse_training_config(config_path)[source]#
Parse a training configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
- Returns:
config – Parsed configuration dictionary.
- Return type:
dict
- gensbi.recipes.simformer.sample_strutured_conditional_mask(key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7)[source]#
Sample structured conditional masks for the Simformer model.
- Parameters:
key (jax.random.PRNGKey) – Random key for sampling.
num_samples (int) – Number of samples to generate.
theta_dim (int) – Dimension of the parameter space.
x_dim (int) – Dimension of the observation space.
p_joint (float) – Probability of selecting the joint mask.
p_posterior (float) – Probability of selecting the posterior mask.
p_likelihood (float) – Probability of selecting the likelihood mask.
p_rnd1 (float) – Probability of selecting the first random mask.
p_rnd2 (float) – Probability of selecting the second random mask.
rnd1_prob (float) – Probability of a True value in the first random mask.
rnd2_prob (float) – Probability of a True value in the second random mask.
- Returns:
condition_mask – Array of shape (num_samples, theta_dim + x_dim) with boolean masks.
- Return type:
jnp.ndarray