gensbi.models.losses.unconditional#
Classes#
UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model. |
|
UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model. |
Module Contents#
- class gensbi.models.losses.unconditional.UnconditionalCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossUnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.losses.unconditional.UnconditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleUnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array