gensbi.recipes.flux1joint#
Pipeline for training and using a Flux1 model for simulation-based inference.
Examples
import grain
import numpy as np
import jax
from jax import numpy as jnp
from gensbi.recipes import SimformerPipeline
# Define your training and validation datasets.
train_data = jax.random.rand((1024, 4)) # your training dataset
val_data = jax.random.rand((128, 4)) # your validation dataset
batch_size = 32
train_dataset_grain = (
grain.MapDataset.source(np.array(train_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
val_dataset_grain = (
grain.MapDataset.source(np.array(val_data)[...,None])
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
# .mp_prefetch() # Uncomment if you want to use multiprocessing prefetching
)
# Define the model
dim_theta = 2 # Dimension of the parameter space
dim_x = 2 # Dimension of the observation space
pipeline = SimformerPipeline(train_dataset, val_dataset, dim_theta, dim_x)
# Train the model
rngs = jax.random.PRNGKey(0)
pipeline.train(rngs)
# Sample from the posterior
x_o = jnp.array([0.5, -0.2]) # Example
samples = pipeline.sample(rngs, x_o, nsamples=10000, step_size=0.01)
Note
If you plan on using multiprocessing prefetching, ensure that your script is wrapped in a if __name__ == “__main__”: guard. See https://docs.python.org/3/library/multiprocessing.html
Classes#
Abstract base class for GenSBI training pipelines. |
|
Abstract base class for GenSBI training pipelines. |
Functions#
|
Parse a Flux1Joint configuration file. |
|
Parse a training configuration file. |
|
Sample structured conditional masks for the Simformer model. |
Module Contents#
- class gensbi.recipes.flux1joint.Flux1JointDiffusionPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointDiffusionPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- class gensbi.recipes.flux1joint.Flux1JointFlowPipeline(train_dataset, val_dataset, dim_theta, dim_x, params=None, training_config=None)[source]#
Bases:
gensbi.recipes.joint_pipeline.JointFlowPipelineAbstract base class for GenSBI training pipelines.
This class provides a template for implementing training and evaluation pipelines for conditional generative models. Subclasses should implement model creation, default parameter setup, loss function, sampling, and evaluation methods.
- Parameters:
train_dataset (iterable) – Training dataset, should yield batches of data.
val_dataset (iterable) – Validation dataset, should yield batches of data.
dim_theta (int) – Dimensionality of the parameter (theta) space.
dim_x (int) – Dimensionality of the observation (x) space.
model (nnx.Module, optional) – The model to be trained. If None, the model is created using _make_model.
params (dict, optional) – Model parameters. If None, uses defaults from _get_default_params.
training_config (dict, optional) – Training configuration. If None, uses defaults from _get_default_training_config.
- gensbi.recipes.flux1joint.parse_flux1joint_params(config_path)[source]#
Parse a Flux1Joint configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
- Returns:
config – Parsed configuration dictionary.
- Return type:
dict
- gensbi.recipes.flux1joint.parse_training_config(config_path)[source]#
Parse a training configuration file.
- Parameters:
config_path (str) – Path to the configuration file.
- Returns:
config – Parsed configuration dictionary.
- Return type:
dict
- gensbi.recipes.flux1joint.sample_strutured_conditional_mask(key, num_samples, theta_dim, x_dim, p_joint=0.2, p_posterior=0.2, p_likelihood=0.2, p_rnd1=0.2, p_rnd2=0.2, rnd1_prob=0.3, rnd2_prob=0.7)[source]#
Sample structured conditional masks for the Simformer model.
- Parameters:
key (jax.random.PRNGKey) – Random key for sampling.
num_samples (int) – Number of samples to generate.
theta_dim (int) – Dimension of the parameter space.
x_dim (int) – Dimension of the observation space.
p_joint (float) – Probability of selecting the joint mask.
p_posterior (float) – Probability of selecting the posterior mask.
p_likelihood (float) – Probability of selecting the likelihood mask.
p_rnd1 (float) – Probability of selecting the first random mask.
p_rnd2 (float) – Probability of selecting the second random mask.
rnd1_prob (float) – Probability of a True value in the first random mask.
rnd2_prob (float) – Probability of a True value in the second random mask.
- Returns:
condition_mask – Array of shape (num_samples, theta_dim + x_dim) with boolean masks.
- Return type:
jnp.ndarray