gensbi.recipes#
Cookie cutter modules for creating and training SBI models.
Submodules#
Classes#
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
Package Contents#
- class gensbi.recipes.FluxDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipeline
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
- sample(rng, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- path#
- class gensbi.recipes.FluxFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipeline
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
- sample(rng, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- cond_ids#
- loss_fn#
- obs_ids#
- p0_dist_model#
- path#
- class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipeline
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- classmethod _get_default_training_config()[source]#
Return a dictionary of default training configuration parameters.
- Returns:
training_config – Default training configuration.
- Return type:
dict
- _wrap_model()[source]#
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
- sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- loss_fn#
- path#
- undirected_edge_mask#
- class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.pipeline.AbstractPipeline
Abstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- _wrap_model()[source]#
Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).
- sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#
Generate samples from the trained model.
- Parameters:
rng (jax.random.PRNGKey) – Random number generator key.
x_o (array-like) – Conditioning variable (e.g., observed data).
nsamples (int, optional) – Number of samples to generate.
step_size (float, optional) – Step size for the sampler.
- Returns:
samples – Generated samples.
- Return type:
array-like
- loss_fn#
- p0_dist_model#
- path#
- undirected_edge_mask#