Source code for gensbi.models.flux1.loss

import jax.numpy as jnp
import jax
from flax import nnx

from gensbi.flow_matching.loss import ContinuousFMLoss


[docs] class FluxCFMLoss(ContinuousFMLoss): def __init__(self, path, reduction="mean", cfg_scale=None): """ ContinuousFMLoss is a class that computes the continuous flow matching loss. Args: path: Probability path (x-prediction training). reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'. """ # self.path = path # if reduction not in ["None", "mean", "sum"]: # raise ValueError(f"{reduction} is not a valid value for reduction") # if reduction == "mean": # self.reduction = jnp.mean # elif reduction == "sum": # self.reduction = jnp.sum # else: # self.reduction = lambda x: x super().__init__(path, reduction) self.cfg_scale = cfg_scale def __call__(self, vf, batch, cond, obs_ids, cond_ids): """ Evaluates the continuous flow matching loss. Args: vf (callable): The vector field model to evaluate. batch (tuple): A tuple containing the input data (x_0, x_1, t). cond (jnp.ndarray): The conditioning data. obs_ids (jnp.ndarray): The observation IDs. cond_ids (jnp.ndarray): The conditioning IDs. Returns: jnp.ndarray: The computed loss. """ x_0, x_1, t = batch path_sample = self.path.sample(x_0, x_1, t) x_t = path_sample.x_t model_output = vf(x_t, obs_ids, cond, cond_ids, t, conditioned=True) loss_cond = model_output - path_sample.dx_t if self.cfg_scale is not None: model_output_uncond = vf(x_t, obs_ids, cond, cond_ids, t, conditioned=False) loss_uncond = model_output_uncond - path_sample.dx_t weight = self.cfg_scale loss = weight*jnp.square(loss_cond) + (1-weight)*jnp.square(loss_uncond) else: loss = jnp.square(loss_cond) return self.reduction(loss)