Gaussian Linear Simformer Conditional Flow Matching Example#

This notebook demonstrates how to train and sample from a conditional flow-matching model on the Gaussian Linear task using JAX and Flax. We cover environment setup, data generation, model definition, training, sampling, and evaluation.

1. Environment Setup#

We set up the notebook environment, import required libraries, and configure JAX for CPU or GPU usage. This section also ensures compatibility with Google Colab.

[1]:
# Check if running on Colab and install dependencies if needed
try:
    import google.colab
    colab = True
except ImportError:
    colab = False

if colab:
    # Install required packages and clone the repository
    %pip install "gensbi_examples[cuda12] @ git+https://github.com/aurelio-amerio/GenSBI-examples"
    !git clone https://github.com/aurelio-amerio/GenSBI-examples
    %cd GenSBI-examples/examples/sbi-benchmarks/gaussian_linear
[2]:
# load autoreload extension
%load_ext autoreload
%autoreload 2
[3]:
# Set JAX backend (use 'cuda' for GPU, 'cpu' otherwise)
import os
os.environ['JAX_PLATFORMS']="cuda"
# os.environ['JAX_PLATFORMS']="cpu"
[4]:
experiment_id = 2

Set Training and Model Restoration Flags#

Configure whether to restore a pretrained model or train from scratch.

[5]:
restore_model=False
train_model=True

Set Checkpoint Directory#

Specify the directory for saving and restoring model checkpoints.

[6]:
import orbax.checkpoint as ocp
# get the current notebook path
notebook_path = os.getcwd()
checkpoint_dir = f"{notebook_path}/checkpoints/gaussian_linear_simformer"

os.makedirs(checkpoint_dir, exist_ok=True)

2. Library Imports and JAX Mesh Setup#

Import required libraries and set up the JAX mesh for sharding.

[7]:
import orbax.checkpoint as ocp
[8]:
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
from flax import nnx
import optax
from optax.contrib import reduce_on_plateau

from numpyro import distributions as dist

import numpy as np

from tqdm.auto import tqdm

from functools import partial

# Define the mesh for JAX sharding (for model restoration on CPU/GPU)
devices = jax.devices()
mesh = jax.sharding.Mesh(devices, axis_names=('data',)) # A simple 1D mesh

[9]:
print(devices)
[CudaDevice(id=0)]

3. Optimizer and Learning Rate Schedule Parameters#

Define optimizer hyperparameters and learning rate scheduling.

[10]:
# @markdown Define optimizer and learning rate schedule parameters
# @markdown Number of epochs to wait before resuming normal operation after the learning rate reduction:
PATIENCE = 10  # @param{type:"integer"}
# @markdown Factor by which to reduce the learning rate:
COOLDOWN = 2  # @param{type:"integer"}
# @markdown Relative tolerance for measuring the new optimum:
FACTOR = 0.5  # @param{type:"number"}
# @markdown Number of iterations to accumulate an average value:
ACCUMULATION_SIZE = 100
RTOL = 1e-4  # @param{type:"number"}
# @markdown learning rate
MAX_LR = 1e-3  # @param{type:"number"}
MIN_LR = 0  # @param{type:"number"}
MIN_SCALE = MIN_LR / MAX_LR

4. Task and Dataset Setup#

Define the Two Moons task and prepare training and validation datasets.

[11]:
from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.solver import ODESolver
from gensbi.utils.plotting import plot_marginals, plot_2d_dist_contour

Define the Task#

[12]:
from gensbi_examples.tasks import GaussianLinear
task = GaussianLinear()
./task_data/data_gaussian_linear.npz already exists, skipping download.
[13]:
idx=1
obs, reference_samples = task.get_reference(num_observation=idx)
true_param = task.get_true_parameters(idx).reshape(-1)

Visualize Reference Samples#

Plot the reference samples from the Two Moons task.

[14]:
plot_marginals(reference_samples,true_param=true_param, gridsize=25)
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/notebooks_gaussian_linear_simformer_22_1.png

5. Dataset Preparation#

Create training and validation datasets for the model.

[15]:
# make a dataset
nsamples = int(1e5)
[16]:
batch_size = 1024*4 # the model greatly benefits from larger batch sizes to avoid overfitting, but this is limited by the GPU memory
train_dataset = task.get_train_dataset(batch_size)
val_dataset = task.get_val_dataset()

