Source code for gensbi.models.simformer.loss
import jax.numpy as jnp
from flax import nnx
from typing import Callable, Tuple, Optional
from jax.numpy import ndarray as Array
from gensbi.flow_matching.loss import ContinuousFMLoss
[docs]
class SimformerCFMLoss(ContinuousFMLoss):
def __init__(self, path, reduction: str = "mean"):
"""
Initialize the Simformer Continuous Flow Matching Loss.
Args:
path: Probability path for training.
reduction (str): Reduction method ('none', 'mean', 'sum').
"""
super().__init__(path, reduction)
def __call__(
self,
vf: Callable,
batch: Tuple[Array, Array, Array],
args: Optional[dict] = None,
condition_mask: Optional[Array] = None,
**kwargs
) -> Array:
"""
Evaluate the continuous flow matching loss.
Args:
vf (Callable): Vector field model.
batch (Tuple[Array, Array, Array]): Input data (x_0, x_1, t).
args (Optional[dict]): Additional arguments.
condition_mask (Optional[Array]): Mask for conditioning.
**kwargs: Additional keyword arguments.
Returns:
Array: Computed loss.
"""
_, x_1, _ = batch
path_sample = self.path.sample(*batch)
if condition_mask is not None:
kwargs["condition_mask"] = condition_mask
x_t = path_sample.x_t
if condition_mask is not None:
condition_mask = condition_mask.reshape(x_t.shape)
x_t = jnp.where(condition_mask, x_1, x_t)
model_output = vf(x_t, path_sample.t, args=args, **kwargs)
loss = model_output - path_sample.dx_t
if condition_mask is not None:
loss = jnp.where(condition_mask, 0.0, loss)
return self.reduction(jnp.square(loss)) # type: ignore