gensbi.flow_matching.solver package#
Submodules#
gensbi.flow_matching.solver.ode_solver module#
- class gensbi.flow_matching.solver.ode_solver.ODESolver(velocity_model: ModelWrapper)[source]#
Bases:
Solver
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
- get_sampler(step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Callable [source]#
Solve the ODE with the velocity field.
Example:
import torch from flow_matching.utils import ModelWrapper from flow_matching.solver import ODESolver class DummyModel(ModelWrapper): def __init__(self): super().__init__(None) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: return torch.ones_like(x) * 3.0 * t**2 velocity_model = DummyModel() solver = ODESolver(velocity_model=velocity_model) x_init = torch.tensor([0.0, 0.0]) step_size = 0.001 time_grid = torch.tensor([0.0, 1.0]) result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
- Parameters:
x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
- Return type:
Union[Tensor, Sequence[Tensor]]
- get_unnormalized_logprob(log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Callable [source]#
Solve for log likelihood given a target sample at \(t=0\).
- Parameters:
x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).
log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.
step_size (Optional[float]) – Step size for fixed-step solvers.
method (str) – Integration method to use.
atol (float) – Absolute tolerance for adaptive solvers.
rtol (float) – Relative tolerance for adaptive solvers.
time_grid (Array) – Must start at 1.0 and end at 0.0.
return_intermediates (bool) – Whether to return intermediate steps.
exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.
**model_extras – Additional model inputs.
- Returns:
Samples and log likelihood values.
- Return type:
Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]
- sample(x_init: Array, step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Array | Sequence[Array] [source]#
- unnormalized_logprob(x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Tuple[Array, Array] | Tuple[Sequence[Array], Array] [source]#
gensbi.flow_matching.solver.sde_solver module#
- class gensbi.flow_matching.solver.sde_solver.BaseSDESolver(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, eps0: float = 1e-05)[source]#
Bases:
Solver
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
- abstractmethod get_f_tilde() Callable [source]#
Get the function :math:` ilde{f}` for the velocity model. See arXiv.2410.02217 Also known as the “drift” term in the SDE context.
- abstractmethod get_g_tilde() Callable [source]#
Get the function :math:` ilde{g}` for the velocity model. See arXiv.2410.02217 Also known as the “diffusion” term in the SDE context.
- get_sampler(args=None, nsteps=300, method='SEA', adaptive=False, **kwargs) Callable [source]#
Stochastic sampler for the SDE. :param args: additional arguments to pass to the velocity model :param nsteps: number of steps for the SDE solver :param method: the method to use for the SDE solver, can be one of “Euler”, “SEA”, “ShARK”. Defaults to “SEA”. Euler is the simplest algorithm. SEA (Shifted Euler method) has a better constant factor in the global error and an improved local error. ShARK (Shifted Additive-noise Runge-Kutta) provides a more accurate solution with a higher computational cost, and implements adaptive stepsize control. :param adaptive: whether to use adaptive stepsize control (only for ShARK). Defaults to True.
- get_score(**kwargs)[source]#
Obtain the score function given the velocity model. See arXiv.2410.02217
- sample(key: Array, nsamples: int, nsteps: int = 300, method='SEA', adaptive=True, **kwargs) Array [source]#
Sample from the SDE using the provided key and number of samples.
- Parameters:
key (jax.Array) – JAX random key for sampling.
nsamples (int) – Number of samples to generate.
nsteps (int) – Number of steps for the SDE solver.
**kwargs – Additional arguments to pass to the velocity model.
- Returns:
Sampled trajectories from the SDE.
- Return type:
jax.Array
- class gensbi.flow_matching.solver.sde_solver.NonSingular(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float)[source]#
Bases:
BaseSDESolver
NonSingular SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.
- class gensbi.flow_matching.solver.sde_solver.ZeroEnds(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, eps0: float = 0.001)[source]#
Bases:
BaseSDESolver
ZeroEnds SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.
gensbi.flow_matching.solver.solver module#
gensbi.flow_matching.solver.utils module#
- gensbi.flow_matching.solver.utils.get_nearest_times(time_grid: Array, t_discretization: Array) Array [source]#
Find the nearest times in t_discretization for each time in time_grid.
- Parameters:
time_grid (Array) – Query times to find nearest neighbors for, shape (N,)
t_discretization (Array) – Reference time points to match against, shape (M,)
- Returns:
Nearest times from t_discretization for each point in time_grid, shape (N,)
- Return type:
Array
Module contents#
- class gensbi.flow_matching.solver.NonSingular(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float)[source]#
Bases:
BaseSDESolver
NonSingular SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.
- class gensbi.flow_matching.solver.ODESolver(velocity_model: ModelWrapper)[source]#
Bases:
Solver
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
- get_sampler(step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Callable [source]#
Solve the ODE with the velocity field.
Example:
import torch from flow_matching.utils import ModelWrapper from flow_matching.solver import ODESolver class DummyModel(ModelWrapper): def __init__(self): super().__init__(None) def forward(self, x: torch.Tensor, t: torch.Tensor, **extras) -> torch.Tensor: return torch.ones_like(x) * 3.0 * t**2 velocity_model = DummyModel() solver = ODESolver(velocity_model=velocity_model) x_init = torch.tensor([0.0, 0.0]) step_size = 0.001 time_grid = torch.tensor([0.0, 1.0]) result = solver.sample(x_init=x_init, step_size=step_size, time_grid=time_grid)
- Parameters:
x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The last timestep when return_intermediates=False, otherwise all values specified in time_grid.
- Return type:
Union[Tensor, Sequence[Tensor]]
- get_unnormalized_logprob(log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Callable [source]#
Solve for log likelihood given a target sample at \(t=0\).
- Parameters:
x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).
log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.
step_size (Optional[float]) – Step size for fixed-step solvers.
method (str) – Integration method to use.
atol (float) – Absolute tolerance for adaptive solvers.
rtol (float) – Relative tolerance for adaptive solvers.
time_grid (Array) – Must start at 1.0 and end at 0.0.
return_intermediates (bool) – Whether to return intermediate steps.
exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.
**model_extras – Additional model inputs.
- Returns:
Samples and log likelihood values.
- Return type:
Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]
- sample(x_init: Array, step_size: float | None, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid: Array = Array([0., 1.], dtype=float32), return_intermediates: bool = False, model_extras: dict = {}) Array | Sequence[Array] [source]#
- unnormalized_logprob(x_1: Array, log_p0: Callable[[Array], Array], step_size: float = 0.01, method: str | AbstractERK = 'Dopri5', atol: float = 1e-05, rtol: float = 1e-05, time_grid=[1.0, 0.0], return_intermediates: bool = False, *, model_extras: dict = {}) Tuple[Array, Array] | Tuple[Sequence[Array], Array] [source]#
- class gensbi.flow_matching.solver.ZeroEnds(velocity_model: ModelWrapper, mu0: Array, sigma0: Array, alpha: float, eps0: float = 0.001)[source]#
Bases:
BaseSDESolver
ZeroEnds SDE, from tab 1 of http://arxiv.org/abs/2410.02217, with change of variable for time: t -> 1-t to match our time notation.