dataset_iter = iter(train_dataset)
val_dataset_iter = iter(val_dataset)
[17]:
next(dataset_iter).shape, next(val_dataset_iter).shape
[17]:
((4096, 20), (512, 20))

6. Model Definition#

Define the Simformer model and the conditional flow-matching loss.

Note:

  • The model uses edge masks as attention masks, which are crucial for controlling which variables are attended to during training and inference. These masks enable both posterior estimation (conditioning on observed data) and unconditional density estimation (no conditioning).

  • The marginalization function is used to construct edge masks that marginalize out arbitrary variables, allowing the model to learn and evaluate arbitrary marginal distributions.

[18]:
from gensbi.models import Simformer, SimformerParams, SimformerCFMLoss, SimformerConditioner
[19]:
path = AffineProbPath(scheduler=CondOTScheduler()) # define the probability path
[20]:
dim_theta = task.dim_theta
dim_data = task.dim_data

dim_joint = task.dim_joint
node_ids = jnp.arange(dim_joint)
[21]:
dim_theta, dim_data
[21]:
(array(10), array(10))
[22]:
params = SimformerParams(
    rngs = nnx.Rngs(0),
    dim_value = 40,
    dim_id = 40,
    dim_condition = 10,
    dim_joint= dim_joint,
    fourier_features = 128,
    num_heads = 6,
    num_layers = 8,
    widening_factor = 3,
    qkv_features = 90, # this bottlenecks the transformer features to 92, instead of the token dimension
    num_hidden_layers = 1)
[23]:
loss_fn_cfm = SimformerCFMLoss(path)
[24]:
def marginalize(rng: jax.random.PRNGKey, edge_mask: jax.Array):
    # This function creates an edge mask that marginalizes out a single node from the adjacency matrix.
    # By setting the corresponding row and column to False (except the diagonal), we can compute arbitrary marginals.
    idx = jax.random.choice(rng, jnp.arange(edge_mask.shape[0]), shape=(1,), replace=False)
    edge_mask = edge_mask.at[idx, :].set(False)
    edge_mask = edge_mask.at[:, idx].set(False)
    edge_mask = edge_mask.at[idx, idx].set(True)
    return edge_mask
[25]:
# Edge masks are attention masks that control which variables are attended to.
# - undirected_edge_mask: all variables attend to each other (for unconditional density estimation)
# - posterior_faithfull: mask for posterior estimation (conditioning on observed data)
undirected_edge_mask = jnp.ones((dim_joint, dim_joint), dtype=jnp.bool_)
posterior_mask = jnp.concatenate([jnp.zeros((dim_theta), dtype=jnp.bool_), jnp.ones((dim_data), dtype=jnp.bool_)], axis=-1)
posterior_faithfull = task.get_edge_mask_fn("faithfull")(node_ids, condition_mask=posterior_mask)
[26]:
@partial(jax.jit, static_argnames=["nsamples"])
def get_random_condition_mask(rng: jax.random.PRNGKey, nsamples):
    mask_joint = jnp.zeros((5*nsamples, dim_joint ), dtype=jnp.bool_)
    mask_posterior = jnp.concatenate([jnp.zeros((nsamples, dim_theta), dtype=jnp.bool_), jnp.ones((nsamples, dim_data), dtype=jnp.bool_)], axis=-1)

    mask1 = jax.random.bernoulli(rng, p=0.3, shape=(nsamples, dim_joint))
    filter = ~jnp.all(mask1, axis=-1)
    mask1 = jnp.logical_and(mask1, filter.reshape(-1,1))

    # masks = jnp.concatenate([mask_joint, mask1, mask_posterior, mask_likelihood], axis=0)
    masks = jnp.concatenate([mask_joint, mask1, mask_posterior], axis=0)
    return  jax.random.choice(rng, masks, shape=(nsamples,), replace=False, axis=0)
