Two Moons Flux1 Flow Matching Example#
This notebook demonstrates conditional flow-matching on the Two Moons task using JAX and Flax.
About Simulation-Based Inference (SBI): SBI refers to a class of methods for inferring parameters of complex models when the likelihood function is intractable, but simulation is possible. SBI algorithms learn to approximate the posterior distribution over parameters given observed data, enabling inference in scientific and engineering domains where traditional methods fail.
The Two Moons Dataset: The Two Moons dataset is a two-dimensional simulation-based inference benchmark designed to test an algorithm’s ability to handle complex posterior distributions. Its posterior is both bimodal (two distinct peaks) and locally crescent-shaped, making it a challenging task for inference algorithms. The primary purpose of this benchmark is to evaluate how well different methods can capture and represent multimodality and intricate structure in the posterior.
Purpose of This Notebook: This notebook trains and evaluates a Flux1 flow-matching model on the Two Moons task. The goal is to assess the model’s ability to learn and represent a non-trivial posterior distribution with both global (bimodal) and local (crescent-shaped) complexity.
Table of Contents#
Section |
Description |
|---|---|
1. Introduction & Setup |
Overview, environment, device, autoreload |
2. Task & Data Preparation |
Define task, visualize data, create datasets |
3. Model Configuration & Definition |
Load config, set parameters, instantiate model |
4. Training |
Train or restore model, manage checkpoints |
5. Evaluation & Visualization |
Visualize loss, sample posterior, compute log prob |
6. Animations |
Create and display GIFs of results |
## 1. Introduction & Setup#
In this section, we introduce the problem, set up the computational environment, import required libraries, configure JAX for CPU or GPU usage, and enable autoreload for iterative development. Compatibility with Google Colab is also ensured.
[1]:
# Check if running on Colab and install dependencies if needed
try:
import google.colab
colab = True
except ImportError:
colab = False
if colab:
# Install required packages and clone the repository
%pip install "gensbi_examples[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI-examples"
!git clone https://github.com/aurelio-amerio/GenSBI-examples
%cd GenSBI-examples/examples/sbi-benchmarks/two_moons
[2]:
# load autoreload extension
%load_ext autoreload
%autoreload 2
[3]:
import os
# select device
os.environ["JAX_PLATFORMS"] = "cuda"
# os.environ["JAX_PLATFORMS"] = "cpu"
## 2. Task & Data Preparation#
In this section, we define the Two Moons task, visualize reference samples, and create the training and validation datasets required for model learning. Batch size and sample count are set for reproducibility and performance.
[4]:
restore_model=True
train_model=False
[5]:
import orbax.checkpoint as ocp
# get the current notebook path
notebook_path = os.getcwd()
checkpoint_dir = os.path.join(notebook_path, "checkpoints")
os.makedirs(checkpoint_dir, exist_ok=True)
[6]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax import nnx
from numpyro import distributions as dist
import numpy as np
[7]:
from gensbi.utils.plotting import plot_marginals
[8]:
from gensbi_examples.tasks import TwoMoons
task = TwoMoons()
/home/aure/miniforge3/envs/gensbi/lib/python3.12/site-packages/google/protobuf/runtime_version.py:98: UserWarning: Protobuf gencode version 5.28.3 is exactly one major version older than the runtime version 6.31.1 at grain/proto/execution_summary.proto. Please update the gencode to avoid compatibility violations in the next runtime release.
warnings.warn(
[9]:
# reference posterior for an observation
obs, reference_samples = task.get_reference(num_observation=8)
[10]:
# plot the 2D posterior
plot_marginals(np.asarray(reference_samples, dtype=np.float32), gridsize=50,range=[(-1,0),(0,1)], plot_levels=False, backend="seaborn")
plt.show()
[11]:
# make a dataset
nsamples = int(1e5)
[12]:
# Set batch size for training. Larger batch sizes help prevent overfitting, but are limited by available GPU memory.
batch_size = 4096
# Create training and validation datasets using the Two Moons task object.
train_dataset = task.get_train_dataset(batch_size)
val_dataset = task.get_val_dataset()
# Create iterators for the training and validation datasets.
dataset_iter = iter(train_dataset)
val_dataset_iter = iter(val_dataset)
## 3. Model Configuration & Definition#
In this section, we load the model and optimizer configuration, set all relevant parameters, and instantiate the Flux1 model. Edge masks and marginalization functions are used for flexible inference and posterior estimation.
[ ]:
from gensbi.recipes import Flux1FlowPipeline
[ ]:
import yaml
# Path to the flux1 flow-matching configuration file.
config_path = f"{notebook_path}/config/config_flow_flux.yaml"
[15]:
# Extract dimensionality information from the task object.
dim_theta = task.dim_theta # Number of parameters to infer
dim_data = task.dim_data # Number of observed data dimensions
dim_joint = task.dim_joint # Joint dimension (for model input)
[16]:
pipeline = Flux1FlowPipeline.init_pipeline_from_config(
train_dataset,
val_dataset,
dim_theta,
dim_data,
config_path,
checkpoint_dir,
)
[17]:
pipeline.restore_model()
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/mnt/c/Users/Aure/Documents/GitHub/GenSBI-examples/examples/sbi-benchmarks/two_moons/flow_flux/checkpoints
WARNING:absl:CheckpointManagerOptions.read_only=True, setting save_interval_steps=0.
WARNING:absl:CheckpointManagerOptions.read_only=True, setting create=False.
WARNING:absl:Given directory is read only=/mnt/c/Users/Aure/Documents/GitHub/GenSBI-examples/examples/sbi-benchmarks/two_moons/flow_flux/checkpoints/ema
Restored model from checkpoint
## 5. Evaluation & Visualization#
In this section, we evaluate the trained Simformer model by sampling from the posterior, and comparing results to reference data. We also compute and visualize the unnormalized log probability over a grid to assess model calibration and density estimation. These analyses provide insight into model performance and reliability.
### Section 5.1: Posterior Sampling#
In this section, we sample from the posterior distribution using the trained model and visualize the results. Posterior samples are generated for a selected observation and compared to reference samples to assess model accuracy.
[18]:
# we want to do conditional inference. We need an observation for which we want to ocmpute the posterior
def get_samples(idx, nsamples=10_000, use_ema=False, rng=None):
observation, reference_samples = task.get_reference(idx)
true_param = jnp.array(task.get_true_parameters(idx))
if rng is None:
rng = jax.random.PRNGKey(42)
time_grid = jnp.linspace(0,1,100)
samples = pipeline.sample(rng, observation, nsamples, use_ema=use_ema, time_grid=time_grid)
return samples, true_param, reference_samples
[19]:
samples, true_param, reference_samples = get_samples(8)
[20]:
samples.shape
[20]:
(100, 10000, 2)
### Section 5.2: Visualize Posterior Samples#
In this section, we plot the posterior samples as a 2D histogram to visualize the learned distribution and compare it to the ground truth.
[21]:
from gensbi.utils.plotting import plot_marginals, plot_2d_dist_contour
[22]:
plot_marginals(samples[-1], plot_levels=False, gridsize=50, range=[(-1., 0), (0, 1.)])
plt.show()
# check how we set the ranges in the seaborn plot, it seems wrong
plot_marginals(samples[-1], plot_levels=False, backend="seaborn", gridsize=50, range =[(-1., 0), (0, 1.)])
plt.text(1.05, 1.05, f"t = {1.0}", transform=plt.gca().transAxes)
plt.show()
<Figure size 640x480 with 0 Axes>
### 5.4. Animations#
In this section, we create and display animations of posterior samples and density contours over time. These visualizations illustrate the evolution of the learned distribution during the sampling process, providing dynamic insight into model behavior and convergence.
[40]:
import imageio.v3 as imageio
import io
from tqdm import tqdm
[41]:
# samples
images = []
for i in tqdm(range(len(samples))):
fig, axes = plot_marginals(
samples[i],
plot_levels=False,
gridsize=50,
range=[(-1.0, 0), (0, 1.0)],
backend="seaborn",
)
# manually set the ticks to make a prettier plot
axes[0,0].set_ylim(0,6)
axes[0,0].set_yticks([5])
axes[1,1].set_xlim(0,6)
axes[1,1].set_xticks([5])
axes[1,1].text(0, 1.03, f"t = {(i+1)/len(samples):.2f}", transform=plt.gca().transAxes)
buf = io.BytesIO()
plt.savefig(buf, format="png", dpi=300)
buf.seek(0)
image = imageio.imread(buf)
buf.close()
if i == 0:
images = []
images.append(image)
plt.close()
100%|██████████| 100/100 [00:30<00:00, 3.29it/s]
[42]:
# repeat the last frame 10 times to make the gif pause at the end
images += [images[-1]] * 20
[43]:
imageio.imwrite(
'animated_plot_samples_flux1.gif',
images,
duration=5000/len(images),
loop=0 # 0 means loop indefinitely
)
