gensbi.recipes.joint_pipeline#
Pipeline for training and using a Flux1 model for simulation-based inference.
Examples
import grain
import numpy as np
import jax
from jax import numpy as jnp
from gensbi.recipes import JointPipeline
# Define your training and validation datasets.
train_data = jax.random.rand((1024, 4)) # your training dataset
val_data = jax.random.rand((128, 4)) # your validation dataset
batch_size = 32
train_dataset_grain = (
grain.MapDataset.source(np.array(train_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
val_dataset_grain = (
grain.MapDataset.source(np.array(val_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
# Define your model
model = ... # your nnx.Module model here, e.g., a simple MLP, or the Simformer or Flux1Joint model
# if you define a custom model, it should take as input the following arguments:
# t: Array,
# obs: Array,
# node_ids: Array,
# condition_mask: Array,
# *args,
# **kwargs,
# the obs should have shape (batch_size, dim_joint, c),
# node_ids and condition_mask should match obs shape,
# and the output will be of the same shape as obs
dim_theta = 2 # Dimension of the parameter space
dim_x = 2 # Dimension of the observation space
pipeline = JointPipeline(model, train_dataset_grain, val_dataset_grain, dim_theta, dim_x)
# Train the model
rngs = jax.random.PRNGKey(0)
pipeline.train(rngs)
# Sample from the posterior
x_o = jnp.array([0.5, -0.2]) # Example
samples = pipeline.sample(rngs, x_o, nsamples=10000, step_size=0.01)
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == “__main__”: guard. See https://docs.python.org/3/library/multiprocessing.html
Classes#
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
Functions#
|
Sample structured conditional masks for the Joint model. |
Module Contents#
- class gensbi.recipes.joint_pipeline.JointDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- class gensbi.recipes.joint_pipeline.JointFlowPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- gensbi.recipes.joint_pipeline.sample_strutured_conditional_mask(key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7)[source]#
Sample structured conditional masks for the Joint model.
- Parameters:
key (jax.random.PRNGKey) – Random key for sampling.
num_samples (int) – Number of samples to generate.
theta_dim (int) – Dimension of the parameter space.
x_dim (int) – Dimension of the observation space.
p_joint (float) – Probability of selecting the joint mask.
p_posterior (float) – Probability of selecting the posterior mask.
p_likelihood (float) – Probability of selecting the likelihood mask.
p_rnd1 (float) – Probability of selecting the first random mask.
p_rnd2 (float) – Probability of selecting the second random mask.
rnd1_prob (float) – Probability of a True value in the first random mask.
rnd2_prob (float) – Probability of a True value in the second random mask.
- Returns:
condition_mask – Array of shape (num_samples, theta_dim + x_dim) with boolean masks.
- Return type:
jnp.ndarray