[27]:
p0_dist_model = dist.Independent(
    dist.Normal(loc=jnp.zeros((dim_joint,)), scale=jnp.ones((dim_joint,))),
    reinterpreted_batch_ndims=1
)
[28]:
def loss_fn_(vf_model, x_1, key: jax.random.PRNGKey):

    batch_size = x_1.shape[0]

    rng_x0, rng_t, rng_condition, rng_edge_mask1, rng_edge_mask2 = jax.random.split(key, 5)

    # Generate data and random times
    x_0 = p0_dist_model.sample(rng_x0, (batch_size,)) # n, T_max, 1

    t = jax.random.uniform(rng_t, x_1.shape[0])

    batch = (x_0, x_1, t)

    # Condition mask -> randomly condition on some data. Here you can choose between the different condition masks, and you should specify the conditionals you may want to compute afterwards.
    condition_mask = get_random_condition_mask(rng_condition, batch_size)

    # undirected_edge_mask
    undirected_edge_mask_ = jnp.repeat(undirected_edge_mask[None,...], 3*batch_size, axis=0) # Dense default mask

    # faithfull posterior mask
    faithfull_edge_mask_ = jnp.repeat(posterior_faithfull[None,...], 3*batch_size, axis=0) # Dense default mask

    # Include marginal consistency by generating edge masks that marginalize out random nodes.
    # This allows the model to learn arbitrary marginal distributions.
    marginal_mask = jax.vmap(marginalize, in_axes=(0,None))(jax.random.split(rng_edge_mask1, (batch_size,)), undirected_edge_mask)
    edge_masks = jnp.concatenate([undirected_edge_mask_, faithfull_edge_mask_, marginal_mask], axis=0)
    # Randomly choose between dense, posterior, and marginal edge masks for each batch element.
    edge_masks = jax.random.choice(rng_edge_mask2, edge_masks, shape=(batch_size,), axis=0) # Randomly choose between dense and marginal mask

    loss = loss_fn_cfm(vf_model, batch, node_ids=node_ids, edge_mask=edge_masks,condition_mask=condition_mask, )

    return loss
[29]:
@nnx.jit
def train_loss(vf_model, key: jax.random.PRNGKey):
    x_1 = next(dataset_iter) # n, T_max, 1
    return loss_fn_(vf_model, x_1, key)
[30]:
@nnx.jit
def val_loss(vf_model, key):
    x_1 = next(val_dataset_iter)
    return loss_fn_(vf_model, x_1, key)
[31]:
@nnx.jit
def train_step(model, optimizer, rng):
    loss_fn = lambda model: train_loss(model, rng)
    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(grads, value=loss)  # In place updates.
    return loss
[32]:
vf_model = Simformer(params)

7. Model Restoration#

Restore the model from checkpoint if requested.

[33]:
if restore_model:
    model_state = nnx.state(vf_model)
    graphdef, abstract_state = nnx.split(vf_model)

    with ocp.CheckpointManager(
        checkpoint_dir, options=ocp.CheckpointManagerOptions(read_only=True)
    ) as read_mgr:
        restored = read_mgr.restore(
            experiment_id,
            # pass in the model_state to restore the exact same State type
            args=ocp.args.Composite(state=ocp.args.PyTreeRestore(item=model_state))
        )

    vf_model= nnx.merge(graphdef, restored["state"])
    print("Restored model from checkpoint")

8. Optimizer Setup#

Set up the optimizer and learning rate schedule.

[34]:
# reduce on plateau schedule
nsteps = 10_000
nepochs = 3

multistep = 1 # if the GPU cannot support batch sizes of at least 4k, adjust this value accordingly to get the desired effective batch size

opt = optax.chain(
    optax.adaptive_grad_clip(10.0),
    optax.adamw(MAX_LR),
    reduce_on_plateau(
        patience=PATIENCE,
        cooldown=COOLDOWN,
        factor=FACTOR,
        rtol=RTOL,
        accumulation_size=ACCUMULATION_SIZE,
        min_scale=MIN_SCALE,
    ),
)
if multistep > 1:
    opt = optax.MultiSteps(opt, multistep)
optimizer = nnx.Optimizer(vf_model, opt)
[35]:
rngs = nnx.Rngs(0)
[36]:
best_state = nnx.state(vf_model)
best_val_loss_value = val_loss(vf_model, jax.random.PRNGKey(0))
val_error_ratio = 1.1
counter = 0
cmax = 10
print_every = 100

loss_array = []
val_loss_array = []

early_stopping = True

9. Training Loop#

