gensbi.recipes.pipeline#
Pipeline module for GenSBI.
This module provides an abstract pipeline class for training and evaluating conditional generative models (such as conditional flow matching or diffusion models) in the GenSBI framework. It handles model creation, training loop, optimizer setup, checkpointing, and evaluation utilities.
Example
from gensbi.recipes.pipeline import AbstractPipeline # Implement a subclass with your model and loss definition class MyPipeline(AbstractPipeline):
- def _make_model(self):
…
- def _get_default_params(self, rngs):
…
- def get_loss_fn(self):
…
- def sample(self, rng, x_o, nsamples=10000, step_size=0.01):
…
# Instantiate and train pipeline = MyPipeline(train_dataset, val_dataset, dim_theta=2, dim_x=2) pipeline.train(rngs)
Classes#
Abstract base class for GenSBI training pipelines. |
|
Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters. |
Functions#
|
Module Contents#
- class gensbi.recipes.pipeline.AbstractPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
abc.ABC
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- abstract _get_default_params(rngs)[source]#
Return a dictionary of default model parameters.
- Parameters:
rngs (flax.nnx.Rngs)
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _get_ema_optimizer()[source]#
Construct the EMA optimizer for maintaining an exponential moving average of model parameters. :returns: ema_optimizer – The EMA optimizer instance. :rtype: ModelEMA
- _get_optimizer()[source]#
Construct the optimizer for training, including learning rate scheduling and gradient clipping.
- Returns:
optimizer – The optimizer instance for the model.
- Return type:
nnx.Optimizer
- abstract _wrap_model()[source]#
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
- get_train_step_fn(loss_fn)[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- get_val_step_fn(loss_fn)[source]#
Return the validation step function, which performs a single optimization step.
- Returns:
val_step – JIT-compiled validation step function.
- Return type:
Callable
- abstract sample(rng, x_o, nsamples=10000, step_size=0.01)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- train(rngs, nsteps=None, save_model=True)[source]#
Run the training loop for the model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for training/validation steps.
nsteps (Optional[int])
- Returns:
loss_array (list) – List of training losses.
val_loss_array (list) – List of validation losses.
- Return type:
Tuple[list, list]
- update_params(new_params)[source]#
Update the model parameters and re-initialize the model.
- Parameters:
new_params (dict) – New model parameters.
- class gensbi.recipes.pipeline.ModelEMA(model, tx)[source]#
Bases:
flax.nnx.Optimizer
Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.
This optimizer keeps an exponential moving average of the model parameters, which can help stabilize training and improve evaluation performance. The EMA parameters are updated at each training step.
- Parameters:
model (nnx.Module) – The model whose parameters will be tracked.
tx (optax.GradientTransformation) – The Optax transformation defining the EMA update rule.
- update(model, model_orginal)[source]#
Update the EMA parameters using the current model parameters. :param model: The model with EMA parameters to be updated. :type model: nnx.Module :param model_orginal: The original model with current parameters. :type model_orginal: nnx.Module
- Parameters:
model_orginal (flax.nnx.Module)