gensbi.flow_matching.solver.sde_solver#
Classes#
A class to solve ordinary differential equations (ODEs) using a specified velocity model. |
|
NonSingular SDE solver. |
|
ZeroEnds SDE solver. |
Module Contents#
- class gensbi.flow_matching.solver.sde_solver.BaseSDESolver(velocity_model, mu0, sigma0, eps0=1e-05)[source]#
Bases:
gensbi.flow_matching.solver.solver.Solver
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
mu0 (jax.Array)
sigma0 (jax.Array)
eps0 (float)
- abstract get_f_tilde()[source]#
Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.
- Return type:
Callable
- abstract get_g_tilde()[source]#
Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.
- Return type:
Callable
- get_sampler(args=None, nsteps=300, method='SEA', adaptive=False, **kwargs)[source]#
Stochastic sampler for the SDE. :param args: additional arguments to pass to the velocity model :param nsteps: number of steps for the SDE solver :param method: the method to use for the SDE solver, can be one of “Euler”, “SEA”, “ShARK”. Defaults to “SEA”. Euler is the simplest algorithm. SEA (Shifted Euler method) has a better constant factor in the global error and an improved local error. ShARK (Shifted Additive-noise Runge-Kutta) provides a more accurate solution with a higher computational cost, and implements adaptive stepsize control. :param adaptive: whether to use adaptive stepsize control (only for ShARK). Defaults to True.
- Returns:
A function that takes initial conditions and returns the solution at final time.
- Return type:
Callable
- get_score(**kwargs)[source]#
Obtain the score function given the velocity model. See arXiv.2410.02217
- sample(key, nsamples, nsteps=300, method='SEA', adaptive=True, **kwargs)[source]#
Sample from the SDE using the provided key and number of samples.
- Parameters:
key (jax.Array) – JAX random key for sampling.
nsamples (int) – Number of samples to generate.
nsteps (int) – Number of steps for the SDE solver.
**kwargs – Additional arguments to pass to the velocity model.
- Returns:
Sampled trajectories from the SDE.
- Return type:
Array
- class gensbi.flow_matching.solver.sde_solver.NonSingular(velocity_model, mu0, sigma0, alpha)[source]#
Bases:
BaseSDESolver
NonSingular SDE solver.
From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.
- Parameters:
velocity_model (gensbi.utils.model_wrapping.ModelWrapper)
mu0 (jax.Array)
sigma0 (jax.Array)
alpha (float)
- get_f_tilde(**kwargs)[source]#
Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.
- Return type:
Callable
- class gensbi.flow_matching.solver.sde_solver.ZeroEnds(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#
Bases:
BaseSDESolver
ZeroEnds SDE solver.
From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.
- Parameters:
velocity_model (gensbi.utils.model_wrapping.ModelWrapper)
mu0 (jax.Array)
sigma0 (jax.Array)
alpha (float)
eps0 (float)
- get_f_tilde(**kwargs)[source]#
Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.
- Return type:
Callable