Train the model using the defined optimizer and loss function. Early stopping and learning rate scheduling are used for efficient training.

[37]:
if train_model:
    vf_model.train()

    for ep in range(nepochs):
        pbar = tqdm(range(nsteps))
        l = 0
        v_l = 0
        for j in pbar:
            if counter > cmax and early_stopping:
                print("Early stopping")
                # restore the model state
                graphdef, abstract_state = nnx.split(vf_model)

                vf_model = nnx.merge(graphdef, best_state)
                break

            loss = train_step(vf_model, optimizer, rngs.train_step())
            l += loss.item()

            v_loss = val_loss(vf_model, rngs.val_step())
            v_l += v_loss.item()

            if j > 0 and j % 100 == 0:
                loss_ = l / 100
                val_ = v_l / 100

                ratio1 = val_ / loss_
                ratio2 = val_ / best_val_loss_value

                # if ratio1 < val_error_ratio and ratio2 < 1.05:
                if ratio1 < val_error_ratio:
                    if val_ < best_val_loss_value:
                        best_val_loss_value = val_
                        best_state = nnx.state(vf_model)
                    counter = 0
                else:
                    counter += 1

                pbar.set_postfix(
                    loss=f"{loss_:.4f}",
                    ratio=f"{ratio1:.4f}",
                    counter=counter,
                    val_loss=f"{val_:.4f}",
                )
                loss_array.append(loss_)
                val_loss_array.append(val_)
                l = 0
                v_l = 0


    vf_model.eval()
Exception ignored in: <bound method IPythonKernel._clean_thread_parent_frames of <ipykernel.ipkernel.IPythonKernel object at 0x146b7520b0b0>>
Traceback (most recent call last):
  File "/lhome/ific/a/aamerio/miniforge3/envs/gensbi/lib/python3.12/site-packages/ipykernel/ipkernel.py", line 775, in _clean_thread_parent_frames
    def _clean_thread_parent_frames(

KeyboardInterrupt:
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[37], line 17
     14     vf_model = nnx.merge(graphdef, best_state)
     15     break
---> 17 loss = train_step(vf_model, optimizer, rngs.train_step())
     18 l += loss.item()
     20 v_loss = val_loss(vf_model, rngs.val_step())

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/flax/nnx/transforms/compilation.py:350, in jit.<locals>.jit_wrapper(*args, **kwargs)
    340 with graph.update_context(jit_wrapper):
    341   pure_args, pure_kwargs = extract.to_tree(
    342     (args, kwargs),
    343     prefix=(in_shardings, kwarg_shardings)
   (...)    348     ctxtag=jit_wrapper,
    349   )
--> 350   pure_args_out, pure_kwargs_out, pure_out = jitted_fn(
    351     *pure_args, **pure_kwargs
    352   )
    353   _args_out, _kwargs_out, out = extract.from_tree(
    354     (pure_args_out, pure_kwargs_out, pure_out),
    355     merge_fn=_jit_merge_fn,
    356     is_inner=False,
    357     ctxtag=jit_wrapper,
    358   )
    359 return out

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:1123, in register_static.<locals>.<lambda>(obj, empty_iter_children)
   1092 """Registers `cls` as a pytree with no leaves.
   1093
   1094 Instances are treated as static by :func:`jax.jit`, :func:`jax.pmap`, etc. This can
   (...)   1120   Array(3, dtype=int32, weak_type=True)
   1121 """
   1122 flatten = lambda obj: ((), obj)
-> 1123 unflatten = lambda obj, empty_iter_children: obj
   1124 register_pytree_with_keys(cls, flatten, unflatten)
   1125 return cls

KeyboardInterrupt:

10. Save Model Checkpoint#

Save the trained model to a checkpoint for future restoration.

[38]:
# save the model
if train_model:
    checkpoint_manager = ocp.CheckpointManager(checkpoint_dir,
        options=ocp.CheckpointManagerOptions(
            max_to_keep=None,
            keep_checkpoints_without_metrics=True,
            create=True,
        ),
    )
    model_state = nnx.state(vf_model)
    checkpoint_manager.save(
        experiment_id, args=ocp.args.Composite(state=ocp.args.PyTreeSave(model_state))
    )

    checkpoint_manager.close()

11. Training and Validation Loss Visualization#

Plot the training and validation loss curves.

[ ]:
if train_model:
    plt.plot(loss_array, label="train loss")
    plt.plot(val_loss_array, label="val loss")
    plt.xlabel("steps")
    plt.ylabel("loss")
    plt.legend()
    plt.show()

12. Posterior Sampling#

Sample from the posterior distribution using the trained model and visualize the results.

[ ]:
from gensbi.utils.model_wrapping import ModelWrapper
class SimWrapper(ModelWrapper):
    def __init__(self, model):
        super().__init__(model)

    def _call_model(self, x, t, args, **kwargs):
        return self.model(obs=x, timesteps=t, **kwargs)
[ ]:
obs_ids = jnp.arange(dim_theta)  # observation ids
cond_ids = jnp.arange(dim_theta, dim_joint)  # conditional ids
step_size = 0.01

# conditional sampling
def get_samples(vf_wrapped, idx, nsamples=10_000, edge_mask=undirected_edge_mask):
    observation, reference_samples =  task.get_reference(idx)
    true_param = jnp.array(task.get_true_parameters(idx))

    rng = jax.random.PRNGKey(45)

    key1,key2 = jax.random.split(rng, 2)

    x_init = jax.random.normal(key1,(nsamples, dim_theta)) # n, T_max, 1
    cond = jnp.broadcast_to(observation[...,None], (1, dim_data, 1)) # n, dim_theta, 1

    solver = ODESolver(velocity_model=vf_wrapped)  # create an ODESolver class
    model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": edge_mask}

    sampler_ = solver.get_sampler(method='Dopri5', step_size=step_size, return_intermediates=False, model_extras=model_extras)
    samples = sampler_(x_init)  # sample from the model

    return samples, true_param, reference_samples

def plot_samples(samples, true_param):
    plot_marginals(samples, true_param=true_param)
    plt.show()
[ ]:
vf_cond = SimformerConditioner(vf_model)
vf_wrapped = SimWrapper(vf_cond)
[ ]:
idx=1
samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=100_000, edge_mask=posterior_faithfull)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:170, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    169 try:
--> 170     vf_type = eqx.filter_eval_shape(term.vf, t, yi, args)
    171 except Exception as e:

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/equinox/_eval_shape.py:38, in filter_eval_shape(fun, *args, **kwargs)
     37 dynamic, static = partition((fun, args, kwargs), _filter)
