gensbi.flow_matching.solver#

Submodules#

Classes#

NonSingular

NonSingular SDE solver.

ODESolver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

Solver

Abstract base class for solvers.

ZeroEnds

ZeroEnds SDE solver.

Package Contents#

class gensbi.flow_matching.solver.NonSingular(velocity_model, mu0, sigma0, alpha)[source]#

Bases: BaseSDESolver

NonSingular SDE solver.

From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.

Parameters:
get_f_tilde(**kwargs)[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

Return type:

Callable

get_g_tilde()[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

Return type:

Callable

alpha#
class gensbi.flow_matching.solver.ODESolver(velocity_model)[source]#

Bases: gensbi.flow_matching.solver.solver.Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

Example

from gensbi.flow_matching.solver import ODESolver
from gensbi.utils.model_wrapping import ModelWrapper
import jax, jax.numpy as jnp

class DummyModel:
    def __call__(self, obs, t, *args, **kwargs):
        return jnp.squeeze(obs + t, axis=-1)

vf_model = DummyModel() # replace with your actual velocity field model, Simformer or Flux1

model_wrapped = ModelWrapper(vf_model) # you should use the appropriate ModelWrapper for your model, either FluxWrapper or SimformerWrapper, or a custom subclass of ModelWrapper
solver = ODESolver(velocity_model=model_wrapped)
x_init = jnp.zeros((10, 2))
time_grid = jnp.linspace(0, 1, 5)
sol = solver.sample(x_init=x_init, step_size=0.05, time_grid=time_grid)
print(sol.shape)
# (5, 10, 2)
get_sampler(step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#

Obtain a sampler to solve the ODE with the velocity field.

Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

A function that takes initial conditions and returns the solution at final time or intermediate times.

Return type:

Callable

get_unnormalized_logprob(log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#

Solve for log likelihood given a target sample at \(t=0\).

Parameters:
  • x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).

  • log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.

  • step_size (Optional[float]) – Step size for fixed-step solvers.

  • method (str) – Integration method to use.

  • atol (float) – Absolute tolerance for adaptive solvers.

  • rtol (float) – Relative tolerance for adaptive solvers.

  • time_grid (Array) – Must start at 1.0 and end at 0.0.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.

  • **model_extras – Additional model inputs.

Returns:

Samples and log likelihood values.

Return type:

Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]

sample(x_init, step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#

Sample from the ODE defined by the velocity field.

Parameters:
  • x_init (Array) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by diffrax. Defaults to “Dopri5”. Other commonly used solvers are “Euler”. For a complete list, see diffrax.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Array) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to jnp.array([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The final state or the states at all intermediate time steps.

Return type:

Union[Array, Sequence[Array]]

unnormalized_logprob(x_1, log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#
Parameters:
  • x_1 (jax.Array)

  • log_p0 (Callable[[jax.Array], jax.Array])

  • step_size (float)

  • method (Union[str, diffrax.AbstractERK])

  • atol (float)

  • rtol (float)

  • return_intermediates (bool)

  • model_extras (dict)

Return type:

Union[Tuple[jax.Array, jax.Array], Tuple[Sequence[jax.Array], jax.Array]]

velocity_model#
class gensbi.flow_matching.solver.Solver[source]#

Bases: abc.ABC

Abstract base class for solvers.

abstract sample(x_0)[source]#
Parameters:

x_0 (jax.Array)

Return type:

jax.Array

class gensbi.flow_matching.solver.ZeroEnds(velocity_model, mu0, sigma0, alpha, eps0=0.001)[source]#

Bases: BaseSDESolver

ZeroEnds SDE solver.

From tab 1 of arXiv:2410.02217, with change of variable for time: t -> 1-t to match our time notation.

Parameters:
get_f_tilde(**kwargs)[source]#

Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.

Return type:

Callable

get_g_tilde()[source]#

Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.

Return type:

Callable

alpha#