Gaussian Linear Simformer Conditional Flow Matching Example#
This notebook demonstrates how to train and sample from a conditional flow-matching model on the Gaussian Linear task using JAX and Flax. We cover environment setup, data generation, model definition, training, sampling, and evaluation.
1. Environment Setup#
We set up the notebook environment, import required libraries, and configure JAX for CPU or GPU usage. This section also ensures compatibility with Google Colab.
[1]:
# Check if running on Colab and install dependencies if needed
try:
import google.colab
colab = True
except ImportError:
colab = False
if colab:
# Install required packages and clone the repository
%pip install "gensbi_examples[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI-examples"
!git clone https://github.com/aurelio-amerio/GenSBI-examples
%cd GenSBI-examples/examples/sbi-benchmarks/gaussian_linear
[2]:
# load autoreload extension
%load_ext autoreload
%autoreload 2
[3]:
# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
import os
os.environ['JAX_PLATFORMS']="cuda"
# os.environ['JAX_PLATFORMS']="cpu"
[4]:
experiment_id = 2
Set Training and Model Restoration Flags#
Configure whether to restore a pretrained model or train from scratch.
[5]:
restore_model=False
train_model=True
Set Checkpoint Directory#
Specify the directory for saving and restoring model checkpoints.
[6]:
import orbax.checkpoint as ocp
# get the current notebook path
notebook_path = os.getcwd()
checkpoint_dir = f"{notebook_path}/checkpoints/gaussian_linear_simformer"
os.makedirs(checkpoint_dir, exist_ok=True)
2. Library Imports and JAX Mesh Setup#
Import required libraries and set up the JAX mesh for sharding.
[7]:
import orbax.checkpoint as ocp
[8]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau
from numpyro import distributions as dist
import numpy as np
from tqdm.auto import tqdm
from functools import partial
# Define the mesh for JAX sharding (for model restoration on CPU/GPU)
devices = jax.devices()
mesh = jax.sharding.Mesh(devices, axis_names=('data',)) # A simple 1D mesh
[9]:
print(devices)
[CudaDevice(id=0)]
3. Optimizer and Learning Rate Schedule Parameters#
Define optimizer hyperparameters and learning rate scheduling.
[10]:
# @markdown Define optimizer and learning rate schedule parameters
# @markdown Number of epochs to wait before resuming normal operation after the learning rate reduction:
PATIENCE = 10 # @param{type:"integer"}
# @markdown Factor by which to reduce the learning rate:
COOLDOWN = 2 # @param{type:"integer"}
# @markdown Relative tolerance for measuring the new optimum:
FACTOR = 0.5 # @param{type:"number"}
# @markdown Number of iterations to accumulate an average value:
ACCUMULATION_SIZE = 100
RTOL = 1e-4 # @param{type:"number"}
# @markdown learning rate
MAX_LR = 1e-3 # @param{type:"number"}
MIN_LR = 0 # @param{type:"number"}
MIN_SCALE = MIN_LR / MAX_LR
4. Task and Dataset Setup#
Define the Two Moons task and prepare training and validation datasets.
[11]:
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.solver import ODESolver
from gensbi.utils.plotting import plot_marginals, plot_2d_dist_contour
Define the Task#
[12]:
from gensbi_examples.tasks import GaussianLinear
task = GaussianLinear()
./task_data/data_gaussian_linear.npz already exists, skipping download.
[13]:
idx=1
obs, reference_samples = task.get_reference(num_observation=idx)
true_param = task.get_true_parameters(idx).reshape(-1)
Visualize Reference Samples#
Plot the reference samples from the Two Moons task.
[14]:
plot_marginals(reference_samples,true_param=true_param, gridsize=25)
plt.show()
<Figure size 640x480 with 0 Axes>