---> 38 dynamic_out, static_out = jax.eval_shape(ft.partial(_fn, static), dynamic)
     39 return combine(dynamic_out, static_out.value)

    [... skipping hidden 1 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api.py:3014, in eval_shape(fun, *args, **kwargs)
   3013 except TypeError: fun = partial(fun)
-> 3014 return jit(fun).eval_shape(*args, **kwargs)

    [... skipping hidden 1 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:352, in jit_eval_shape(jit_func, *args, **kwargs)
    350 @api_boundary
    351 def jit_eval_shape(jit_func, *args, **kwargs):
--> 352   p, _ = _infer_params(jit_func._fun, jit_func._jit_info, args, kwargs)
    353   out_shardings = [None if isinstance(s, UnspecifiedValue) else s
    354                    for s in p.params['out_shardings']]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:686, in _infer_params(fun, ji, args, kwargs)
    685     return _infer_params_internal(fun, ji, args, kwargs)
--> 686 return _infer_params_internal(fun, ji, args, kwargs)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:710, in _infer_params_internal(fun, ji, args, kwargs)
    709 if entry.pjit_params is None:
--> 710   p, args_flat = _infer_params_impl(
    711       fun, ji, ctx_mesh, dbg, args, kwargs, in_avals=avals)
    712   if p.attrs_tracked or p.box_data or p.params['jaxpr'].jaxpr.is_high:

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:606, in _infer_params_impl(***failed resolving arguments***)
    604 attr_token = _attr_cache_index(flat_fun, in_type)
--> 606 jaxpr, consts, out_avals, attrs_tracked = _create_pjit_jaxpr(
    607     flat_fun, in_type, attr_token, IgnoreKey(ji.inline))
    609 if config.mutable_array_checks.value:

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:471, in cache.<locals>.memoized_fun(fun, *args)
    470   start = time.time()
--> 471 ans = call(fun, *args)
    472 if do_explain:

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/pjit.py:1414, in _create_pjit_jaxpr(***failed resolving arguments***)
   1413   else:
-> 1414     jaxpr, global_out_avals, consts, attrs_tracked = pe.trace_to_jaxpr_dynamic(fun, in_type)
   1416 if config.debug_key_reuse.value:
   1417   # Import here to avoid circular imports

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/profiler.py:354, in annotate_function.<locals>.wrapper(*args, **kwargs)
    353 with TraceAnnotation(name, **decorator_kwargs):
--> 354   return func(*args, **kwargs)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/interpreters/partial_eval.py:2292, in trace_to_jaxpr_dynamic(fun, in_avals, keep_inputs, lower)
   2291 with core.set_current_trace(trace):
-> 2292   ans = fun.call_wrapped(*in_tracers)
   2293 _check_returned_jaxtypes(fun.debug_info, ans)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:211, in WrappedFun.call_wrapped(self, *args, **kwargs)
    210 """Calls the transformed function"""
--> 211 return self.f_transformed(*args, **kwargs)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api_util.py:288, in _argnums_partial(_fun, _dyn_argnums, _fixed_args, *dyn_args, **kwargs)
    287 assert next(fixed_args_, sentinel) is sentinel
--> 288 return _fun(*args, **kwargs)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/api_util.py:73, in flatten_fun(f, store, in_tree, *args_flat)
     72 py_args, py_kwargs = tree_unflatten(in_tree, args_flat)
---> 73 ans = f(*py_args, **py_kwargs)
     74 ans, out_tree = tree_flatten(ans)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/linear_util.py:396, in _get_result_paths_thunk(_fun, _store, *args, **kwargs)
    394 @transformation_with_aux2
    395 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
--> 396   ans = _fun(*args, **kwargs)
    397   result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/equinox/_eval_shape.py:33, in filter_eval_shape.<locals>._fn(_static, _dynamic)
     32 _fun, _args, _kwargs = combine(_static, _dynamic)
---> 33 _out = _fun(*_args, **_kwargs)
     34 _dynamic_out, _static_out = partition(_out, _filter)

    [... skipping hidden 1 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_term.py:194, in ODETerm.vf(self, t, y, args)
    193 def vf(self, t: RealScalarLike, y: Y, args: Args) -> _VF:
--> 194     out = self.vector_field(t, y, args)
    195     if jtu.tree_structure(out) != jtu.tree_structure(y):

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/utils/model_wrapping.py:75, in ModelWrapper.get_vector_field.<locals>.vf(t, x, args)
     74 def vf(t, x, args):
---> 75     vf = self._call_model(x, t, args, **kwargs)
     76     # squeeze the first dimension of the vector field if it is 1

Cell In[59], line 7, in SimWrapper._call_model(self, x, t, args, **kwargs)
      6 def _call_model(self, x, t, args, **kwargs):
----> 7     return self.model(obs=x, timesteps=t, **kwargs)

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:311, in SimformerConditioner.__call__(self, obs, obs_ids, cond, cond_ids, timesteps, conditioned, edge_mask)
    310 if conditioned:
--> 311     return self.conditioned(
    312         obs, obs_ids, cond, cond_ids, timesteps, edge_mask=edge_mask
    313     )
    314 else:

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:234, in SimformerConditioner.conditioned(self, obs, obs_ids, cond, cond_ids, t, edge_mask)
    229 # Sort the nodes and the corresponding values
    230 # nodes_sort = jnp.argsort(node_ids)
    231 # x = x[:, nodes_sort]
    232 # node_ids = node_ids[nodes_sort]
--> 234 res = self.model(
    235     x=x,
    236     t=t,
    237     node_ids=node_ids,
    238     condition_mask=condition_mask,
    239     edge_mask=edge_mask,
    240 )
    241 # now return only the values on which we are not conditioning

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/models/simformer/simformer.py:137, in Simformer.__call__(self, x, t, args, node_ids, condition_mask, edge_mask)
    136 batch_size, seq_len, _ = x.shape
--> 137 condition_mask = condition_mask.astype(jnp.bool_).reshape(-1, seq_len, 1)
    138 condition_mask = jnp.broadcast_to(condition_mask, (batch_size, seq_len, 1))

    [... skipping hidden 2 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:464, in _compute_newshape(arr, newshape)
    462 if (all(isinstance(d, int) for d in (*arr.shape, *other_sizes)) and
    463     arr.size % math.prod(other_sizes) != 0):
--> 464   raise TypeError(f"cannot reshape array of shape {arr.shape} (size {arr.size}) "
    465                   f"into shape {orig_newshape} because the product of "
    466                   f"specified axis sizes ({math.prod(other_sizes)}) does "
    467                   f"not evenly divide {arr.size}")
    468 sz = core.cancel_divide_tracers(arr.shape, other_sizes)

TypeError: cannot reshape array of shape (12,) (size 12) into shape (-1, 20, 1) because the product of specified axis sizes (20) does not evenly divide 12

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:200, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
    199     with jax.numpy_dtype_promotion("standard"):
--> 200         jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
    201 except ValueError as e:
    202     # ValueError may also arise from mismatched tree structures

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in tree_map(f, tree, is_leaf, *rest)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/jax/_src/tree_util.py:362, in <genexpr>(.0)
    361 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 362 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:172, in _assert_term_compatible.<locals>._check(term_cls, term, term_contr_kwargs, yi)
    171 except Exception as e:
--> 172     raise ValueError(f"Error while tracing {term}.vf: " + str(e))
    173 vf_type_compatible = eqx.filter_eval_shape(
    174     better_isinstance, vf_type, vf_type_expected
    175 )

ValueError: Error while tracing ODETerm(vector_field=<function ModelWrapper.get_vector_field.<locals>.vf>).vf: cannot reshape array of shape (12,) (size 12) into shape (-1, 20, 1) because the product of specified axis sizes (20) does not evenly divide 12

The above exception was the direct cause of the following exception:

ValueError                                Traceback (most recent call last)
Cell In[207], line 2
      1 idx=1
----> 2 samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=100_000)
      3 # samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=10)

Cell In[40], line 21, in get_samples(vf_wrapped, idx, nsamples, edge_mask)
     18 model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": edge_mask}
     20 sampler_ = solver.get_sampler(method='Dopri5', step_size=step_size, return_intermediates=False, model_extras=model_extras)
---> 21 samples = sampler_(x_init)  # sample from the model
     23 return samples, true_param, reference_samples

    [... skipping hidden 14 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/gensbi/flow_matching/solver/ode_solver.py:107, in ODESolver.get_sampler.<locals>.sampler(x_init)
    104 @jax.jit
    105 def sampler(x_init):
--> 107     solution = diffrax.diffeqsolve(
    108         term,
    109         solver,
    110         t0=time_grid[0],
    111         t1=time_grid[-1],
    112         dt0=step_size,
    113         y0=x_init,
    114         saveat=(
    115             diffrax.SaveAt(ts=time_grid)
    116             if return_intermediates
    117             else diffrax.SaveAt(t1=True)
    118         ),
    119         stepsize_controller=stepsize_controller,
    120     )
    121     return solution.ys if return_intermediates else solution.ys[-1]

    [... skipping hidden 19 frame]

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:1103, in diffeqsolve(terms, solver, t0, t1, dt0, y0, args, saveat, stepsize_controller, adjoint, event, max_steps, throw, progress_meter, solver_state, controller_state, made_jump, discrete_terminating_event)
   1100         terms = MultiTerm(*terms)
   1102 # Error checking for term compatibility
-> 1103 _assert_term_compatible(
   1104     t0,
   1105     y0,
   1106     args,
   1107     terms,
   1108     solver.term_structure,
   1109     solver.term_compatible_contr_kwargs,
   1110 )
   1112 if is_sde(terms):
   1113     if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):

File ~/miniforge3/envs/gensbi/lib/python3.12/site-packages/diffrax/_integrate.py:205, in _assert_term_compatible(t, y, args, terms, term_structure, contr_kwargs)
    203 pretty_term = wl.pformat(terms)
    204 pretty_expected = wl.pformat(term_structure)
--> 205 raise ValueError(
    206     f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
    207     f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
    208     "scroll up you may find a root-cause error that is more specific."
    209 ) from e

ValueError: Terms are not compatible with solver! Got:
ODETerm(vector_field=<function ModelWrapper.get_vector_field.<locals>.vf>)
but expected:
diffrax.AbstractTerm
Note that terms are checked recursively: if you scroll up you may find a root-cause error that is more specific.

Visualize Posterior Samples#

Plot the posterior samples as a 2D histogram.

[ ]:
plot_marginals(samples, true_param=true_param.reshape(-1))
plt.show()
<Figure size 640x480 with 0 Axes>
../_images/notebooks_gaussian_linear_simformer_62_1.png

13. Posterior Evaluation#

Evaluate the posterior by computing the likelihood on a grid and visualizing the results.

NOTE: the code for marginal likelihood estimation is not working. Currently it can only sample the full conditional likelihood, which is undonvenient if the dimensionality of the output is high.

we evaluate the marginal posterior for the first pair of parameters, which corresponds to the first two dimensions of the parameter space.

[ ]:
idx = 1
observation, reference_samples = task.get_reference(idx)
solver = ODESolver(velocity_model=vf_wrapped)  # create an ODESolver class
[ ]:
p0_cond = dist.Independent(
    dist.Normal(loc=jnp.zeros((2,)), scale=jnp.ones((2,))),
    reinterpreted_batch_ndims=1
)
[ ]:

[ ]:
grid_size = 200

theta1 = jnp.linspace(-1, 1, grid_size)
theta2 = jnp.linspace(-1, 1, grid_size)
x_1 = jnp.meshgrid(theta1, theta2)

x_1 = jnp.stack([x_1[0].flatten(), x_1[1].flatten()], axis=1)

cond = jnp.broadcast_to(observation, (1, dim_data)) # n, dim_theta, 1

obs_ids = jnp.array([8,9])  # we only sample the posterior for two dimensions
cond_ids = jnp.arange(dim_theta, dim_joint)  # conditional ids

node_ids_cond = jnp.concatenate([obs_ids, cond_ids])
cond_mask_partial_posterior = jnp.concatenate([jnp.zeros((obs_ids.shape[0],), dtype=jnp.bool_), jnp.ones((cond_ids.shape[0],), dtype=jnp.bool_)], axis=-1)
posterior_mask_cond = task.get_edge_mask_fn("faithfull")(node_ids_cond, condition_mask=cond_mask_partial_posterior)

model_extras = {"cond": cond, "obs_ids": obs_ids, "cond_ids": cond_ids, "edge_mask": posterior_mask_cond}
[ ]:
# get the logprob
# logp_sampler = solver.get_unnormalized_logprob(condition_mask=condition_mask, time_grid=[1.0,0.0],method='Dopri5', step_size=step_size, log_p0=p0_dist_model.log_prob, model_extras=model_extras)
logp_sampler = solver.get_unnormalized_logprob(time_grid=[1.0,0.0],method='Dopri5', step_size=step_size, log_p0=p0_cond.log_prob, model_extras=model_extras)

y_init = x_1

exact_log_p = logp_sampler(y_init)
p = jnp.exp(exact_log_p)[-1]

[ ]:
x = theta1
y = theta2
Z = np.array(p.reshape((grid_size, grid_size)))


plot_2d_dist_contour(x,y,Z, true_param=[true_param[0, 8], true_param[0, 9]])
plt.xlabel(r"$\theta_9$", fontsize=12)
plt.ylabel(r"$\theta_{10}$", fontsize=12)
#plot the true value
# plt.scatter(true_param[0, 8], true_param[0, 9], color='red', label='True Parameter', s=50, marker="s", zorder=10)
# plt.axvline(true_param[0, 8], color='red', linestyle='-', linewidth=1, zorder=9)
# plt.axhline(true_param[0, 9], color='red', linestyle='-', linewidth=1, zorder=9)
plt.show()
../_images/notebooks_gaussian_linear_simformer_71_0.png

14. Classifier Two-Sample Test (C2ST)#

Evaluate the quality of the posterior samples using the C2ST metric. Values closer to 0.5 are better.

[ ]:
from gensbi_examples.c2st import c2st
[ ]:
idx = 1
samples, true_param, reference_samples = get_samples(vf_wrapped, idx, nsamples=10_000)

[ ]:
c2st_accuracy = c2st(reference_samples, samples)
[ ]:
c2st_accuracy #0.5277 for obs 1
array(0.51165, dtype=float32)
[ ]: