gensbi.recipes#

Cookie cutter modules for creating and training SBI models.

Submodules#

Classes#

ConditionalDiffusionPipeline

Abstract base class for GenSBI training pipelines.

ConditionalFlowPipeline

Abstract base class for GenSBI training pipelines.

Flux1DiffusionPipeline

Abstract base class for GenSBI training pipelines.

Flux1FlowPipeline

Abstract base class for GenSBI training pipelines.

Flux1JointDiffusionPipeline

Abstract base class for GenSBI training pipelines.

Flux1JointFlowPipeline

Abstract base class for GenSBI training pipelines.

JointDiffusionPipeline

Abstract base class for GenSBI training pipelines.

JointFlowPipeline

Abstract base class for GenSBI training pipelines.

SimformerDiffusionPipeline

Abstract base class for GenSBI training pipelines.

SimformerFlowPipeline

Abstract base class for GenSBI training pipelines.

UnconditionalDiffusionPipeline

Abstract base class for GenSBI training pipelines.

UnconditionalFlowPipeline

Abstract base class for GenSBI training pipelines.

Package Contents#

class gensbi.recipes.ConditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(rng, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
path#
class gensbi.recipes.ConditionalFlowPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(rng, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
obs_ids#
p0_dist_model#
path#
class gensbi.recipes.Flux1DiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalDiffusionPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

ema_model#
model#
class gensbi.recipes.Flux1FlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.conditional_pipeline.ConditionalFlowPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux1 model.

_make_model(params)[source]#

Create and return the Flux1 model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

ema_model#
model#
class gensbi.recipes.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointDiffusionPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

ema_model#
model#
class gensbi.recipes.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointFlowPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

ema_model#
model#
class gensbi.recipes.JointDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
node_ids#
obs_ids#
path#
class gensbi.recipes.JointFlowPipeline(model, train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids#
loss_fn#
node_ids#
obs_ids#
p0_dist_model#
path#
class gensbi.recipes.SimformerDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointDiffusionPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

edge_mask = None#
ema_model#
model#
class gensbi.recipes.SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None, edge_mask=None)[source]#

Bases: gensbi.recipes.joint_pipeline.JointFlowPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Simformer model.

_make_model(params)[source]#

Create and return the Simformer model to be trained.

compute_unnorm_logprob(x_1, x_o, step_size=0.01, use_ema=True, time_grid=None)[source]#
classmethod init_pipeline_from_config(train_dataset, val_dataset, dim_theta, dim_x, config_path, checkpoint_dir)[source]#

Initialize the pipeline from a configuration file.

Parameters:
  • config_path (str) – Path to the configuration file.

  • dim_theta (int)

  • dim_x (int)

  • checkpoint_dir (str)

sample(key, x_o, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

edge_mask = None#
ema_model#
model#
class gensbi.recipes.UnconditionalDiffusionPipeline(model, train_dataset, val_dataset, dim_theta, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(rng, nsamples=10000, nsteps=18, use_ema=True, return_intermediates=False, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn#
obs_ids#
path#
class gensbi.recipes.UnconditionalFlowPipeline(model, train_dataset, val_dataset, dim_theta, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

abstractmethod _get_default_params()[source]#

Return a dictionary of default model parameters.

abstractmethod _make_model()[source]#

Create and return the model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using JointWrapper or ConditionalWrapper).

compute_unnorm_logprob(x_1, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#
get_loss_fn()[source]#

Return the loss function for training/validation.

classmethod init_pipeline_from_config()[source]#
Abstractmethod:

Initialize the pipeline from a configuration file.

Parameters:
  • train_dataset (iterable) – Training dataset.

  • val_dataset (iterable) – Validation dataset.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • config_path (str) – Path to the configuration file.

  • checkpoint_dir (str) – Directory for saving checkpoints.

Returns:

pipeline – An instance of the pipeline initialized from the configuration.

Return type:

AbstractPipeline

sample(rng, nsamples=10000, step_size=0.01, use_ema=True, time_grid=None, **model_extras)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

loss_fn#
obs_ids#
p0_dist_model#
path#