gensbi.models.simformer.loss#
Classes#
SimformerCFMLoss is a class that computes the continuous flow matching loss for the Simformer model. |
|
SimformerDiffLoss is a class that computes the diffusion score matching loss for the Simformer model. |
Module Contents#
- class gensbi.models.simformer.loss.SimformerCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLoss
SimformerCFMLoss is a class that computes the continuous flow matching loss for the Simformer model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.simformer.loss.SimformerDiffLoss(path)[source]#
Bases:
flax.nnx.Module
SimformerDiffLoss is a class that computes the diffusion score matching loss for the Simformer model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array