Source code for gensbi.models.simformer.loss

import jax
import jax.numpy as jnp
from flax import nnx
from typing import Callable, Tuple, Optional
from jax.numpy import ndarray as Array

from gensbi.flow_matching.loss import ContinuousFMLoss


[docs] class SimformerCFMLoss(ContinuousFMLoss): """ SimformerCFMLoss is a class that computes the continuous flow matching loss for the Simformer model. Args: path: Probability path for training. reduction (str): Reduction method ('none', 'mean', 'sum'). """ def __init__(self, path, reduction: str = "mean"): super().__init__(path, reduction)
[docs] def __call__( self, vf: Callable, batch: Tuple[Array, Array, Array], *args, condition_mask: Optional[Array] = None, **kwargs ) -> Array: """ Evaluate the continuous flow matching loss. Args: vf (Callable): Vector field model. batch (Tuple[Array, Array, Array]): Input data (x_0, x_1, t). args (Optional[dict]): Additional arguments. condition_mask (Optional[Array]): Mask for conditioning. **kwargs: Additional keyword arguments. Returns: Array: Computed loss. """ _, x_1, _ = batch path_sample = self.path.sample(*batch) if condition_mask is not None: kwargs["condition_mask"] = condition_mask x_t = path_sample.x_t if condition_mask is not None: condition_mask = condition_mask.reshape(x_t.shape) x_t = jnp.where(condition_mask, x_1, x_t) model_output = vf(path_sample.t, x_t, *args, **kwargs) loss = model_output - path_sample.dx_t if condition_mask is not None: loss = jnp.where(condition_mask, 0.0, loss) return self.reduction(jnp.square(loss)) # type: ignore
[docs] class SimformerDiffLoss(nnx.Module): """ SimformerDiffLoss is a class that computes the diffusion score matching loss for the Simformer model. Args: path: Probability path for training. """ def __init__(self, path):
[docs] self.path = path
[docs] self.loss_fn = self.path.get_loss_fn()
[docs] def __call__( self, key: jax.random.PRNGKey, model: Callable, batch: Tuple[Array, Array, Array], condition_mask: Optional[Array] = None, **kwargs ) -> Array: """ Evaluate the continuous flow matching loss. Args: key (jax.random.PRNGKey): Random key for stochastic operations. model (Callable): F model. batch (Tuple[Array, Array, Array]): Input data (x_1, sigma). args (Optional[dict]): Additional arguments. condition_mask (Optional[Array]): Mask for conditioning. **kwargs: Additional keyword arguments. Returns: Array: Computed loss. """ x_1, sigma = batch path_sample = self.path.sample(key, x_1, sigma) batch = path_sample.get_batch() if condition_mask is not None: kwargs["condition_mask"] = condition_mask loss = self.loss_fn(model, batch, loss_mask=condition_mask, model_extras=kwargs) return loss # type: ignore