Source code for gensbi.diffusion.path.path_sample
from dataclasses import dataclass, field
from jax import Array
from typing import Tuple
[docs]
@dataclass
class EDMPathSample:
r"""Represents a sample of a diffusion generated probability path.
Attributes:
x_1 (Array): the target sample :math:`X_1`.
sigma (Array): the noise scale :math:`t`.
x_t (Array): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...).
"""
x_1: Array = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
sigma: Array = field(metadata={"help": "noise scale sigma (batch_size, ...)."})
x_t: Array = field(
metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."}
)
[docs]
def get_batch(self) -> Tuple[Array, Array, Array]:
r"""
Returns the batch as a tuple (x_1, x_t, sigma).
Returns:
Tuple[Array, Array, Array]: The target sample, the noisy sample, and the noise scale.
"""
return self.x_1, self.x_t, self.sigma