Source code for gensbi.recipes.simformer

"""
Pipeline for training and using a Flux1 model for simulation-based inference.

Examples:
    .. code-block:: python
    
        import grain
        import numpy as np
        import jax
        from jax import numpy as jnp
        from gensbi.recipes import SimformerPipeline

        # Define your training and validation datasets.
        train_data = jax.random.rand((1024, 4)) # your training dataset
        val_data = jax.random.rand((128, 4)) # your validation dataset

        batch_size = 32

        train_dataset_grain = (
            grain.MapDataset.source(np.array(train_data)[...,None])
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(batch_size)
            # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
        )

        val_dataset_grain = (
            grain.MapDataset.source(np.array(val_data)[...,None])
            .shuffle(42)
            .repeat()
            .to_iter_dataset()
            .batch(batch_size) 
            # .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
        )
        
        # Define the model
        dim_theta = 2  # Dimension of the parameter space
        dim_x = 2      # Dimension of the observation space
        pipeline = SimformerPipeline(train_dataset_grain, val_dataset_grain, dim_theta, dim_x)

        # Train the model
        rngs = jax.random.PRNGKey(0)
        pipeline.train(rngs)

        # Sample from the posterior
        x_o = jnp.array([0.5, -0.2])  # Example
        samples = pipeline.sample(rngs, x_o, nsamples=10000, step_size=0.01)

    .. note::
    
        If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a `if __name__ == "__main__":` guard. See https://docs.python.org/3/library/multiprocessing.html

"""

import jax
import jax.numpy as jnp
from flax import config, nnx

import yaml


from gensbi.models import (
    Simformer,
    SimformerParams,
)


from gensbi.recipes.joint_pipeline import JointFlowPipeline, JointDiffusionPipeline


