gensbi.flow_matching.solver package#

Submodules#

gensbi.flow_matching.solver.ode_solver module#

class gensbi.flow_matching.solver.ode_solver.ODESolver(velocity_model: ModelWrapper)[source]#

Bases: Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

get_sampler(step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Callable[source]#

Solve the ODE with the velocity field.

Example:

import torch
from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver

class DummyModel(ModelWrapper):
    def __init__(self):
        super().__init__(None)

    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
        return torch.ones_like(x) * 3.0 * t**2

velocity_model = DummyModel()
solver = ODESolver(velocity_model=velocity_model)
x_init = torch.tensor([0.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The last timestep when return_intermediates=False, otherwise all values specified in time_grid.

Return type:

Union[Tensor, Sequence[Tensor]]

get_unnormalized_logprob(log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Callable[source]#

Solve for log likelihood given a target sample at \(t=0\).

Parameters:
  • x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).

  • log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.

  • step_size (Optional[float]) – Step size for fixed-step solvers.

  • method (str) – Integration method to use.

  • atol (float) – Absolute tolerance for adaptive solvers.

  • rtol (float) – Relative tolerance for adaptive solvers.

  • time_grid (Array) – Must start at 1.0 and end at 0.0.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.

  • **model_extras – Additional model inputs.

Returns:

Samples and log likelihood values.

Return type:

Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]

sample(x_init: Array, step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Array | Sequence[Array][source]#
unnormalized_logprob(x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Tuple[Array, Array] | Tuple[Sequence[Array], Array][source]#

gensbi.flow_matching.solver.sde_solver module#

class gensbi.flow_matching.solver.sde_solver.BaseSDESolver(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, eps0: float = 1e-05)[source]#

Bases: Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

abstractmethod get_f_tilde() Callable[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

abstractmethod get_g_tilde() Callable[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

get_sampler(args=None, nsteps=300, method='SEA', adaptive=False, **kwargs) Callable[source]#

Stochastic sampler for the SDE. :param args: additional arguments to pass to the velocity model :param nsteps: number of steps for the SDE solver :param method: the method to use for the SDE solver, can be one of “Euler”, “SEA”, “ShARK”. Defaults to “SEA”. Euler is the simplest algorithm. SEA (Shifted Euler method) has a better constant factor in the global error and an improved local error. ShARK (Shifted Additive-noise Runge-Kutta) provides a more accurate solution with a higher computational cost, and implements adaptive stepsize control. :param adaptive: whether to use adaptive stepsize control (only for ShARK). Defaults to True.

get_score(**kwargs)[source]#

Obtain the score function given the velocity model. See arXiv.2410.02217

sample(key: Array, nsamples: int, nsteps: int = 300, method='SEA', adaptive=True, **kwargs) Array[source]#

Sample from the SDE using the provided key and number of samples.

Parameters:
  • key (jax.Array) – JAX random key for sampling.

  • nsamples (int) – Number of samples to generate.

  • nsteps (int) – Number of steps for the SDE solver.

  • **kwargs – Additional arguments to pass to the velocity model.

Returns:

Sampled trajectories from the SDE.

Return type:

jax.Array

class gensbi.flow_matching.solver.sde_solver.NonSingular(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float)[source]#

Bases: BaseSDESolver

NonSingular SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.

get_f_tilde(**kwargs) Callable[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

get_g_tilde() Callable[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

class gensbi.flow_matching.solver.sde_solver.ZeroEnds(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, eps0: float = 0.001)[source]#

Bases: BaseSDESolver

ZeroEnds SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.

get_f_tilde(**kwargs) Callable[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

get_g_tilde() Callable[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

gensbi.flow_matching.solver.solver module#

class gensbi.flow_matching.solver.solver.Solver[source]#

Bases: ABC

Abstract base class for solvers.

abstractmethod sample(x_0: Array) Array[source]#

gensbi.flow_matching.solver.utils module#

gensbi.flow_matching.solver.utils.get_nearest_times(time_grid: Array, t_discretization: Array) Array[source]#

Find the nearest times in t_discretization for each time in time_grid.

Parameters:
  • time_grid (Array) – Query times to find nearest neighbors for, shape (N,)

  • t_discretization (Array) – Reference time points to match against, shape (M,)

Returns:

Nearest times from t_discretization for each point in time_grid, shape (N,)

Return type:

Array

Module contents#

class gensbi.flow_matching.solver.NonSingular(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float)[source]#

Bases: BaseSDESolver

NonSingular SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.

get_f_tilde(**kwargs) Callable[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

get_g_tilde() Callable[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

class gensbi.flow_matching.solver.ODESolver(velocity_model: ModelWrapper)[source]#

Bases: Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

get_sampler(step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Callable[source]#

Solve the ODE with the velocity field.

Example:

import torch
from flow_matching.utils import ModelWrapper
from flow_matching.solver import ODESolver

class DummyModel(ModelWrapper):
    def __init__(self):
        super().__init__(None)

    def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor:
        return torch.ones_like(x) * 3.0 * t**2

velocity_model = DummyModel()
solver = ODESolver(velocity_model=velocity_model)
x_init = torch.tensor([0.0, 0.0])
step_size = 0.001
time_grid = torch.tensor([0.0, 1.0])

result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The last timestep when return_intermediates=False, otherwise all values specified in time_grid.

Return type:

Union[Tensor, Sequence[Tensor]]

get_unnormalized_logprob(log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Callable[source]#

Solve for log likelihood given a target sample at \(t=0\).

Parameters:
  • x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).

  • log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.

  • step_size (Optional[float]) – Step size for fixed-step solvers.

  • method (str) – Integration method to use.

  • atol (float) – Absolute tolerance for adaptive solvers.

  • rtol (float) – Relative tolerance for adaptive solvers.

  • time_grid (Array) – Must start at 1.0 and end at 0.0.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.

  • **model_extras – Additional model inputs.

Returns:

Samples and log likelihood values.

Return type:

Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]

sample(x_init: Array, step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Array | Sequence[Array][source]#
unnormalized_logprob(x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Tuple[Array, Array] | Tuple[Sequence[Array], Array][source]#
class gensbi.flow_matching.solver.Solver[source]#

Bases: ABC

Abstract base class for solvers.

abstractmethod sample(x_0: Array) Array[source]#
class gensbi.flow_matching.solver.ZeroEnds(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, eps0: float = 0.001)[source]#

Bases: BaseSDESolver

ZeroEnds SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.

get_f_tilde(**kwargs) Callable[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

get_g_tilde() Callable[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.