gensbi.flow_matching.loss#
Submodules#
Classes#
ContinuousFMLoss is a class that computes the continuous flow matching loss. |
Package Contents#
- class gensbi.flow_matching.loss.ContinuousFMLoss(path, reduction='mean')[source]#
Bases:
flax.nnx.Module
ContinuousFMLoss is a class that computes the continuous flow matching loss.
- Parameters:
path (MixtureDiscreteProbPath) – Probability path (x-prediction training).
reduction (str, optional) – Specify the reduction to apply to the output
'none'
|'mean'
|'sum'
.'none'
: no reduction is applied to the output,'mean'
: the output is reduced by mean over sequence elements,'sum'
: the output is reduced by sum over sequence elements. Defaults to ‘mean’.
Example
from gensbi.flow_matching.loss import ContinuousFMLoss from gensbi.flow_matching.path import AffineProbPath from gensbi.flow_matching.path.scheduler import CondOTScheduler import jax, jax.numpy as jnp scheduler = CondOTScheduler() path = AffineProbPath(scheduler) loss_fn = ContinuousFMLoss(path) def vf(x, t, args=None): return x + t x_0 = jnp.zeros((8, 2)) x_1 = jnp.ones((8, 2)) t = jnp.linspace(0, 1, 8) batch = (x_0, x_1, t) loss = loss_fn(vf, batch) print(loss.shape) # ()
- __call__(vf, batch, args=None, **kwargs)[source]#
Evaluates the continuous flow matching loss.
- Parameters:
vf (callable) – The vector field model to evaluate.
batch (tuple) – A tuple containing the input data (x_0, x_1, t).
args (optional) – Additional arguments for the function.
condition_mask (optional) – A mask to apply to the input data.
**kwargs – Additional keyword arguments for the function.
- Returns:
The computed loss.
- Return type:
Array
- path#