gensbi.recipes.flux#

Pipeline for training and using a Flux1 model for simulation-based inference.

Example

Classes#

FluxDiffusionPipeline

Abstract base class for GenSBI training pipelines.

FluxFlowPipeline

Abstract base class for GenSBI training pipelines.

Module Contents#

class gensbi.recipes.flux.FluxDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux model.

classmethod _get_default_training_config()[source]#

Return a dictionary of default training configuration parameters.

Returns:

training_config – Default training configuration.

Return type:

dict

_make_model()[source]#

Create and return the Flux model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(rng, x_o, nsamples=10000, nsteps=18, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids[source]#
loss_fn[source]#
obs_ids[source]#
path[source]#
class gensbi.recipes.flux.FluxFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#

Bases: gensbi.recipes.pipeline.AbstractPipeline

Abstract base class for GenSBI training pipelines.

This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.

Parameters:
  • train_dataset (iterable) – Training dataset, should yield batches of data.

  • val_dataset (iterable) – Validation dataset, should yield batches of data.

  • dim_theta (int) – Dimensionality of the parameter (theta) space.

  • dim_x (int) – Dimensionality of the observation (x) space.

  • params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.

  • training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.

_get_default_params()[source]#

Return default parameters for the Flux model.

_make_model()[source]#

Create and return the Flux model to be trained.

_wrap_model()[source]#

Wrap the model for evaluation (either using SimformerWrapper or Flux1Wrapper).

get_loss_fn()[source]#

Return the loss function for training/validation.

sample(rng, x_o, nsamples=10000, step_size=0.01, use_ema=True)[source]#

Generate samples from the trained model.

Parameters:
  • rng (jax.random.PRNGKey) – Random number generator key.

  • x_o (array-like) – Conditioning variable (e.g., observed data).

  • nsamples (int, optional) – Number of samples to generate.

  • step_size (float, optional) – Step size for the sampler.

Returns:

samples – Generated samples.

Return type:

array-like

cond_ids[source]#
loss_fn[source]#
obs_ids[source]#
p0_dist_model[source]#
path[source]#