gensbi.models.losses#

Submodules#

Classes#

ConditionalCFMLoss

ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.

ConditionalDiffLoss

ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.

JointCFMLoss

JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.

JointDiffLoss

JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.

UnconditionalCFMLoss

UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.

UnconditionalDiffLoss

UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.

Package Contents#

class gensbi.models.losses.ConditionalCFMLoss(path, reduction='mean', cfg_scale=None)[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.

Parameters:
  • path – Probability path (x-prediction training).

  • reduction (str, optional) – Specify the reduction to apply to the output 'none' | 'mean' | 'sum'. 'none': no reduction is applied to the output, 'mean': the output is reduced by mean over sequence elements, 'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.

__call__(vf, batch, cond, obs_ids, cond_ids)[source]#

Evaluates the continuous flow matching loss.

Parameters:
  • vf (callable) – The vector field model to evaluate.

  • batch (tuple) – A tuple containing the input data (x_0, x_1, t).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

The computed loss.

Return type:

jnp.ndarray

cfg_scale = None#
class gensbi.models.losses.ConditionalDiffLoss(path)[source]#

Bases: flax.nnx.Module

ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, cond, obs_ids, cond_ids)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • cond (jnp.ndarray) – The conditioning data.

  • obs_ids (jnp.ndarray) – The observation IDs.

  • cond_ids (jnp.ndarray) – The conditioning IDs.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.losses.JointCFMLoss(path, reduction='mean')[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.

Parameters:
  • path – Probability path for training.

  • reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).

__call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • vf (Callable) – Vector field model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

class gensbi.models.losses.JointDiffLoss(path)[source]#

Bases: flax.nnx.Module

JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, condition_mask=None, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#
class gensbi.models.losses.UnconditionalCFMLoss(path, reduction='mean')[source]#

Bases: gensbi.flow_matching.loss.ContinuousFMLoss

UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.

Parameters:
  • path – Probability path for training.

  • reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).

__call__(vf, batch, *args, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • vf (Callable) – Vector field model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).

  • args (Optional[dict]) – Additional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

class gensbi.models.losses.UnconditionalDiffLoss(path)[source]#

Bases: flax.nnx.Module

UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.

Parameters:

path – Probability path for training.

__call__(key, model, batch, **kwargs)[source]#

Evaluate the continuous flow matching loss.

Parameters:
  • key (jax.random.PRNGKey) – Random key for stochastic operations.

  • model (Callable) – F model.

  • batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).

  • args (Optional[dict]) – Additional arguments.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • **kwargs – Additional keyword arguments.

Returns:

Computed loss.

Return type:

Array

loss_fn#
path#