gensbi.flow_matching.loss.continuous_loss#

Classes#

ContinuousFMLoss

ContinuousFMLoss is a class that computes the continuous flow matching loss.

Module Contents#

class gensbi.flow_matching.loss.continuous_loss.ContinuousFMLoss(path, reduction='mean')[source]#

Bases: flax.nnx.Module

ContinuousFMLoss is a class that computes the continuous flow matching loss.

Parameters:
  • path (MixtureDiscreteProbPath) – Probability path (x-prediction training).

  • reduction (str, optional) – Specify the reduction to apply to the output 'none' | 'mean' | 'sum'. 'none': no reduction is applied to the output, 'mean': the output is reduced by mean over sequence elements, 'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.

Example

from gensbi.flow_matching.loss import ContinuousFMLoss
from gensbi.flow_matching.path import AffineProbPath
from gensbi.flow_matching.path.scheduler import CondOTScheduler
import jax, jax.numpy as jnp
scheduler = CondOTScheduler()
path = AffineProbPath(scheduler)
loss_fn = ContinuousFMLoss(path)
def vf(x, t, args=None):
    return x + t
x_0 = jnp.zeros((8, 2))
x_1 = jnp.ones((8, 2))
t = jnp.linspace(0, 1, 8)
batch = (x_0, x_1, t)
loss = loss_fn(vf, batch)
print(loss.shape)
# ()
__call__(vf, batch, args=None, **kwargs)[source]#

Evaluates the continuous flow matching loss.

Parameters:
  • vf (callable) – The vector field model to evaluate.

  • batch (tuple) – A tuple containing the input data (x_0, x_1, t).

  • args (optional) – Additional arguments for the function.

  • condition_mask (optional) – A mask to apply to the input data.

  • **kwargs – Additional keyword arguments for the function.

Returns:

The computed loss.

Return type:

Array

path[source]#