[docs] def parse_simformer_params(config_path: str): """ Parse a Simformer configuration file. Parameters ---------- config_path : str Path to the configuration file. Returns ------- config : dict Parsed configuration dictionary. """ with open(config_path, "r") as f: config = yaml.safe_load(f) model_params = config.get("model", {}) params_dict = dict( in_channels=model_params.get("in_channels", 1), dim_value=model_params.get("dim_value", 40), dim_id=model_params.get("dim_id", 40), dim_condition=model_params.get("dim_condition", 10), fourier_features=model_params.get("fourier_features", 128), num_heads=model_params.get("num_heads", 4), num_layers=model_params.get("num_layers", 8), widening_factor=model_params.get("widening_factor", 3), qkv_features=model_params.get("qkv_features", 90), num_hidden_layers=model_params.get("num_hidden_layers", 1), ) return params_dict
[docs] def parse_training_config(config_path: str): """ Parse a training configuration file. Parameters ---------- config_path : str Path to the configuration file. Returns ------- config : dict Parsed configuration dictionary. """ with open(config_path, "r") as f: config = yaml.safe_load(f) # Training parameters train_params = config.get("training", {}) multistep = train_params.get("multistep", 1) experiment_id = train_params.get("experiment_id", 1) early_stopping = train_params.get("early_stopping", True) nsteps = train_params.get("nsteps", 30000) * multistep val_every = train_params.get("val_every", 100) * multistep # Optimizer parameters opt_params = config.get("optimizer", {}) PATIENCE = opt_params.get("patience", 10) COOLDOWN = opt_params.get("cooldown", 2) FACTOR = opt_params.get("factor", 0.5) ACCUMULATION_SIZE = opt_params.get("accumulation_size", 100) * multistep RTOL = opt_params.get("rtol", 1e-4) MAX_LR = opt_params.get("max_lr", 1e-3) MIN_LR = opt_params.get("min_lr", 0.0) MIN_SCALE = MIN_LR / MAX_LR if MAX_LR > 0 else 0.0 ema_decay = opt_params.get("ema_decay", 0.999) training_config = {} # overwrite the defaults with the config file values training_config["num_steps"] = nsteps training_config["ema_decay"] = ema_decay training_config["patience"] = PATIENCE training_config["cooldown"] = COOLDOWN training_config["factor"] = FACTOR training_config["accumulation_size"] = ACCUMULATION_SIZE training_config["rtol"] = RTOL training_config["max_lr"] = MAX_LR training_config["min_lr"] = MIN_LR training_config["min_scale"] = MIN_SCALE training_config["val_every"] = val_every training_config["early_stopping"] = early_stopping training_config["experiment_id"] = experiment_id training_config["multistep"] = multistep return training_config
[docs] def sample_strutured_conditional_mask( key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7, ): """ Sample structured conditional masks for the Simformer model. Parameters ---------- key : jax.random.PRNGKey Random key for sampling. num_samples : int Number of samples to generate. theta_dim : int Dimension of the parameter space. x_dim : int Dimension of the observation space. p_joint : float Probability of selecting the joint mask. p_posterior : float Probability of selecting the posterior mask. p_likelihood : float Probability of selecting the likelihood mask. p_rnd1 : float Probability of selecting the first random mask. p_rnd2 : float Probability of selecting the second random mask. rnd1_prob : float Probability of a True value in the first random mask. rnd2_prob : float Probability of a True value in the second random mask. Returns ------- condition_mask : jnp.ndarray Array of shape (num_samples, theta_dim + x_dim) with boolean masks. """ # Joint, posterior, likelihood, random1_mask, random2_mask key1, key2, key3 = jax.random.split(key, 3) joint_mask = jnp.array([False] * (theta_dim + x_dim), dtype=jnp.bool_) posterior_mask = jnp.array([False] * theta_dim + [True] * x_dim, dtype=jnp.bool_) likelihood_mask = jnp.array([True] * theta_dim + [False] * x_dim, dtype=jnp.bool_) random1_mask = jax.random.bernoulli( key2, rnd1_prob, shape=(theta_dim + x_dim,) ).astype(jnp.bool_) random2_mask = jax.random.bernoulli( key3, rnd2_prob, shape=(theta_dim + x_dim,) ).astype(jnp.bool_) mask_options = jnp.stack( [joint_mask, posterior_mask, likelihood_mask, random1_mask, random2_mask], axis=0, ) # (5, theta_dim + x_dim) idx = jax.random.choice( key1, 5, shape=(num_samples,), p=jnp.array([p_joint, p_posterior, p_likelihood, p_rnd1, p_rnd2]), ) condition_mask = mask_options[idx] all_ones_mask = jnp.all(condition_mask, axis=-1) # If all are ones, then set to false condition_mask = jnp.where(all_ones_mask[..., None], False, condition_mask) return condition_mask[..., None]
[docs] class SimformerFlowPipeline(JointFlowPipeline): def __init__( self, train_dataset, val_dataset, dim_theta: int, dim_x: int, params=None, training_config=None, edge_mask=None, ): """ Flow pipeline for training and using a Simformer model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_theta : int Dimension of the parameter space. dim_x : int Dimension of the observation space. params : SimformerParams, optional Parameters for the Simformer model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. edge_mask : jnp.ndarray, optional Edge mask for the Simformer model. If None, no mask is applied. """ super().__init__( None, train_dataset, val_dataset, dim_theta, dim_x, params, training_config ) if params is None: self.params = self._get_default_params()
[docs] self.model = self._make_model(self.params)
[docs] self.ema_model = nnx.clone(self.model)
[docs] self.edge_mask = edge_mask
@classmethod
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_theta: int, dim_x: int, config_path: str, checkpoint_dir: str, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. """ with open(config_path, "r") as f: config = yaml.safe_load(f) # methodology strategy = config.get("strategy", {}) method = strategy.get("method") model_type = strategy.get("model") assert ( method == "flow" ), f"Method {method} not supported in SimformerFlowPipeline." assert ( model_type == "simformer" ), f"Model type {model_type} not supported in SimformerFlowPipeline." # Model parameters from config dim_joint = dim_theta + dim_x params_dict = parse_simformer_params(config_path) params = SimformerParams( rngs=nnx.Rngs(0), dim_joint=dim_joint, **params_dict, ) # Training parameters training_config = cls._get_default_training_config() training_config["checkpoint_dir"] = checkpoint_dir training_config_ = parse_training_config(config_path) for key, value in training_config_.items(): training_config[key] = value # update with config file values pipeline = cls( train_dataset, val_dataset, dim_theta, dim_x, params, training_config, ) return pipeline
[docs] def _make_model(self, params): """ Create and return the Simformer model to be trained. """ model = Simformer(params) return model
[docs] def _get_default_params(self): """ Return default parameters for the Simformer model. """ params = SimformerParams( rngs=nnx.Rngs(0), in_channels=1, dim_value=40, dim_id=40, dim_condition=10, dim_joint=self.dim_joint, fourier_features=128, num_heads=4, num_layers=8, widening_factor=3, qkv_features=40, num_hidden_layers=1, ) return params
[docs] def sample( self, key, x_o, nsamples=10_000, step_size=0.01, use_ema=True, time_grid=None ): model_extras = { "edge_mask": self.edge_mask, } return super().sample( key, x_o, nsamples=nsamples, step_size=step_size, use_ema=use_ema, time_grid=time_grid, **model_extras, )
[docs] def compute_unnorm_logprob( self, x_1, x_o, step_size=0.01, use_ema=True, time_grid=None ): model_extras = { "edge_mask": self.edge_mask, } return super().compute_unnorm_logprob( x_1, x_o, step_size=step_size, use_ema=use_ema, time_grid=time_grid, **model_extras, )
[docs] class SimformerDiffusionPipeline(JointDiffusionPipeline): def __init__( self, train_dataset, val_dataset, dim_theta: int, dim_x: int, params=None, training_config=None, edge_mask=None, ): """ Diffusion pipeline for training and using a Simformer model for simulation-based inference. Parameters ---------- train_dataset : grain dataset or iterator over batches Training dataset. val_dataset : grain dataset or iterator over batches Validation dataset. dim_theta : int Dimension of the parameter space. dim_x : int Dimension of the observation space. params : SimformerParams, optional Parameters for the Simformer model. If None, default parameters are used. training_config : dict, optional Configuration for training. If None, default configuration is used. edge_mask : jnp.ndarray, optional Edge mask for the Simformer model. If None, no mask is applied. """ super().__init__( None, train_dataset, val_dataset, dim_theta, dim_x, params, training_config ) if params is None: self.params = self._get_default_params()
[docs] self.model = self._make_model(self.params)
[docs] self.ema_model = nnx.clone(self.model)
[docs] self.edge_mask = edge_mask
@classmethod
[docs] def init_pipeline_from_config( cls, train_dataset, val_dataset, dim_theta: int, dim_x: int, config_path: str, checkpoint_dir: str, ): """ Initialize the pipeline from a configuration file. Parameters ---------- config_path : str Path to the configuration file. """ with open(config_path, "r") as f: config = yaml.safe_load(f) # methodology strategy = config.get("strategy", {}) method = strategy.get("method") model_type = strategy.get("model") assert ( method == "diffusion" ), f"Method {method} not supported in SimformerDiffusionPipeline." assert ( model_type == "simformer" ), f"Model type {model_type} not supported in SimformerDiffusionPipeline." # Model parameters from config dim_joint = dim_theta + dim_x params_dict = parse_simformer_params(config_path) params = SimformerParams( rngs=nnx.Rngs(0), dim_joint=dim_joint, **params_dict, ) # Training parameters training_config = cls._get_default_training_config() training_config["checkpoint_dir"] = checkpoint_dir training_config_ = parse_training_config(config_path) for key, value in training_config_.items(): training_config[key] = value # update with config file values pipeline = cls( train_dataset, val_dataset, dim_theta, dim_x, params, training_config, ) return pipeline
[docs] def _make_model(self, params): """ Create and return the Simformer model to be trained. """ model = Simformer(params) return model
[docs] def _get_default_params(self): """ Return default parameters for the Simformer model. """ params = SimformerParams( in_channels=1, dim_value=40, dim_id=40, dim_condition=10, dim_joint=self.dim_joint, fourier_features=128, num_heads=4, num_layers=8, widening_factor=3, qkv_features=40, rngs=nnx.Rngs(0), num_hidden_layers=1, ) return params
[docs] def sample( self, key, x_o, nsamples=10_000, nsteps=18, use_ema=True, return_intermediates=False, ): model_extras = { "edge_mask": self.edge_mask, } return super().sample( key, x_o, nsamples=nsamples, nsteps=nsteps, use_ema=use_ema, return_intermediates=return_intermediates, **model_extras, )