gensbi.recipes.pipeline#
Pipeline module for GenSBI.
This module provides an abstract pipeline class for training and evaluating conditional generative models (such as conditional flow matching or diffusion models) in the GenSBI framework. It handles model creation, training loop, optimizer setup, checkpointing, and evaluation utilities.
For practical implementations, subclasses should implement specific model architectures, loss functions, and sampling methods. See JointPipeline and ConditionalPipeline for concrete examples.
Classes#
Abstract base class for GenSBI training pipelines. |
|
Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters. |
Functions#
|
Module Contents#
- class gensbi.recipes.pipeline.AbstractPipeline(train_dataset, val_dataset, dim_theta, dim_x, model, params, training_config=None)[source]#
Bases:
abc.ABCAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- abstractmethod _get_default_params(rngs)[source]#
Return a dictionary of default model parameters.
- Parameters:
rngs (flax.nnx.Rngs)
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _get_ema_optimizer()[source]#
Construct the EMA optimizer for maintaining an exponential moving average of model parameters. :returns: ema_optimizer – The EMA optimizer instance. :rtype: ModelEMA
- _get_optimizer()[source]#
Construct the optimizer for training, including learning rate scheduling and gradient clipping.
- Returns:
optimizer – The optimizer instance for the model.
- Return type:
nnx.Optimizer
- abstractmethod _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- get_train_step_fn(loss_fn)[source]#
Return the training step function, which performs a single optimization step.
- Returns:
train_step – JIT-compiled training step function.
- Return type:
Callable
- get_val_step_fn(loss_fn)[source]#
Return the validation step function, which performs a single optimization step.
- Returns:
val_step – JIT-compiled validation step function.
- Return type:
Callable
- abstractmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- abstractmethod sample(rng, x_o, nsamples=10000, step_size=0.01)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- train(rngs, nsteps=None, save_model=True)[source]#
Run the training loop for the model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for training/validation steps.
nsteps (Optional[int])
- Returns:
loss_array (list) – List of training losses.
val_loss_array (list) – List of validation losses.
- Return type:
Tuple[list, list]
- update_params(new_params)[source]#
Update the model parameters and re-initialize the model.
- Parameters:
new_params (dict) – New model parameters.
- class gensbi.recipes.pipeline.ModelEMA(model, tx)[source]#
Bases:
flax.nnx.OptimizerExponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.
This optimizer keeps an exponential moving average of the model parameters, which can help stabilize training and improve evaluation performance. The EMA parameters are updated at each training step.
- Parameters:
model (nnx.Module) – The model whose parameters will be tracked.
tx (optax.GradientTransformation) – The Optax transformation defining the EMA update rule.
- update(model, model_orginal)[source]#
Update the EMA parameters using the current model parameters. :param model: The model with EMA parameters to be updated. :type model: nnx.Module :param model_orginal: The original model with current parameters. :type model_orginal: nnx.Module
- Parameters:
model_orginal (flax.nnx.Module)