gensbi.recipes#

Cookie cutter modules for creating and training SBI models.

Submodules#

Classes#

FluxDiffusionPipeline

Abstract base class for GenSBI training pipelines.

FluxFlowPipeline

Abstract base class for GenSBI training pipelines.

SimformerDiffusionPipeline

Abstract base class for GenSBI training pipelines.

SimformerFlowPipeline

Abstract base class for GenSBI training pipelines.

Package Contents#

class gensbi.recipes.FluxDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux model.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

_make_model()[source]#

Create and return the Flux model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(rng, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
path#
class gensbi.recipes.FluxFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux model.

_make_model()[source]#

Create and return the Flux model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(rng, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
p0_dist_model#
path#
class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

_make_model()[source]#

Create and return the Simformer model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn#
path#
undirected_edge_mask#
class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model()[source]#

Create and return the Simformer model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn#
p0_dist_model#
path#
undirected_edge_mask#