Source code for gensbi.flow_matching.solver.ode_solver

# FIXME: first pass

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from typing import Callable, Optional, Sequence, Tuple, Union

# import torch
# from torch import Tensor
# from torchdiffeq import odeint

import jax
import jax.numpy as jnp
from jax import Array
import diffrax
from diffrax import AbstractERK

from gensbi.flow_matching.solver.solver import Solver
from gensbi.utils.model_wrapping import ModelWrapper


[docs] class ODESolver(Solver): """A class to solve ordinary differential equations (ODEs) using a specified velocity model. This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers. Args: velocity_model (Union[ModelWrapper, Callable]): a velocity field model receiving :math:`(x,t)` and returning :math:`u_t(x)` """ def __init__(self, velocity_model: ModelWrapper): super().__init__() self.velocity_model = velocity_model
[docs] def get_sampler( self, step_size: Optional[float], method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Array = jnp.array([0.0, 1.0]), return_intermediates: bool = False, model_extras: dict = {}, ) -> Callable: r"""Solve the ODE with the velocity field. Example: .. code-block:: python import torch from flow_matching.utils import ModelWrapper from flow_matching.solver import ODESolver class DummyModel(ModelWrapper): def __init__(self): super().__init__(None) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: return torch.ones_like(x) * 3.0 * t**2 velocity_model = DummyModel() solver = ODESolver(velocity_model=velocity_model) x_init = torch.tensor([0.0, 0.0]) step_size = 0.001 time_grid = torch.tensor([0.0, 1.0]) result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid) Args: x_init (Tensor): initial conditions (e.g., source samples :math:`X_0 \sim p`). Shape: [batch_size, ...]. step_size (Optional[float]): The step size. Must be None for adaptive step solvers. method (str): A method supported by torchdiffeq. Defaults to "Euler". Other commonly used solvers are "Dopri5", "midpoint" and "heun3". For a complete list, see torchdiffeq. atol (float): Absolute tolerance, used for adaptive step solvers. rtol (float): Relative tolerance, used for adaptive step solvers. time_grid (Tensor): The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]). return_intermediates (bool, optional): If True then return intermediate time steps according to time_grid. Defaults to False. **model_extras: Additional input for the model. Returns: Union[Tensor, Sequence[Tensor]]: The last timestep when return_intermediates=False, otherwise all values specified in time_grid. """ term = diffrax.ODETerm(self.velocity_model.get_vector_field(**model_extras)) if isinstance(method, str): solver = { "Euler": diffrax.Euler, "Dopri5": diffrax.Dopri5, }[method]() else: solver = method stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol) @jax.jit def sampler(x_init): solution = diffrax.diffeqsolve( term, solver, t0=time_grid[0], t1=time_grid[-1], dt0=step_size, y0=x_init, saveat=( diffrax.SaveAt(ts=time_grid) if return_intermediates else diffrax.SaveAt(t1=True) ), stepsize_controller=stepsize_controller, ) return solution.ys if return_intermediates else solution.ys[-1] # type: ignore return sampler
[docs] def sample( self, x_init: Array, step_size: Optional[float], method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid: Array = jnp.array([0.0, 1.0]), return_intermediates: bool = False, model_extras: dict = {}, ) -> Union[Array, Sequence[Array]]: sampler = self.get_sampler( step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, return_intermediates=return_intermediates, model_extras=model_extras, ) solution = sampler(x_init) return solution
[docs] def get_unnormalized_logprob( self, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid=[1.0, 0.0], return_intermediates: bool = False, # exact_divergence: bool = True, *, # key: jax.random.PRNGKey = None, model_extras: dict = {}, ) -> Callable: r"""Solve for log likelihood given a target sample at :math:`t=0`. Args: x_1 (Array): target sample (e.g., samples :math:`X_1 \sim p_1`). log_p0 (Callable[[Array], Array]): Log probability function of source distribution. step_size (Optional[float]): Step size for fixed-step solvers. method (str): Integration method to use. atol (float): Absolute tolerance for adaptive solvers. rtol (float): Relative tolerance for adaptive solvers. time_grid (Array): Must start at 1.0 and end at 0.0. return_intermediates (bool): Whether to return intermediate steps. exact_divergence (bool): Use exact divergence vs Hutchinson estimator. **model_extras: Additional model inputs. Returns: Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]: Samples and log likelihood values. """ assert ( time_grid[0] == 1.0 and time_grid[-1] == 0.0 ), f"Time grid must start at 1.0 and end at 0.0. Got {time_grid}" vector_field = self.velocity_model.get_vector_field(**model_extras) divergence = self.velocity_model.get_divergence(**model_extras) def dynamics_func(t, states, args): xt, _ = states ut = vector_field(t, xt, args) div = divergence(t, xt, args) return ut, div term = diffrax.ODETerm(dynamics_func) if isinstance(method, str): solver = { "Euler": diffrax.Euler(), "Dopri5": diffrax.Dopri5(), }[method] else: solver = method stepsize_controller = diffrax.PIDController(rtol=rtol, atol=atol) def sampler(x_1): # y_init = (x_1, jnp.ones(x_1.shape)) # the divergence is a scalar, so it has one less dimension than the vector field y_init = ( x_1, jnp.zeros(x_1.shape[0]), ) # the divergence is a scalar, so it has one less dimension than the vector field solution = diffrax.diffeqsolve( term, solver, t0=time_grid[0], t1=time_grid[-1], dt0=-step_size, y0=y_init, saveat=( diffrax.SaveAt(ts=time_grid) if return_intermediates else diffrax.SaveAt(t1=True) ), stepsize_controller=stepsize_controller, ) x_source, log_det = solution.ys[0], solution.ys[1] # type: ignore source_log_p = log_p0(x_source) return source_log_p + log_det return sampler
[docs] def unnormalized_logprob( self, x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: Union[str, AbstractERK] = "Dopri5", atol: float = 1e-5, rtol: float = 1e-5, time_grid=[1.0, 0.0], return_intermediates: bool = False, # exact_divergence: bool = True, *, # key: jax.random.PRNGKey = None, model_extras: dict = {}, ) -> Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]: sampler = self.get_unnormalized_logprob( log_p0=log_p0, step_size=step_size, method=method, atol=atol, rtol=rtol, time_grid=time_grid, return_intermediates=return_intermediates, model_extras=model_extras, ) solution = sampler(x_1) return solution