"""
Pipeline module for GenSBI.
This module provides an abstract pipeline class for training and evaluating conditional generative models
(such as conditional flow matching or diffusion models) in the GenSBI framework. It handles model creation,
training loop, optimizer setup, checkpointing, and evaluation utilities.
Example:
.. code-block:: python
from gensbi.recipes.pipeline import AbstractPipeline
# Implement a subclass with your model and loss definition
class MyPipeline(AbstractPipeline):
def _make_model(self):
...
def _get_default_params(self, rngs):
...
def get_loss_fn(self):
...
def sample(self, rng, x_o, nsamples=10000, step_size=0.01):
...
# Instantiate and train
pipeline = MyPipeline(train_dataset, val_dataset, dim_theta=2, dim_x=2)
pipeline.train(rngs)
"""
from flax import nnx
import jax
from jax import numpy as jnp
from typing import Any, Callable, Optional, Tuple
from jax import Array
from numpyro import distributions as dist
import abc
from functools import partial
import optax
from optax.contrib import reduce_on_plateau
import orbax.checkpoint as ocp
from tqdm import tqdm
import os
[docs]
class ModelEMA(nnx.Optimizer):
"""
Exponential Moving Average (EMA) optimizer for maintaining a smoothed version of model parameters.
This optimizer keeps an exponential moving average of the model parameters, which can help stabilize training
and improve evaluation performance. The EMA parameters are updated at each training step.
Parameters
----------
model : nnx.Module
The model whose parameters will be tracked.
tx : optax.GradientTransformation
The Optax transformation defining the EMA update rule.
"""
def __init__(
self,
model: nnx.Module,
tx: optax.GradientTransformation,
):
super().__init__(model, tx, wrt=[nnx.Param, nnx.BatchStat])
[docs]
def update(self, model, model_orginal: nnx.Module):
"""
Update the EMA parameters using the current model parameters.
Parameters
----------
model : nnx.Module
The model with EMA parameters to be updated.
model_orginal : nnx.Module
The original model with current parameters.
"""
params = nnx.state(model_orginal, self.wrt)
ema_params = nnx.state(model, self.wrt)
self.step.value += 1
ema_state = optax.EmaState(count=self.step, ema=ema_params)
_, new_ema_state = self.tx.update(params, ema_state)
nnx.update(model, new_ema_state.ema)
@nnx.jit
[docs]
def ema_step(ema_model, model, ema_optimizer: nnx.Optimizer):
ema_optimizer.update(ema_model, model)
[docs]
class AbstractPipeline(abc.ABC):
"""
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models.
Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
Parameters
----------
train_dataset : iterable
Training dataset, should yield batches of data.
val_dataset : iterable
Validation dataset, should yield batches of data.
dim_theta : int
Dimensionality of the parameter (theta) space.
dim_x : int
Dimensionality of the observation (x) space.
params : dict, optional
Model parameters. If None, uses defaults from `_get_default_params`.
training_config : dict, optional
Training configuration. If None, uses defaults from `_get_default_training_config`.
"""
def __init__(self, train_dataset, val_dataset, dim_theta: int, dim_x: int, params=None, training_config=None):
[docs]
self.train_dataset = train_dataset
[docs]
self.val_dataset = val_dataset
[docs]
self.train_dataset_iter = iter(self.train_dataset)
[docs]
self.val_dataset_iter = iter(self.val_dataset)
[docs]
self.dim_theta = dim_theta
[docs]
self.dim_joint = dim_theta + dim_x
[docs]
self.node_ids = jnp.arange(self.dim_joint)
[docs]
self.obs_ids = jnp.arange(self.dim_theta) # observation ids
[docs]
self.cond_ids = jnp.arange(self.dim_theta, self.dim_joint) # conditional ids
if params is None:
self.params = self._get_default_params()
[docs]
self.training_config = training_config
if training_config is None:
self.training_config = self._get_default_training_config()
self.training_config["min_scale"] = self.training_config["min_lr"] / self.training_config["max_lr"] if self.training_config["max_lr"] > 0 else 0.0
os.makedirs(self.training_config["checkpoint_dir"], exist_ok=True)
[docs]
self.model = self._make_model()
[docs]
self.model_wrapped = None # to be set in subclass
[docs]
self.ema_model = nnx.clone(self.model)
[docs]
self.ema_model_wrapped = None # to be set in subclass
[docs]
self.p0_dist_model = None # to be set in subclass
[docs]
self.loss_fn = None # to be set in subclass
[docs]
self.path = None # to be set in subclass
@abc.abstractmethod
[docs]
def _make_model(self):
"""
Create and return the model to be trained.
"""
... # pragma: no cover
[docs]
def _get_ema_optimizer(self):
"""
Construct the EMA optimizer for maintaining an exponential moving average of model parameters.
Returns
-------
ema_optimizer : ModelEMA
The EMA optimizer instance.
"""
ema_tx = optax.ema(self.training_config["ema_decay"])
ema_optimizer = ModelEMA(self.ema_model, ema_tx)
return ema_optimizer
[docs]
def _get_optimizer(self):
"""
Construct the optimizer for training, including learning rate scheduling and gradient clipping.
Returns
-------
optimizer : nnx.Optimizer
The optimizer instance for the model.
"""
opt = optax.chain(
optax.adaptive_grad_clip(10.0),
optax.adamw(self.training_config["max_lr"]),
reduce_on_plateau(
patience=self.training_config["patience"],
cooldown=self.training_config["cooldown"],
factor=self.training_config["factor"],
rtol=self.training_config["rtol"],
accumulation_size=self.training_config["accumulation_size"],
min_scale=self.training_config["min_scale"],
),
)
if self.training_config["multistep"] > 1:
opt = optax.MultiSteps(opt, self.training_config["multistep"])
optimizer = nnx.Optimizer(self.model, opt, wrt=nnx.Param)
return optimizer
@abc.abstractmethod
[docs]
def _get_default_params(self, rngs: nnx.Rngs):
"""
Return a dictionary of default model parameters.
"""
... # pragma: no cover
@classmethod
[docs]
def _get_default_training_config(cls):
"""
Return a dictionary of default training configuration parameters.
Returns
-------
training_config : dict
Default training configuration.
"""
training_config = {}
training_config["num_steps"] = 30_000
training_config["ema_decay"] = 0.99
training_config["patience"] = 10
training_config["cooldown"] = 2
training_config["factor"] = 0.5
training_config["accumulation_size"] = 100
training_config["rtol"] = 1e-4
training_config["max_lr"] = 1e-3
training_config["min_lr"] = 1e-8
training_config["val_every"] = 100
training_config["early_stopping"] = True
training_config["experiment_id"] = 1
training_config["multistep"] = 1
training_config["checkpoint_dir"] = os.path.join(os.getcwd(), "checkpoints")
return training_config
[docs]
def update_training_config(self, new_config):
"""
Update the training configuration with new parameters.
Parameters
----------
new_config : dict
New training configuration parameters.
"""
self.training_config.update(new_config)
self.training_config["min_scale"] = self.training_config["min_lr"] / self.training_config["max_lr"] if self.training_config["max_lr"] > 0 else 0.0
return
[docs]
def update_params(self, new_params):
"""
Update the model parameters and re-initialize the model.
Parameters
----------
new_params : dict
New model parameters.
"""
self.params = new_params
self.model = self._make_model()
self.model_wrapped = None # to be set in subclass
return
[docs]
def _next_batch(self):
"""
Return the next batch from the training dataset.
"""
return next(self.train_dataset_iter)
[docs]
def _next_val_batch(self):
"""
Return the next batch from the validation dataset.
"""
return next(self.val_dataset_iter)
@abc.abstractmethod
[docs]
def get_loss_fn(self):
"""
Return the loss function for training/validation.
"""
... # pragma: no cover
[docs]
def get_train_step_fn(self, loss_fn):
"""
Return the training step function, which performs a single optimization step.
Returns
-------
train_step : Callable
JIT-compiled training step function.
"""
@nnx.jit
def train_step(model, optimizer, x_1: Array, rng: jax.random.PRNGKey):
loss, grads = nnx.value_and_grad(loss_fn)(model, x_1, rng)
optimizer.update(model, grads, value=loss)
return loss
return train_step
[docs]
def get_val_step_fn(self, loss_fn):
"""
Return the validation step function, which performs a single optimization step.
Returns
-------
val_step : Callable
JIT-compiled validation step function.
"""
@nnx.jit
def val_step(model, x_1: Array, rng: jax.random.PRNGKey):
loss = loss_fn(model, x_1, rng)
return loss
return val_step
[docs]
def restore_model(self, experiment_id=None):
if experiment_id is None:
experiment_id = self.training_config["experiment_id"]
graphdef, model_state = nnx.split(self.model)
with ocp.CheckpointManager(
self.training_config["checkpoint_dir"],
options=ocp.CheckpointManagerOptions(read_only=True),
) as read_mgr:
restored = read_mgr.restore(
experiment_id,
args=ocp.args.Composite(state=ocp.args.PyTreeRestore(item=model_state)),
)
self.model = nnx.merge(graphdef, restored["state"])
# restore the ema model
model_state_ema = nnx.state(self.ema_model)
with ocp.CheckpointManager(
os.path.join(self.training_config["checkpoint_dir"], "ema"),
options=ocp.CheckpointManagerOptions(read_only=True),
) as read_mgr_ema:
restored_ema = read_mgr_ema.restore(
experiment_id,
args=ocp.args.Composite(
state=ocp.args.PyTreeRestore(item=model_state_ema)
),
)
self.ema_model = nnx.merge(graphdef, restored_ema["state"])
# wrap models
self._wrap_model()
print("Restored model from checkpoint")
return
@abc.abstractmethod
[docs]
def _wrap_model(self):
"""
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
"""
... # pragma: no cover
[docs]
def save_model(self, experiment_id=None):
if experiment_id is None:
experiment_id = self.training_config["experiment_id"]
checkpoint_dir = self.training_config["checkpoint_dir"]
checkpoint_dir_ema = os.path.join(self.training_config["checkpoint_dir"], "ema")
os.makedirs(checkpoint_dir, exist_ok=True)
os.makedirs(checkpoint_dir_ema, exist_ok=True)
# Save the model
checkpoint_manager = ocp.CheckpointManager(
checkpoint_dir,
options=ocp.CheckpointManagerOptions(
max_to_keep=None,
keep_checkpoints_without_metrics=True,
create=True,
),
)
model_state = nnx.state(self.model)
checkpoint_manager.save(
experiment_id, args=ocp.args.Composite(state=ocp.args.PyTreeSave(model_state))
)
checkpoint_manager.close()
# now we create the ema model and save it
ema_state = nnx.state(self.ema_model)
#save the ema model
checkpoint_manager_ema = ocp.CheckpointManager(
checkpoint_dir_ema,
options=ocp.CheckpointManagerOptions(
max_to_keep=None,
keep_checkpoints_without_metrics=True,
create=True,
),
)
checkpoint_manager_ema.save(
experiment_id, args=ocp.args.Composite(state=ocp.args.PyTreeSave(ema_state))
)
checkpoint_manager_ema.close()
print("Saved model to checkpoint")
return
[docs]
def train(self, rngs: nnx.Rngs, nsteps: Optional[int] = None, save_model=True) -> Tuple[list, list]:
"""
Run the training loop for the model.
Parameters
----------
rngs : nnx.Rngs
Random number generators for training/validation steps.
Returns
-------
loss_array : list
List of training losses.
val_loss_array : list
List of validation losses.
"""
optimizer = self._get_optimizer()
ema_optimizer = self._get_ema_optimizer()
best_state = nnx.state(self.model)
best_state_ema = nnx.state(self.ema_model)
loss_fn = self.get_loss_fn()
train_step = self.get_train_step_fn(loss_fn)
val_step = self.get_val_step_fn(loss_fn)
min_val = val_step(self.model, self._next_val_batch(), rngs.val_step())
val_error_ratio = 1.1
counter = 0
cmax = 10
loss_array = []
val_loss_array = []
self.model.train()
if nsteps is None:
nsteps = self.training_config["num_steps"]
early_stopping = self.training_config["early_stopping"]
val_every = self.training_config["val_every"]
experiment_id = self.training_config["experiment_id"]
pbar = tqdm(range(nsteps))
l_train = None
for j in pbar:
if counter > cmax and early_stopping:
print("Early stopping")
graphdef = nnx.graphdef(self.model)
self.model = nnx.merge(graphdef, best_state)
self.ema_model = nnx.merge(graphdef, best_state_ema)
break
loss = train_step(self.model, optimizer, self._next_batch(), rngs.train_step())
# update the parameters ema
ema_step(self.ema_model, self.model, ema_optimizer)
if j == 0:
l_train = loss
else:
l_train = 0.9 * l_train + 0.1 * loss
if j > 0 and j % val_every == 0:
l_val = val_step(self.model, self._next_val_batch(), rngs.val_step())
ratio = l_val / l_train
if ratio > val_error_ratio:
counter += 1
else:
counter = 0
pbar.set_postfix(
loss=f"{l_train:.4f}",
ratio=f"{ratio:.4f}",
counter=counter,
val_loss=f"{l_val:.4f}",
)
loss_array.append(l_train)
val_loss_array.append(l_val)
if l_val < min_val:
min_val = l_val
best_state = nnx.state(self.model)
best_state_ema = nnx.state(self.ema_model)
l_val = 0
l_train = 0
self.model.eval()
if save_model:
self.save_model(experiment_id)
self._wrap_model()
return loss_array, val_loss_array
@abc.abstractmethod
[docs]
def sample(self, rng, x_o, nsamples=10_000, step_size=0.01):
"""
Generate samples from the trained model.
Parameters
----------
rng : jax.random.PRNGKey
Random number generator key.
x_o : array-like
Conditioning variable (e.g., observed data).
nsamples : int, optional
Number of samples to generate.
step_size : float, optional
Step size for the sampler.
Returns
-------
samples : array-like
Generated samples.
"""
... # pragma: no cover