from typing import Callable, Optional, Sequence, Tuple, Union
from abc import abstractmethod
import jax
from jax import jit, vmap
import jax.numpy as jnp
from jax import Array
import diffrax
from functools import partial
from einops import repeat
from numpyro.distributions import Independent, Normal
from diffrax import (
diffeqsolve,
ControlTerm,
MultiTerm,
ODETerm,
VirtualBrownianTree,
)
from flow_matching.solver.solver import Solver
from utils.model_wrapping import ModelWrapper
[docs]
class BaseSDESolver(Solver):
"""A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
Args:
velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)`
"""
def __init__(
self,
velocity_model: ModelWrapper,
mu0: Array,
sigma0: Array,
eps0: float = 1e-5,
):
"""
Args:
velocity_model (ModelWrapper): a velocity field model, propery wrapped
mu0 (Array): mean of the prior gaussian distribution
sigma0 (Array): standard deviation of the prior gaussian distribution
"""
super().__init__()
self.velocity_model = velocity_model
self.mu0 = mu0
self.sigma0 = sigma0
self.prior_distribution = Independent(
Normal(mu0, sigma0), reinterpreted_batch_ndims=1
)
self.dim = mu0.shape[0]
self.eps0 = eps0 # dafaults to 1e-5
return
[docs]
@abstractmethod
def get_f_tilde(self) -> Callable:
"""Get the function :math:`\tilde{f}` for the velocity model. See arXiv.2410.02217
Also known as the "drift" term in the SDE context.
"""
score = self.get_score()
...
def f_tilde(t, x, args):
return
return f_tilde
[docs]
@abstractmethod
def get_g_tilde(self) -> Callable:
"""Get the function :math:`\tilde{g}` for the velocity model. See arXiv.2410.02217
Also known as the "diffusion" term in the SDE context.
"""
def g_tilde(t, x, args):
return
return g_tilde
[docs]
def get_score(self, **kwargs):
"""Obtain the score function given the velocity model. See arXiv.2410.02217"""
vf = self.velocity_model.get_vector_field(**kwargs)
def score(t, x, args):
res = (-t * vf(t, x, args) + self.mu0 - x) / ((1 - t) * self.sigma0**2)
return res
return score
[docs]
def get_sampler(
self, args=None, nsteps=300, method="SEA", adaptive=False, **kwargs
) -> Callable:
"""Stochastic sampler for the SDE.
Args:
args: additional arguments to pass to the velocity model
nsteps: number of steps for the SDE solver
method: the method to use for the SDE solver, can be one of "Euler", "SEA", "ShARK". Defaults to "SEA". Euler is the simplest algorithm. SEA (Shifted Euler method) has a better constant factor in the global error and an improved local error. ShARK (Shifted Additive-noise Runge-Kutta) provides a more accurate solution with a higher computational cost, and implements adaptive stepsize control.
adaptive: whether to use adaptive stepsize control (only for ShARK). Defaults to True.
"""
solvers = {
"Euler": diffrax.Euler,
"SEA": diffrax.SEA,
"ShARK": diffrax.ShARK,
}
if method not in solvers.keys():
raise ValueError(
f"Method {method} not supported. Choose from {solvers.keys()}."
)
solver = solvers[method]()
if method in ["Euler"]:
levy_area = diffrax.BrownianIncrement
else:
levy_area = diffrax.SpaceTimeLevyArea
# levy_area = diffrax.SpaceTimeTimeLevyArea
drift = self.get_f_tilde(**kwargs) # drift term, (t,x,args) -> f_tilde
diff = self.get_g_tilde() # diffusion term, (t,x,args) -> g_tilde
t0 = self.eps0
t1 = 1
dt = t1 / nsteps
dtmin = min(2e-5, dt) # minimum step size
tol = dtmin / 2
if method in ["ShARK"] and adaptive:
stepsize_controller = diffrax.PIDController(
rtol=1e-5, atol=1e-5, dtmin=dtmin, dtmax=2 * dt
)
else:
stepsize_controller = diffrax.ConstantStepSize()
@jit
def sample_batch(key, y0):
brownian_motion = VirtualBrownianTree(
t0, t1, tol=tol, shape=(self.dim,), key=key, levy_area=levy_area
)
terms = MultiTerm(ODETerm(drift), ControlTerm(diff, brownian_motion))
sol = diffeqsolve(
terms,
solver,
t0,
t1,
dt0=dt,
y0=y0,
args=args,
stepsize_controller=stepsize_controller,
)
return sol.ys[-1]
@partial(jit, static_argnums=(1))
def sample(key, nsamples):
key1, key2 = jax.random.split(key)
y0s = self.prior_distribution.sample(key1, (nsamples,))
res = sample_batch(key2, y0s)
return res
return sample
[docs]
def sample(
self,
key: jax.Array,
nsamples: int,
nsteps: int = 300,
method="SEA",
adaptive=True,
**kwargs,
) -> jax.Array:
"""Sample from the SDE using the provided key and number of samples.
Args:
key (jax.Array): JAX random key for sampling.
nsamples (int): Number of samples to generate.
nsteps (int): Number of steps for the SDE solver.
**kwargs: Additional arguments to pass to the velocity model.
Returns:
jax.Array: Sampled trajectories from the SDE.
"""
sampler = self.get_sampler(
nsteps=nsteps, method=method, adaptive=adaptive, **kwargs
)
return sampler(key, nsamples)
[docs]
class ZeroEnds(BaseSDESolver):
"""
ZeroEnds SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.
"""
def __init__(
self,
velocity_model: ModelWrapper,
mu0: Array,
sigma0: Array,
alpha: float,
eps0: float = 1e-3,
):
super().__init__(velocity_model, mu0, sigma0, eps0=eps0)
self.alpha = alpha
[docs]
def get_f_tilde(self, **kwargs) -> Callable:
score = self.get_score(**kwargs)
vf = self.velocity_model.get_vector_field(**kwargs)
def f_tilde(t, x, args):
res = vf(t, x, args) + 0.5 * self.alpha**2 * t * (1 - t) * score(t, x, args)
return res
return f_tilde
[docs]
def get_g_tilde(self) -> Callable:
def g_tilde(t, x, args):
b, d = x.shape
res = self.alpha * jnp.sqrt(t * (1 - t)) # scalar
res = jnp.repeat(res, d)
res = jnp.diag(res)
res = repeat(res, "i j -> b i j", b=b)
return res
return g_tilde
[docs]
class NonSingular(BaseSDESolver):
"""
NonSingular SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.
"""
def __init__(
self, velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float
):
super().__init__(velocity_model, mu0, sigma0)
self.alpha = alpha
[docs]
def get_f_tilde(self, **kwargs) -> Callable:
score = self.get_score(**kwargs)
vf = self.velocity_model.get_vector_field(**kwargs)
def f_tilde(t, x, args):
return vf(t, x, args) + 0.5 * self.alpha**2 * (1 - t) * score(t, x, args)
return f_tilde
[docs]
def get_g_tilde(self) -> Callable:
def g_tilde(t, x, args):
b, d = x.shape
res = self.alpha * jnp.sqrt(1 - t) # scalar
res = jnp.repeat(res, d)
res = jnp.diag(res)
res = repeat(res, "i j -> b i j", b=b)
return res
return g_tilde