Source code for gensbi.recipes.unconditional_pipeline

"""
Pipeline for training and using a Unconditional model for simulation-based inference.

Examples:
    .. code-block:: python
    
        import grain
        import numpy as np
        import jax
        from jax import numpy as jnp
        from gensbi.recipes import UnconditionalPipeline

        # Define your training and validation datasets.
        train_data = jax.random.rand((1024, 4)) # your training dataset
        val_data = jax.random.rand((128, 4)) # your validation dataset

        batch_size = 32

        train_dataset_grain = (
            grain.MapDataset.source(np.array(train_data)[...,None])
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(batch_size)
            # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
        )

        val_dataset_grain = (
            grain.MapDataset.source(np.array(val_data)[...,None])
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(batch_size) 
            # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
        )
        

        # Define your model 
        model = ...  # your nnx.Module model here, e.g., a simple MLP, or the Simformer model
        # if you define a custom model, it should take as input the following arguments:
        #    t: Array,
        #    obs: Array,
        #    node_ids: Array (optional, if your model is a transformer-based model)
        #    *args 
        #    **kwargs   
        
        # the obs input should have shape (batch_size, dim_joint, c), and the output will be of the same shape

        # Define the model
        dim_theta = 2  # Dimension of the parameter space
        pipeline = UnconditionalPipeline(model, train_dataset_grain, val_dataset_grain, dim_theta)

        # Train the model
        rngs = jax.random.PRNGKey(0)
        pipeline.train(rngs)

        # Sample from the posterior
        samples = pipeline.sample(rngs, nsamples=10000, step_size=0.01)

        
    .. note::
    
        If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a `if __name__ == "__main__":` guard. See https://docs.python.org/3/library/multiprocessing.html
"""

import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau
from numpyro import distributions as dist
from tqdm.auto import tqdm
from functools import partial
import orbax.checkpoint as ocp

from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.solver import ODESolver

from gensbi.diffusion.path import EDMPath
from gensbi.diffusion.path.scheduler import EDMScheduler, VEScheduler
from gensbi.diffusion.solver import SDESolver

from gensbi.models import  UnconditionalCFMLoss, UnconditionalWrapper, UnconditionalDiffLoss

from einops import repeat

from gensbi.utils.model_wrapping import _expand_dims

import os

import yaml

from gensbi.recipes.pipeline import AbstractPipeline


