gensbi.models.losses#
Submodules#
Classes#
ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model. |
|
ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model. |
|
JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model. |
|
JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model. |
|
UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model. |
|
UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model. |
Package Contents#
- class gensbi.models.losses.ConditionalCFMLoss(path, reduction='mean', cfg_scale=None)[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.
- Parameters:
path – Probability path (x-prediction training).
reduction (str, optional) – Specify the reduction to apply to the output
'none'|'mean'|'sum'.'none': no reduction is applied to the output,'mean': the output is reduced by mean over sequence elements,'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.
- __call__(vf, batch, cond, obs_ids, cond_ids)[source]#
Evaluates the continuous flow matching loss.
- Parameters:
vf (callable) – The vector field model to evaluate.
batch (tuple) – A tuple containing the input data (x_0, x_1, t).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
The computed loss.
- Return type:
jnp.ndarray
- cfg_scale = None#
- class gensbi.models.losses.ConditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, cond, obs_ids, cond_ids)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.losses.JointCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossJointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.losses.JointDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleJointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.losses.UnconditionalCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossUnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.losses.UnconditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleUnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#