Getting Started#
Welcome to GenSBI! This page will help you get started with installation and basic usage.
Installation#
GenSBI is in early development. To install, clone the repository and install dependencies:
pip install git+https://github.com/aurelio-amerio/GenSBI.git
If a GPU is available, it is advisable to install the cuda version of the package:
pip install git+https://github.com/aurelio-amerio/GenSBI.git[cuda12]
Requirements#
Python 3.11+
JAX
Flax
(See
pyproject.tomlfor full requirements)
Basic Usage#
The most basic usage of GenSBI involves defining a simulation-based inference pipeline using one of the provided recipes. Here is a minimal example of setting up a flow-based inference pipeline using Simformer:
import jax
from jax import numpy as jnp
import numpy as np
from gensbi.recipes import SimformerFlowPipeline
from gensbi.utils.plotting import plot_marginals
import grain # data loaders for jax
import matplotlib.pyplot as plt # for plotting
thetas_train = ... # inference parameters for training, shape (N_train, dim_theta, 1)
thetas_val = ... # inference parameters for validation, shape (N_val, dim_theta, 1)
xs_train = ... # observed/simulated data for training, shape (N_train, dim_data, 1)
xs_val = ... # observed/simulated data for validation, shape (N_val, dim_data, 1)
# concatenate thetas and xs for dataset creation
train_data = np.concatenate((thetas_train, xs_train), axis=1)
val_data = np.concatenate((thetas_val, xs_val), axis=1)
# define a batch size for training, and create a batched dataset using grain
batch_size = 256
train_dataset = (
grain.MapDataset.source(np.array(train_data))
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
)
val_dataset = (
grain.MapDataset.source(np.array(val_data))
.shuffle(42)
.repeat()
.to_iter_dataset()
.batch(batch_size)
)
# define JAX datasets
dim_theta = ... # number of inference parameters, that is thetas_train.shape[1]
dim_data = ... # dimensionality of observed data, that is xs_train.shape[1]
# define the pipeline
pipeline = SimformerFlowPipeline(train_dataset, val_dataset, dim_theta, dim_data)
# train the model
pipeline.train(num_steps=10000)
rng = jax.random.PRNGKey(0)
nsamples = 10_000
observation = ... # target observation, shape (1, dim_data, 1)
samples = pipeline.sample(rng, observation, nsamples)
# plot the posterior distributions
plot_marginals(samples, gridsize=50)
plt.show()
See the Examples page for practical demonstrations on common SBI benchmarks.
Citing GenSBI#
If you use this library, please consider citing this work and the original methodology papers.