gensbi.flow_matching.solver.ode_solver#

Classes#

ODESolver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

Module Contents#

class gensbi.flow_matching.solver.ode_solver.ODESolver(velocity_model)[source]#

Bases: gensbi.flow_matching.solver.solver.Solver

A class to solve ordinary differential equations (ODEs) using a specified velocity model.

This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.

Parameters:

velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)

Example

from gensbi.flow_matching.solver import ODESolver
from gensbi.utils.model_wrapping import ModelWrapper
import jax, jax.numpy as jnp

class DummyModel:
    def __call__(self, obs, t, *args, **kwargs):
        return jnp.squeeze(obs + t, axis=-1)

vf_model = DummyModel() # replace with your actual velocity field model, Simformer or Flux1

model_wrapped = ModelWrapper(vf_model) # you should use the appropriate ModelWrapper for your model, either FluxWrapper or SimformerWrapper, or a custom subclass of ModelWrapper
solver = ODESolver(velocity_model=model_wrapped)
x_init = jnp.zeros((10, 2))
time_grid = jnp.linspace(0, 1, 5)
sol = solver.sample(x_init=x_init, step_size=0.05, time_grid=time_grid)
print(sol.shape)
# (5, 10, 2)
get_sampler(step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#

Obtain a sampler to solve the ODE with the velocity field.

Parameters:
  • x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

A function that takes initial conditions and returns the solution at final time or intermediate times.

Return type:

Callable

get_unnormalized_logprob(log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#

Solve for log likelihood given a target sample at \(t=0\).

Parameters:
  • x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).

  • log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.

  • step_size (Optional[float]) – Step size for fixed-step solvers.

  • method (str) – Integration method to use.

  • atol (float) – Absolute tolerance for adaptive solvers.

  • rtol (float) – Relative tolerance for adaptive solvers.

  • time_grid (Array) – Must start at 1.0 and end at 0.0.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.

  • **model_extras – Additional model inputs.

Returns:

Samples and log likelihood values.

Return type:

Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]

sample(x_init, step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#

Sample from the ODE defined by the velocity field.

Parameters:
  • x_init (Array) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].

  • step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.

  • method (str) – A method supported by diffrax. Defaults to “Dopri5”. Other commonly used solvers are “Euler”. For a complete list, see diffrax.

  • atol (float) – Absolute tolerance, used for adaptive step solvers.

  • rtol (float) – Relative tolerance, used for adaptive step solvers.

  • time_grid (Array) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to jnp.array([0.0, 1.0]).

  • return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.

  • **model_extras – Additional input for the model.

Returns:

The final state or the states at all intermediate time steps.

Return type:

Union[Array, Sequence[Array]]

unnormalized_logprob(x_1, log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#
Parameters:
  • x_1 (jax.Array)

  • log_p0 (Callable[[jax.Array], jax.Array])

  • step_size (float)

  • method (Union[str, diffrax.AbstractERK])

  • atol (float)

  • rtol (float)

  • return_intermediates (bool)

  • model_extras (dict)

Return type:

Union[Tuple[jax.Array, jax.Array], Tuple[Sequence[jax.Array], jax.Array]]

velocity_model[source]#