[docs] class UnconditionalFlowPipeline(AbstractPipeline): def __init__( self, model, train_dataset, val_dataset, dim_theta: int, params=None, training_config=None, ): """ Flow pipeline for training and using a Unconditional model for simulation-based inference. Parameters ---------- model: nnx.Module The model to be trained. train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_theta : int Dimension of the parameter space. dim_x : int Dimension of the observation space. params : UnconditionalParams, optional Parameters for the Unconditional model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. """ super().__init__( train_dataset, val_dataset, dim_theta, 0, model, params, training_config )
[docs] self.obs_ids = _expand_dims(self.obs_ids)
[docs] self.path = AffineProbPath(scheduler=CondOTScheduler())
[docs] self.loss_fn = UnconditionalCFMLoss(self.path)
[docs] self.p0_dist_model = dist.Independent( dist.Normal( loc=jnp.zeros((self.dim_theta, 1)), scale=jnp.ones((self.dim_theta, 1)) ), reinterpreted_batch_ndims=1, )
@classmethod
[docs] def init_pipeline_from_config( cls, ): raise NotImplementedError( "Initialization from config not implemented for UnconditionalFlowPipeline." )
[docs] def _make_model(self): raise NotImplementedError( "Model creation not implemented for UnconditionalFlowPipeline." )
[docs] def _get_default_params(self): raise NotImplementedError( "Default parameters not implemented for UnconditionalFlowPipeline." )
[docs] def get_loss_fn( self, ): def loss_fn(model, batch, key: jax.random.PRNGKey): obs = batch batch_size = batch.shape[0] rng_x0, rng_t = jax.random.split(key, 2) x_1 = obs x_0 = self.p0_dist_model.sample(rng_x0, (batch_size,)) t = jax.random.uniform(rng_t, x_1.shape[0]) batch = (x_0, x_1, t) condition_mask = jnp.zeros(x_1.shape, dtype=jnp.bool_) loss = self.loss_fn(model, batch, node_ids=self.obs_ids, condition_mask=condition_mask) return loss return loss_fn
[docs] def _wrap_model(self): self.model_wrapped = UnconditionalWrapper(self.model) self.ema_model_wrapped = UnconditionalWrapper(self.ema_model) return
[docs] def sample( self, rng, nsamples=10_000, step_size=0.01, use_ema=True, time_grid=None, **model_extras ): if use_ema: vf_wrapped = self.ema_model_wrapped else: vf_wrapped = self.model_wrapped if time_grid is None: time_grid = jnp.array([0.0, 1.0]) return_intermediates = False else: assert jnp.all(time_grid[:-1] <= time_grid[1:]) return_intermediates = True x_init = jax.random.normal(rng, (nsamples, self.dim_theta)) solver = ODESolver(velocity_model=vf_wrapped) model_extras = { "obs_ids": self.obs_ids, **model_extras } sampler_ = solver.get_sampler( method="Dopri5", step_size=step_size, return_intermediates=return_intermediates, model_extras=model_extras, time_grid=time_grid, ) samples = sampler_(x_init) return samples
[docs] def compute_unnorm_logprob( self, x_1, step_size=0.01, use_ema=True, time_grid=None, **model_extras ): if use_ema: model = self.ema_model_wrapped else: model = self.model_wrapped if time_grid is None: time_grid = jnp.array([1.0, 0.0]) return_intermediates = False else: # assert time grid is decreasing assert jnp.all(time_grid[:-1] >= time_grid[1:]) return_intermediates = True solver = ODESolver(velocity_model=model) # x_1 = _expand_dims(x_1) assert ( x_1.ndim == 2 ), "x_1 must be of shape (num_samples, dim_obs), currently sampling for multiple channels is not supported." p0_cond = dist.Independent( dist.Normal( loc=jnp.zeros((x_1.shape[1],)), scale=jnp.ones((x_1.shape[1],)) ), reinterpreted_batch_ndims=1, ) #todo need to check the model extras, is that node_ids instead? model_extras = { "obs_ids": self.obs_ids, **model_extras } logp_sampler = solver.get_unnormalized_logprob( time_grid=time_grid, method="Dopri5", step_size=step_size, log_p0=p0_cond.log_prob, model_extras=model_extras, return_intermediates=return_intermediates, ) if len(x_1)>4: # we trigger precompilation first _ = logp_sampler(x_1[:4]) exact_log_p = logp_sampler(x_1) return exact_log_p
[docs] class UnconditionalDiffusionPipeline(AbstractPipeline): def __init__( self, model, train_dataset, val_dataset, dim_theta: int, params=None, training_config=None, ): """ Diffusion pipeline for training and using a Unconditional model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_theta : int Dimension of the parameter space. params : UnconditionalParams, optional Parameters for the Unconditional model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. """ super().__init__( train_dataset, val_dataset, dim_theta, 0, model, params, training_config )
[docs] self.obs_ids = _expand_dims(self.obs_ids)
[docs] self.path = EDMPath( scheduler=EDMScheduler( sigma_min=self.training_config["sigma_min"], sigma_max=self.training_config["sigma_max"], ) )
[docs] self.loss_fn = UnconditionalDiffLoss(self.path)
@classmethod
[docs] def init_pipeline_from_config( cls, ): raise NotImplementedError( "Initialization from config not implemented for UnconditionalDiffusionPipeline." )
[docs] def _make_model(self): raise NotImplementedError( "Model creation not implemented for UnconditionalDiffusionPipeline." )
[docs] def _get_default_params(self): raise NotImplementedError( "Default parameters not implemented for UnconditionalDiffusionPipeline." )
@classmethod
[docs] def _get_default_training_config(cls): config = super()._get_default_training_config() config.update( { "sigma_min": 0.002, # from edm paper "sigma_max": 80.0, } ) return config
[docs] def get_loss_fn( self, ): def loss_fn(model, batch, key: jax.random.PRNGKey): rng_x0, rng_sigma = jax.random.split(key, 2) x_1 = batch sigma = self.path.sample_sigma(rng_sigma, x_1.shape[0]) sigma = repeat(sigma, f"b -> b {'1 ' * (x_1.ndim - 1)}") # TODO fixme batch = (x_1, sigma) loss = self.loss_fn(rng_x0, model, batch, node_ids=self.obs_ids) return loss return loss_fn
[docs] def _wrap_model(self): self.model_wrapped = UnconditionalWrapper(self.model) self.ema_model_wrapped = UnconditionalWrapper(self.ema_model) return
[docs] def sample( self, rng, nsamples=10_000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras ): if use_ema: model = self.ema_model_wrapped else: model = self.model_wrapped key1, key2 = jax.random.split(rng, 2) solver = SDESolver(score_model=model, path=self.path) model_extras = { "obs_ids": self.obs_ids, **model_extras } x_init = self.path.sample_prior(key1, (nsamples, self.dim_theta, 1)) samples = solver.sample( key2, x_init, nsteps=nsteps, model_extras=model_extras, return_intermediates=return_intermediates, ) return jnp.squeeze(samples, axis=-1)