Source code for gensbi.flow_matching.loss.continuous_loss
import jax.numpy as jnp
from flax import nnx
from typing import Callable, Tuple, Any
from jax import Array
[docs]
class ContinuousFMLoss(nnx.Module):
def __init__(self, path, reduction: str = "mean"):
"""
ContinuousFMLoss is a class that computes the continuous flow matching loss.
Args:
path (MixtureDiscreteProbPath): Probability path (x-prediction training).
reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'.
"""
self.path = path
if reduction not in ["None", "mean", "sum"]:
raise ValueError(f"{reduction} is not a valid value for reduction")
if reduction == "mean":
self.reduction = jnp.mean
elif reduction == "sum":
self.reduction = jnp.sum
else:
self.reduction = lambda x: x
def __call__(
self,
vf: Callable,
batch: Tuple[Array, Array, Array],
args: Any = None,
**kwargs
) -> Array:
"""
Evaluates the continuous flow matching loss.
Args:
vf (callable): The vector field model to evaluate.
batch (tuple): A tuple containing the input data (x_0, x_1, t).
args (optional): Additional arguments for the function.
condition_mask (optional): A mask to apply to the input data.
**kwargs: Additional keyword arguments for the function.
Returns:
Array: The computed loss.
"""
path_sample = self.path.sample(*batch)
x_t = path_sample.x_t
model_output = vf(x_t, path_sample.t, args=args, **kwargs)
loss = model_output - path_sample.dx_t
return self.reduction(jnp.square(loss))