5. Dataset Preparation#
Create training and validation datasets for the model.
[15]:
# make a dataset
nsamples = int(1e5)
[16]:
batch_size = 1024*4 # the model greatly benefits from larger batch sizes to avoid overfitting, but this is limited by the GPU memory
train_dataset = task.get_train_dataset(batch_size)
val_dataset = task.get_val_dataset()
dataset_iter = iter(train_dataset)
val_dataset_iter = iter(val_dataset)
[17]:
next(dataset_iter).shape, next(val_dataset_iter).shape
[17]:
((4096, 20), (512, 20))
6. Model Definition#
Define the Simformer model and the conditional flow-matching loss.
Note:
The model uses edge masks as attention masks, which are crucial for controlling which variables are attended to during training and inference. These masks enable both posterior estimation (conditioning on observed data) and unconditional density estimation (no conditioning).
The marginalization function is used to construct edge masks that marginalize out arbitrary variables, allowing the model to learn and evaluate arbitrary marginal distributions.
[18]:
from gensbi.models import Simformer, SimformerParams, SimformerCFMLoss, SimformerConditioner
[19]:
path = AffineProbPath(scheduler=CondOTScheduler()) # define the probability path
[20]:
dim_theta = task.dim_theta
dim_data = task.dim_data
dim_joint = task.dim_joint
node_ids = jnp.arange(dim_joint)
[21]:
dim_theta, dim_data
[21]:
(array(10), array(10))
[22]:
params = SimformerParams(
rngs = nnx.Rngs(0),
dim_value = 40,
dim_id = 40,
dim_condition = 10,
dim_joint= dim_joint,
fourier_features = 128,
num_heads = 6,
num_layers = 8,
widening_factor = 3,
qkv_features = 90, # this bottlenecks the transformer features to 92, instead of the token dimension
num_hidden_layers = 1)
[23]:
loss_fn_cfm = SimformerCFMLoss(path)
[24]:
def marginalize(rng: jax.random.PRNGKey, edge_mask: jax.Array):
# This function creates an edge mask that marginalizes out a single node from the adjacency matrix.
# By setting the corresponding row and column to False (except the diagonal), we can compute arbitrary marginals.
idx = jax.random.choice(rng, jnp.arange(edge_mask.shape[0]), shape=(1,), replace=False)
edge_mask = edge_mask.at[idx, :].set(False)
edge_mask = edge_mask.at[:, idx].set(False)
edge_mask = edge_mask.at[idx, idx].set(True)
return edge_mask
[25]:
# Edge masks are attention masks that control which variables are attended to.
# - undirected_edge_mask: all variables attend to each other (for unconditional density estimation)
# - posterior_faithfull: mask for posterior estimation (conditioning on observed data)
undirected_edge_mask = jnp.ones((dim_joint, dim_joint), dtype=jnp.bool_)
posterior_mask = jnp.concatenate([jnp.zeros((dim_theta), dtype=jnp.bool_), jnp.ones((dim_data), dtype=jnp.bool_)], axis=-1)
posterior_faithfull = task.get_edge_mask_fn("faithfull")(node_ids, condition_mask=posterior_mask)
[26]:
@partial(jax.jit, static_argnames=["nsamples"])
def get_random_condition_mask(rng: jax.random.PRNGKey, nsamples):
mask_joint = jnp.zeros((5*nsamples, dim_joint ), dtype=jnp.bool_)
mask_posterior = jnp.concatenate([jnp.zeros((nsamples, dim_theta), dtype=jnp.bool_), jnp.ones((nsamples, dim_data), dtype=jnp.bool_)], axis=-1)
mask1 = jax.random.bernoulli(rng, p=0.3, shape=(nsamples, dim_joint))
filter = ~jnp.all(mask1, axis=-1)
mask1 = jnp.logical_and(mask1, filter.reshape(-1,1))
# masks = jnp.concatenate([mask_joint, mask1, mask_posterior, mask_likelihood], axis=0)
masks = jnp.concatenate([mask_joint, mask1, mask_posterior], axis=0)
return jax.random.choice(rng, masks, shape=(nsamples,), replace=False, axis=0)
[27]:
p0_dist_model = dist.Independent(
dist.Normal(loc=jnp.zeros((dim_joint,)), scale=jnp.ones((dim_joint,))),
reinterpreted_batch_ndims=1
)
[28]:
def loss_fn_(vf_model, x_1, key: jax.random.PRNGKey):
batch_size = x_1.shape[0]
rng_x0, rng_t, rng_condition, rng_edge_mask1, rng_edge_mask2 = jax.random.split(key, 5)
# Generate data and random times
x_0 = p0_dist_model.sample(rng_x0, (batch_size,)) # n, T_max, 1
t = jax.random.uniform(rng_t, x_1.shape[0])
batch = (x_0, x_1, t)
# Condition mask -> randomly condition on some data. Here you can choose between the different condition masks, and you should specify the conditionals you may want to compute afterwards.
condition_mask = get_random_condition_mask(rng_condition, batch_size)
# undirected_edge_mask
undirected_edge_mask_ = jnp.repeat(undirected_edge_mask[None,...], 3*batch_size, axis=0) # Dense default mask
# faithfull posterior mask
faithfull_edge_mask_ = jnp.repeat(posterior_faithfull[None,...], 3*batch_size, axis=0) # Dense default mask
# Include marginal consistency by generating edge masks that marginalize out random nodes.
# This allows the model to learn arbitrary marginal distributions.
marginal_mask = jax.vmap(marginalize, in_axes=(0,None))(jax.random.split(rng_edge_mask1, (batch_size,)), undirected_edge_mask)
edge_masks = jnp.concatenate([undirected_edge_mask_, faithfull_edge_mask_, marginal_mask], axis=0)
# Randomly choose between dense, posterior, and marginal edge masks for each batch element.
edge_masks = jax.random.choice(rng_edge_mask2, edge_masks, shape=(batch_size,), axis=0) # Randomly choose between dense and marginal mask
loss = loss_fn_cfm(vf_model, batch, node_ids=node_ids, edge_mask=edge_masks,condition_mask=condition_mask, )
return loss
[29]:
@nnx.jit
def train_loss(vf_model, key: jax.random.PRNGKey):
x_1 = next(dataset_iter) # n, T_max, 1
return loss_fn_(vf_model, x_1, key)
[30]:
@nnx.jit
def val_loss(vf_model, key):
x_1 = next(val_dataset_iter)
return loss_fn_(vf_model, x_1, key)
[31]:
@nnx.jit
def train_step(model, optimizer, rng):
loss_fn = lambda model: train_loss(model, rng)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(grads, value=loss) # In place updates.
return loss
[32]:
vf_model = Simformer(params)
7. Model Restoration#
Restore the model from checkpoint if requested.
[33]:
if restore_model:
model_state = nnx.state(vf_model)
graphdef, abstract_state = nnx.split(vf_model)
with ocp.CheckpointManager(
checkpoint_dir, options=ocp.CheckpointManagerOptions(read_only=True)
) as read_mgr:
restored = read_mgr.restore(
experiment_id,
# pass in the model_state to restore the exact same State type
args=ocp.args.Composite(state=ocp.args.PyTreeRestore(item=model_state))
)
vf_model= nnx.merge(graphdef, restored["state"])
print("Restored model from checkpoint")
8. Optimizer Setup#
Set up the optimizer and learning rate schedule.
[34]:
# reduce on plateau schedule
nsteps = 10_000
nepochs = 3
multistep = 1 # if the GPU cannot support batch sizes of at least 4k, adjust this value accordingly to get the desired effective batch size
opt = optax.chain(
optax.adaptive_grad_clip(10.0),
optax.adamw(MAX_LR),
reduce_on_plateau(
patience=PATIENCE,
cooldown=COOLDOWN,
factor=FACTOR,
rtol=RTOL,
accumulation_size=ACCUMULATION_SIZE,
min_scale=MIN_SCALE,
),
)
if multistep > 1:
opt = optax.MultiSteps(opt, multistep)
optimizer = nnx.Optimizer(vf_model, opt)
[35]:
rngs = nnx.Rngs(0)
[36]:
best_state = nnx.state(vf_model)
best_val_loss_value = val_loss(vf_model, jax.random.PRNGKey(0))
val_error_ratio = 1.1
counter = 0
cmax = 10
print_every = 100
loss_array = []
val_loss_array = []
early_stopping = True
9. Training Loop#
Train the model using the defined optimizer and loss function. Early stopping and learning rate scheduling are used for efficient training.
[37]:
if train_model:
vf_model.train()
for ep in range(nepochs):
pbar = tqdm(range(nsteps))
l = 0
v_l = 0
for j in pbar:
if counter > cmax and early_stopping:
print("Early stopping")
# restore the model state
graphdef, abstract_state = nnx.split(vf_model)
vf_model = nnx.merge(graphdef, best_state)
break
loss = train_step(vf_model, optimizer, rngs.train_step())
l += loss.item()
v_loss = val_loss(vf_model, rngs.val_step())
v_l += v_loss.item()
if j > 0 and j % 100 == 0:
loss_ = l / 100
val_ = v_l / 100
ratio1 = val_ / loss_
ratio2 = val_ / best_val_loss_value
# if ratio1 < val_error_ratio and ratio2 < 1.05:
if ratio1 < val_error_ratio:
if val_ < best_val_loss_value:
best_val_loss_value = val_
best_state = nnx.state(vf_model)
counter = 0
else:
counter += 1
pbar.set_postfix(
loss=f"{loss_:.4f}",
ratio=f"{ratio1:.4f}",
counter=counter,
val_loss=f"{val_:.4f}",
)
loss_array.append(loss_)
val_loss_array.append(val_)
l = 0
v_l = 0
vf_model.eval()
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x146b7520b0b0>>
Traceback (most recent call last):
File "/lhome/ific/a/aamerio/miniforge3/envs/gensbi/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
def _clean_thread_parent_frames(
KeyboardInterrupt:
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[37], line 17
14 vf_model = nnx.merge(graphdef, best_state)
15 break
---> 17 loss = train_step(vf_model, optimizer, rngs.train_step())
18 l += loss.item()
20 v_loss = val_loss(vf_model, rngs.val_step())
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:350, in jit.<locals>.jit_wrapper(*args, **kwargs)
340 with graph.update_context(jit_wrapper):
341 pure_args, pure_kwargs = extract.to_tree(
342 (args, kwargs),
343 prefix=(in_shardings, kwarg_shardings)
(...) 348 ctxtag=jit_wrapper,
349 )
--> 350 pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
351 *pure_args, **pure_kwargs
352 )
353 _args_out, _kwargs_out, out = extract.from_tree(
354 (pure_args_out, pure_kwargs_out, pure_out),
355 merge_fn=_jit_merge_fn,
356 is_inner=False,
357 ctxtag=jit_wrapper,
358 )
359 return out
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:1123, in register_static.<locals>.<lambda>(obj, empty_iter_children)
1092 """Registers `cls` as a pytree with no leaves.
1093
1094 Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can
(...) 1120 Array(3, dtype=int32, weak_type=True)
1121 """
1122 flatten = lambda obj: ((), obj)
-> 1123 unflatten = lambda obj, empty_iter_children: obj
1124 register_pytree_with_keys(cls, flatten, unflatten)
1125 return cls
KeyboardInterrupt:
10. Save Model Checkpoint#
Save the trained model to a checkpoint for future restoration.
[38]:
# save the model
if train_model:
checkpoint_manager = ocp.CheckpointManager(checkpoint_dir,
options=ocp.CheckpointManagerOptions(
max_to_keep=None,
keep_checkpoints_without_metrics=True,
create=True,
),
)
model_state = nnx.state(vf_model)
checkpoint_manager.save(
experiment_id, args=ocp.args.Composite(state=ocp.args.PyTreeSave(model_state))
)
checkpoint_manager.close()
11. Training and Validation Loss Visualization#
Plot the training and validation loss curves.
[ ]:
if train_model:
plt.plot(loss_array, label="train loss")
plt.plot(val_loss_array, label="val loss")
plt.xlabel("steps")
plt.ylabel("loss")
plt.legend()
plt.show()
12. Posterior Sampling#
Sample from the posterior distribution using the trained model and visualize the results.
[ ]:
from gensbi.utils.model_wrapping import ModelWrapper
class SimWrapper(ModelWrapper):
def __init__(self, model):
super().__init__(model)
def _call_model(self, x, t, args, **kwargs):
return self.model(obs=x, timesteps=t, **kwargs)
[ ]:
obs_ids = jnp.arange(dim_theta) # observation ids
cond_ids = jnp.arange(dim_theta, dim_joint) # conditional ids
step_size = 0.01
# conditional sampling
def get_samples(vf_wrapped, idx, nsamples=10_000, edge_mask=undirected_edge_mask):
observation, reference_samples = task.get_reference(idx)
true_param = jnp.array(task.get_true_parameters(idx))
rng = jax.random.PRNGKey(45)
key1,key2 = jax.random.split(rng, 2)
x_init = jax.random.normal(key1,(nsamples, dim_theta)) # n, T_max, 1
cond = jnp.broadcast_to(observation[...,None], (1, dim_data, 1)) # n, dim_theta, 1
solver = ODESolver(velocity_model=vf_wrapped) # create an ODESolver class
model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": edge_mask}
sampler_ = solver.get_sampler(method='Dopri5', step_size=step_size, return_intermediates=False, model_extras=model_extras)
samples = sampler_(x_init) # sample from the model
return samples, true_param, reference_samples
def plot_samples(samples, true_param):
plot_marginals(samples, true_param=true_param)
plt.show()
[ ]:
vf_cond = SimformerConditioner(vf_model)
vf_wrapped = SimWrapper(vf_cond)
[ ]:
idx=1
samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=100_000, edge_mask=posterior_faithfull)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
169 try:
--> 170 vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
171 except Exception as e:
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/equinox/_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs)
37 dynamic, static = partition((fun, args, kwargs), _filter)
---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
39 return combine(dynamic_out, static_out.value)
[... skipping hidden 1 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api.py:3014, in eval_shape(fun, *args, **kwargs)
3013 except TypeError: fun = partial(fun)
-> 3014 return jit(fun).eval_shape(*args, **kwargs)
[... skipping hidden 1 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:352, in jit_eval_shape(jit_func, *args, **kwargs)
350 @api_boundary
351 def jit_eval_shape(jit_func, *args, **kwargs):
--> 352 p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
353 out_shardings = [None if isinstance(s, UnspecifiedValue) else s
354 for s in p.params['out_shardings']]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:686, in _infer_params(fun, ji, args, kwargs)
685 return _infer_params_internal(fun, ji, args, kwargs)
--> 686 return _infer_params_internal(fun, ji, args, kwargs)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:710, in _infer_params_internal(fun, ji, args, kwargs)
709 if entry.pjit_params is None:
--> 710 p, args_flat = _infer_params_impl(
711 fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
712 if p.attrs_tracked or p.box_data or p.params['jaxpr'].jaxpr.is_high:
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:606, in _infer_params_impl(***failed resolving arguments***)
604 attr_token = _attr_cache_index(flat_fun, in_type)
--> 606 jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
607 flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
609 if config.mutable_array_checks.value:
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:471, in cache.<locals>.memoized_fun(fun, *args)
470 start = time.time()
--> 471 ans = call(fun, *args)
472 if do_explain:
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:1414, in _create_pjit_jaxpr(***failed resolving arguments***)
1413 else:
-> 1414 jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type)
1416 if config.debug_key_reuse.value:
1417 # Import here to avoid circular imports
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs)
353 with TraceAnnotation(name, **decorator_kwargs):
--> 354 return func(*args, **kwargs)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2292, in trace_to_jaxpr_dynamic(fun, in_avals, keep_inputs, lower)
2291 with core.set_current_trace(trace):
-> 2292 ans = fun.call_wrapped(*in_tracers)
2293 _check_returned_jaxtypes(fun.debug_info, ans)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
210 """Calls the transformed function"""
--> 211 return self.f_transformed(*args, **kwargs)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api_util.py:288, in _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs)
287 assert next(fixed_args_, sentinel) is sentinel
--> 288 return _fun(*args, **kwargs)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api_util.py:73, in flatten_fun(f, store, in_tree, *args_flat)
72 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 73 ans = f(*py_args, **py_kwargs)
74 ans, out_tree = tree_flatten(ans)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:396, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
394 @transformation_with_aux2
395 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 396 ans = _fun(*args, **kwargs)
397 result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/equinox/_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic)
32 _fun, _args, _kwargs = combine(_static, _dynamic)
---> 33 _out = _fun(*_args, **_kwargs)
34 _dynamic_out, _static_out = partition(_out, _filter)
[... skipping hidden 1 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_term.py:194, in ODETerm.vf(self, t, y, args)
193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
--> 194 out = self.vector_field(t, y, args)
195 if jtu.tree_structure(out) != jtu.tree_structure(y):
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/utils/model_wrapping.py:75, in ModelWrapper.get_vector_field.<locals>.vf(t, x, args)
74 def vf(t, x, args):
---> 75 vf = self._call_model(x, t, args, **kwargs)
76 # squeeze the first dimension of the vector field if it is 1
Cell In[59], line 7, in SimWrapper._call_model(self, x, t, args, **kwargs)
6 def _call_model(self, x, t, args, **kwargs):
----> 7 return self.model(obs=x, timesteps=t, **kwargs)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:311, in SimformerConditioner.__call__(self, obs, obs_ids, cond, cond_ids, timesteps, conditioned, edge_mask)
310 if conditioned:
--> 311 return self.conditioned(
312 obs, obs_ids, cond, cond_ids, timesteps, edge_mask=edge_mask
313 )
314 else:
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:234, in SimformerConditioner.conditioned(self, obs, obs_ids, cond, cond_ids, t, edge_mask)
229 # Sort the nodes and the corresponding values
230 # nodes_sort = jnp.argsort(node_ids)
231 # x = x[:, nodes_sort]
232 # node_ids = node_ids[nodes_sort]
--> 234 res = self.model(
235 x=x,
236 t=t,
237 node_ids=node_ids,
238 condition_mask=condition_mask,
239 edge_mask=edge_mask,
240 )
241 # now return only the values on which we are not conditioning
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:137, in Simformer.__call__(self, x, t, args, node_ids, condition_mask, edge_mask)
136 batch_size, seq_len, _ = x.shape
--> 137 condition_mask = condition_mask.astype(jnp.bool_).reshape(-1, seq_len, 1)
138 condition_mask = jnp.broadcast_to(condition_mask, (batch_size, seq_len, 1))
[... skipping hidden 2 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:464, in _compute_newshape(arr, newshape)
462 if (all(isinstance(d, int) for d in (*arr.shape, *other_sizes)) and
463 arr.size % math.prod(other_sizes) != 0):
--> 464 raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) "
465 f"into shape {orig_newshape} because the product of "
466 f"specified axis sizes ({math.prod(other_sizes)}) does "
467 f"not evenly divide {arr.size}")
468 sz = core.cancel_divide_tracers(arr.shape, other_sizes)
TypeError: cannot reshape array of shape (12,) (size 12) into shape (-1, 20, 1) because the product of specified axis sizes (20) does not evenly divide 12
During handling of the above exception, another exception occurred:
ValueError Traceback (most recent call last)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:200, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
199 with jax.numpy_dtype_promotion("standard"):
--> 200 jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
201 except ValueError as e:
202 # ValueError may also arise from mismatched tree structures
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:172, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
171 except Exception as e:
--> 172 raise ValueError(f"Error while tracing {term}.vf: " + str(e))
173 vf_type_compatible = eqx.filter_eval_shape(
174 better_isinstance, vf_type, vf_type_expected
175 )
ValueError: Error while tracing ODETerm(vector_field=<function ModelWrapper.get_vector_field.<locals>.vf>).vf: cannot reshape array of shape (12,) (size 12) into shape (-1, 20, 1) because the product of specified axis sizes (20) does not evenly divide 12
The above exception was the direct cause of the following exception:
ValueError Traceback (most recent call last)
Cell In[207], line 2
1 idx=1
----> 2 samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=100_000)
3 # samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=10)
Cell In[40], line 21, in get_samples(vf_wrapped, idx, nsamples, edge_mask)
18 model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": edge_mask}
20 sampler_ = solver.get_sampler(method='Dopri5', step_size=step_size, return_intermediates=False, model_extras=model_extras)
---> 21 samples = sampler_(x_init) # sample from the model
23 return samples, true_param, reference_samples
[... skipping hidden 14 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/flow_matching/solver/ode_solver.py:107, in ODESolver.get_sampler.<locals>.sampler(x_init)
104 @jax.jit
105 def sampler(x_init):
--> 107 solution = diffrax.diffeqsolve(
108 term,
109 solver,
110 t0=time_grid[0],
111 t1=time_grid[-1],
112 dt0=step_size,
113 y0=x_init,
114 saveat=(
115 diffrax.SaveAt(ts=time_grid)
116 if return_intermediates
117 else diffrax.SaveAt(t1=True)
118 ),
119 stepsize_controller=stepsize_controller,
120 )
121 return solution.ys if return_intermediates else solution.ys[-1]
[... skipping hidden 19 frame]
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:1103, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
1100 terms = MultiTerm(*terms)
1102 # Error checking for term compatibility
-> 1103 _assert_term_compatible(
1104 t0,
1105 y0,
1106 args,
1107 terms,
1108 solver.term_structure,
1109 solver.term_compatible_contr_kwargs,
1110 )
1112 if is_sde(terms):
1113 if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:205, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
203 pretty_term = wl.pformat(terms)
204 pretty_expected = wl.pformat(term_structure)
--> 205 raise ValueError(
206 f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
207 f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
208 "scroll up you may find a root-cause error that is more specific."
209 ) from e
ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function ModelWrapper.get_vector_field.<locals>.vf>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.
Visualize Posterior Samples#
Plot the posterior samples as a 2D histogram.
[ ]:
plot_marginals(samples, true_param=true_param.reshape(-1))
plt.show()
<Figure size 640x480 with 0 Axes>

13. Posterior Evaluation#
Evaluate the posterior by computing the likelihood on a grid and visualizing the results.
NOTE: the code for marginal likelihood estimation is not working. Currently it can only sample the full conditional likelihood, which is undonvenient if the dimensionality of the output is high.
we evaluate the marginal posterior for the first pair of parameters, which corresponds to the first two dimensions of the parameter space.
[ ]:
idx = 1
observation, reference_samples = task.get_reference(idx)
solver = ODESolver(velocity_model=vf_wrapped) # create an ODESolver class
[ ]:
p0_cond = dist.Independent(
dist.Normal(loc=jnp.zeros((2,)), scale=jnp.ones((2,))),
reinterpreted_batch_ndims=1
)
[ ]:
[ ]:
grid_size = 200
theta1 = jnp.linspace(-1, 1, grid_size)
theta2 = jnp.linspace(-1, 1, grid_size)
x_1 = jnp.meshgrid(theta1, theta2)
x_1 = jnp.stack([x_1[0].flatten(), x_1[1].flatten()], axis=1)
cond = jnp.broadcast_to(observation, (1, dim_data)) # n, dim_theta, 1
obs_ids = jnp.array([8,9]) # we only sample the posterior for two dimensions
cond_ids = jnp.arange(dim_theta, dim_joint) # conditional ids
node_ids_cond = jnp.concatenate([obs_ids, cond_ids])
cond_mask_partial_posterior = jnp.concatenate([jnp.zeros((obs_ids.shape[0],), dtype=jnp.bool_), jnp.ones((cond_ids.shape[0],), dtype=jnp.bool_)], axis=-1)
posterior_mask_cond = task.get_edge_mask_fn("faithfull")(node_ids_cond, condition_mask=cond_mask_partial_posterior)
model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": posterior_mask_cond}
[ ]:
# get the logprob
# logp_sampler = solver.get_unnormalized_logprob(condition_mask=condition_mask, time_grid=[1.0,0.0],method='Dopri5', step_size=step_size, log_p0=p0_dist_model.log_prob, model_extras=model_extras)
logp_sampler = solver.get_unnormalized_logprob(time_grid=[1.0,0.0],method='Dopri5', step_size=step_size, log_p0=p0_cond.log_prob, model_extras=model_extras)
y_init = x_1
exact_log_p = logp_sampler(y_init)
p = jnp.exp(exact_log_p)[-1]
[ ]:
x = theta1
y = theta2
Z = np.array(p.reshape((grid_size, grid_size)))
plot_2d_dist_contour(x,y,Z, true_param=[true_param[0, 8], true_param[0, 9]])
plt.xlabel(r"$\theta_9$", fontsize=12)
plt.ylabel(r"$\theta_{10}$", fontsize=12)
#plot the true value
# plt.scatter(true_param[0, 8], true_param[0, 9], color='red', label='True Parameter', s=50, marker="s", zorder=10)
# plt.axvline(true_param[0, 8], color='red', linestyle='-', linewidth=1, zorder=9)
# plt.axhline(true_param[0, 9], color='red', linestyle='-', linewidth=1, zorder=9)
plt.show()

14. Classifier Two-Sample Test (C2ST)#
Evaluate the quality of the posterior samples using the C2ST metric. Values closer to 0.5 are better.
[ ]:
from gensbi_examples.c2st import c2st
[ ]:
idx = 1
samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=10_000)
[ ]:
c2st_accuracy = c2st(reference_samples, samples)
[ ]:
c2st_accuracy #0.5277 for obs 1
array(0.51165, dtype=float32)
[ ]: