gensbi.recipes#
Cookie cutter modules for creating and training SBI models.
Submodules#
Classes#
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
Package Contents#
- class gensbi.recipes.ConditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(rng, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- path#
- class gensbi.recipes.ConditionalFlowPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(rng, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- p0_dist_model#
- path#
- class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- ema_model#
- model#
- class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.conditional_pipeline.ConditionalFlowPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- ema_model#
- model#
- class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- ema_model#
- model#
- class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- ema_model#
- model#
- class gensbi.recipes.JointDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- node_ids#
- obs_ids#
- path#
- class gensbi.recipes.JointFlowPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- node_ids#
- obs_ids#
- p0_dist_model#
- path#
- class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- edge_mask = None#
- ema_model#
- model#
- class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#
Initialize the pipeline from a configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
dim_theta (int)
dim_x (int)
checkpoint_dir (str)
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- edge_mask = None#
- ema_model#
- model#
- class gensbi.recipes.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(rng, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- loss_fn#
- obs_ids#
- path#
- class gensbi.recipes.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_theta, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).
- classmethod init_pipeline_from_config()[source]#
- Abstractmethod:
Initialize the pipeline from a configuration file.
- Parameters:
train_dataset (iterable) – Training dataset.
val_dataset (iterable) – Validation dataset.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
config_path (str) – Path to the configuration file.
checkpoint_dir (str) – Directory for saving checkpoints.
- Returns:
pipeline – An instance of the pipeline initialized from the configuration.
- Return type:
- sample(rng, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- loss_fn#
- obs_ids#
- p0_dist_model#
- path#