gensbi.recipes.pipeline#

Pipeline module for GenSBI.

This module provides an abstract pipeline class for training and evaluating conditional generative models (such as conditional flow matching or diffusion models) in the GenSBI framework. It handles model creation, training loop, optimizer setup, checkpointing, and evaluation utilities.

Example


from gensbi.recipes.pipeline import AbstractPipeline # Implement a subclass with your model and loss definition class MyPipeline(AbstractPipeline):

def _make_model(self):

def _get_default_params(self, rngs):

def get_loss_fn(self):

def sample(self, rng, x_o, nsamples=10000, step_size=0.01):

# Instantiate and train pipeline = MyPipeline(train_dataset, val_dataset, dim_theta=2, dim_x=2) pipeline.train(rngs)

Classes#

AbstractPipeline

Abstract base class for GenSBI training pipelines.

ModelEMA

Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.

Functions#

ema_step(ema_model, model, ema_optimizer)

Module Contents#

class gensbi.recipes.pipeline.AbstractPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: abc.ABC

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstract _get_default_params(rngs)[source]#

Return a dictionary of default model parameters.

Parameters:

rngs (flax.nnx.Rngs)

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

_get_ema_optimizer()[source]#

Construct the EMA optimizer for maintaining an exponential moving average of model parameters. :returns: ema_optimizer – The EMA optimizer instance. :rtype: ModelEMA

_get_optimizer()[source]#

Construct the optimizer for training, including learning rate scheduling and gradient clipping.

Returns:

optimizer – The optimizer instance for the model.

Return type:

nnx.Optimizer

abstract _make_model()[source]#

Create and return the model to be trained.

_next_batch()[source]#

Return the next batch from the training dataset.

_next_val_batch()[source]#

Return the next batch from the validation dataset.

abstract _wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

abstract get_loss_fn()[source]#

Return the loss function for training/validation.

get_train_step_fn(loss_fn)[source]#

Return the training step function, which performs a single optimization step.

Returns:

train_step – JIT-compiled training step function.

Return type:

Callable

get_val_step_fn(loss_fn)[source]#

Return the validation step function, which performs a single optimization step.

Returns:

val_step – JIT-compiled validation step function.

Return type:

Callable

restore_model(experiment_id=None)[source]#
abstract sample(rng, x_o, nsamples=10000, step_size=0.01)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

save_model(experiment_id=None)[source]#
train(rngs, nsteps=None, save_model=True)[source]#

Run the training loop for the model.

Parameters:
  • rngs (nnx.Rngs) – Random number generators for training/validation steps.

  • nsteps (Optional[int])

Returns:

  • loss_array (list) – List of training losses.

  • val_loss_array (list) – List of validation losses.

Return type:

Tuple[list, list]

update_params(new_params)[source]#

Update the model parameters and re-initialize the model.

Parameters:

new_params (dict) – New model parameters.

update_training_config(new_config)[source]#

Update the training configuration with new parameters.

Parameters:

new_config (dict) – New training configuration parameters.

cond_ids[source]#
dim_joint[source]#
dim_theta[source]#
dim_x[source]#
ema_model[source]#
ema_model_wrapped = None[source]#
loss_fn = None[source]#
model[source]#
model_wrapped = None[source]#
node_ids[source]#
obs_ids[source]#
p0_dist_model = None[source]#
params = None[source]#
path = None[source]#
train_dataset[source]#
train_dataset_iter[source]#
training_config = None[source]#
val_dataset[source]#
val_dataset_iter[source]#
class gensbi.recipes.pipeline.ModelEMA(model, tx)[source]#

Bases: flax.nnx.Optimizer

Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.

This optimizer keeps an exponential moving average of the model parameters, which can help stabilize training and improve evaluation performance. The EMA parameters are updated at each training step.

Parameters:
  • model (nnx.Module) – The model whose parameters will be tracked.

  • tx (optax.GradientTransformation) – The Optax transformation defining the EMA update rule.

update(model, model_orginal)[source]#

Update the EMA parameters using the current model parameters. :param model: The model with EMA parameters to be updated. :type model: nnx.Module :param model_orginal: The original model with current parameters. :type model_orginal: nnx.Module

Parameters:

model_orginal (flax.nnx.Module)

gensbi.recipes.pipeline.ema_step(ema_model, model, ema_optimizer)[source]#
Parameters:

ema_optimizer (flax.nnx.Optimizer)