gensbi.flow_matching.solver.sde_solver#

Classes#

BaseSDESolver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

NonSingular

NonSingular SDE solver.

ZeroEnds

ZeroEnds SDE solver.

Module Contents#

class gensbi.flow_matching.solver.sde_solver.BaseSDESolver(velocity_model, mu0, sigma0, eps0=1e-05)[source]#

Bases: gensbi.flow_matching.solver.solver.Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:
  • velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

  • mu0 (jax.Array)

  • sigma0 (jax.Array)

  • eps0 (float)

abstract get_f_tilde()[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

Return type:

Callable

abstract get_g_tilde()[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

Return type:

Callable

get_sampler(args=None, nsteps=300, method='SEA', adaptive=False, **kwargs)[source]#

Stochastic sampler for the SDE. :param args: additional arguments to pass to the velocity model :param nsteps: number of steps for the SDE solver :param method: the method to use for the SDE solver, can be one of “Euler”, “SEA”, “ShARK”. Defaults to “SEA”. Euler is the simplest algorithm. SEA (Shifted Euler method) has a better constant factor in the global error and an improved local error. ShARK (Shifted Additive-noise Runge-Kutta) provides a more accurate solution with a higher computational cost, and implements adaptive stepsize control. :param adaptive: whether to use adaptive stepsize control (only for ShARK). Defaults to True.

Returns:

A function that takes initial conditions and returns the solution at final time.

Return type:

Callable

get_score(**kwargs)[source]#

Obtain the score function given the velocity model. See arXiv.2410.02217

sample(key, nsamples, nsteps=300, method='SEA', adaptive=True, **kwargs)[source]#

Sample from the SDE using the provided key and number of samples.

Parameters:
  • key (jax.Array) – JAX random key for sampling.

  • nsamples (int) – Number of samples to generate.

  • nsteps (int) – Number of steps for the SDE solver.

  • **kwargs – Additional arguments to pass to the velocity model.

Returns:

Sampled trajectories from the SDE.

Return type:

Array

dim[source]#
eps0 = 1e-05[source]#
mu0[source]#
prior_distribution[source]#
sigma0[source]#
velocity_model[source]#
class gensbi.flow_matching.solver.sde_solver.NonSingular(velocity_model, mu0, sigma0, alpha)[source]#

Bases: BaseSDESolver

NonSingular SDE solver.

From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.

Parameters:
get_f_tilde(**kwargs)[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

Return type:

Callable

get_g_tilde()[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

Return type:

Callable

alpha[source]#
class gensbi.flow_matching.solver.sde_solver.ZeroEnds(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#

Bases: BaseSDESolver

ZeroEnds SDE solver.

From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.

Parameters:
get_f_tilde(**kwargs)[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

Return type:

Callable

get_g_tilde()[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

Return type:

Callable

alpha[source]#