import jax.numpy as jnp
import jax
from flax import nnx
from typing import Callable, Tuple, Optional
from jax.numpy import ndarray as Array
from gensbi.flow_matching.loss import ContinuousFMLoss
[docs]
class FluxCFMLoss(ContinuousFMLoss):
"""
FluxCFMLoss is a class that computes the continuous flow matching loss for the Flux model.
Args:
path: Probability path (x-prediction training).
reduction (str, optional): Specify the reduction to apply to the output ``'none'`` | ``'mean'`` | ``'sum'``. ``'none'``: no reduction is applied to the output, ``'mean'``: the output is reduced by mean over sequence elements, ``'sum'``: the output is reduced by sum over sequence elements. Defaults to 'mean'.
"""
def __init__(self, path, reduction="mean", cfg_scale=None):
# self.path = path
# if reduction not in ["None", "mean", "sum"]:
# raise ValueError(f"{reduction} is not a valid value for reduction")
# if reduction == "mean":
# self.reduction = jnp.mean
# elif reduction == "sum":
# self.reduction = jnp.sum
# else:
# self.reduction = lambda x: x
super().__init__(path, reduction)
[docs]
self.cfg_scale = cfg_scale
[docs]
def __call__(self, vf, batch, cond, obs_ids, cond_ids):
"""
Evaluates the continuous flow matching loss.
Args:
vf (callable): The vector field model to evaluate.
batch (tuple): A tuple containing the input data (x_0, x_1, t).
cond (jnp.ndarray): The conditioning data.
obs_ids (jnp.ndarray): The observation IDs.
cond_ids (jnp.ndarray): The conditioning IDs.
Returns:
jnp.ndarray: The computed loss.
"""
path_sample = self.path.sample(*batch)
x_t = path_sample.x_t
if self.cfg_scale is not None:
key = jax.random.PRNGKey(0)
conditioned = jax.random.bernoulli(
key, p=self.cfg_scale, shape=(x_t.shape[0],)
)
else:
conditioned = jnp.ones((x_t.shape[0],), dtype=jnp.bool_)
model_output = vf(path_sample.t, x_t, obs_ids, cond, cond_ids, conditioned=conditioned)
loss = model_output - path_sample.dx_t
loss = jnp.square(loss)
return self.reduction(loss)
# TODO: WIP
[docs]
class FluxDiffLoss(nnx.Module):
"""
FluxDiffLoss is a class that computes the diffusion score matching loss for the Flux model.
Args:
path: Probability path for training.
"""
def __init__(self, path):
[docs]
self.loss_fn = self.path.get_loss_fn()
[docs]
def __call__(
self,
key: jax.random.PRNGKey,
model: Callable,
batch: Tuple[Array, Array, Array],
cond,
obs_ids,
cond_ids,
) -> Array:
"""
Evaluate the continuous flow matching loss.
Args:
key (jax.random.PRNGKey): Random key for stochastic operations.
model (Callable): F model.
batch (Tuple[Array, Array, Array]): Input data (x_1, sigma).
cond (jnp.ndarray): The conditioning data.
obs_ids (jnp.ndarray): The observation IDs.
cond_ids (jnp.ndarray): The conditioning IDs.
Returns:
Array: Computed loss.
"""
x_1, sigma = batch
path_sample = self.path.sample(key, x_1, sigma)
batch = path_sample.get_batch()
# def F_model(x, sigma, obs_ids, cond, cond_ids, **model_extras):
# if sigma.ndim == 1:
# sigma = sigma[..., None, None]
# return model(
# t=sigma,
# obs=x,
# obs_ids=obs_ids,
# cond=cond,
# cond_ids=cond_ids,
# **model_extras,
# )
model_extras = {}
model_extras["cond"] = cond
model_extras["obs_ids"] = obs_ids
model_extras["cond_ids"] = cond_ids
loss = self.loss_fn(model, batch, loss_mask=None, model_extras=model_extras)
return loss # type: ignore