gensbi.flow_matching.solver.ode_solver#
Classes#
A class to solve ordinary differential equations (ODEs) using a specified velocity model. |
Module Contents#
- class gensbi.flow_matching.solver.ode_solver.ODESolver(velocity_model)[source]#
Bases:
gensbi.flow_matching.solver.solver.Solver
A class to solve ordinary differential equations (ODEs) using a specified velocity model.
This class utilizes a velocity field model to solve ODEs over a given time grid using numerical ode solvers.
- Parameters:
velocity_model (Union[ModelWrapper, Callable]) – a velocity field model receiving \((x,t)\) and returning \(u_t(x)\)
Example
from gensbi.flow_matching.solver import ODESolver from gensbi.utils.model_wrapping import ModelWrapper import jax, jax.numpy as jnp class DummyModel: def __call__(self, obs, t, *args, **kwargs): return jnp.squeeze(obs + t, axis=-1) vf_model = DummyModel() # replace with your actual velocity field model, Simformer or Flux1 model_wrapped = ModelWrapper(vf_model) # you should use the appropriate ModelWrapper for your model, either FluxWrapper or SimformerWrapper, or a custom subclass of ModelWrapper solver = ODESolver(velocity_model=model_wrapped) x_init = jnp.zeros((10, 2)) time_grid = jnp.linspace(0, 1, 5) sol = solver.sample(x_init=x_init, step_size=0.05, time_grid=time_grid) print(sol.shape) # (5, 10, 2)
- get_sampler(step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#
Obtain a sampler to solve the ODE with the velocity field.
- Parameters:
x_init (Tensor) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by torchdiffeq. Defaults to “Euler”. Other commonly used solvers are “Dopri5”, “midpoint” and “heun3”. For a complete list, see torchdiffeq.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Tensor) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to torch.tensor([0.0, 1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
A function that takes initial conditions and returns the solution at final time or intermediate times.
- Return type:
Callable
- get_unnormalized_logprob(log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#
Solve for log likelihood given a target sample at \(t=0\).
- Parameters:
x_1 (Array) – target sample (e.g., samples \(X_1 \sim p_1\)).
log_p0 (Callable[[Array], Array]) – Log probability function of source distribution.
step_size (Optional[float]) – Step size for fixed-step solvers.
method (str) – Integration method to use.
atol (float) – Absolute tolerance for adaptive solvers.
rtol (float) – Relative tolerance for adaptive solvers.
time_grid (Array) – Must start at 1.0 and end at 0.0.
return_intermediates (bool) – Whether to return intermediate steps.
exact_divergence (bool) – Use exact divergence vs Hutchinson estimator.
**model_extras – Additional model inputs.
- Returns:
Samples and log likelihood values.
- Return type:
Union[Tuple[Array, Array], Tuple[Sequence[Array], Array]]
- sample(x_init, step_size, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=jnp.array([0.0, 1.0]), return_intermediates=False, model_extras={})[source]#
Sample from the ODE defined by the velocity field.
- Parameters:
x_init (Array) – initial conditions (e.g., source samples \(X_0 \sim p\)). Shape: [batch_size, …].
step_size (Optional[float]) – The step size. Must be None for adaptive step solvers.
method (str) – A method supported by diffrax. Defaults to “Dopri5”. Other commonly used solvers are “Euler”. For a complete list, see diffrax.
atol (float) – Absolute tolerance, used for adaptive step solvers.
rtol (float) – Relative tolerance, used for adaptive step solvers.
time_grid (Array) – The process is solved in the interval [min(time_grid, max(time_grid)] and if step_size is None then time discretization is set by the time grid. May specify a descending time_grid to solve in the reverse direction. Defaults to jnp.array([0.0, 1.0]).
return_intermediates (bool, optional) – If True then return intermediate time steps according to time_grid. Defaults to False.
**model_extras – Additional input for the model.
- Returns:
The final state or the states at all intermediate time steps.
- Return type:
Union[Array, Sequence[Array]]
- unnormalized_logprob(x_1, log_p0, step_size=0.01, method='Dopri5', atol=1e-05, rtol=1e-05, time_grid=[1.0, 0.0], return_intermediates=False, *, model_extras={})[source]#
- Parameters:
x_1 (jax.Array)
log_p0 (Callable[[jax.Array], jax.Array])
step_size (float)
method (Union[str, diffrax.AbstractERK])
atol (float)
rtol (float)
return_intermediates (bool)
model_extras (dict)
- Return type:
Union[Tuple[jax.Array, jax.Array], Tuple[Sequence[jax.Array], jax.Array]]