Source code for gensbi.flow_matching.path.path_sample

from dataclasses import dataclass, field
from jax import Array


@dataclass
[docs] class PathSample: r"""Represents a sample of a conditional-flow generated probability path. Attributes: x_1 (Array): the target sample :math:`X_1`. x_0 (Array): the source sample :math:`X_0`. t (Array): the time sample :math:`t`. x_t (Array): samples :math:`X_t \sim p_t(X_t)`, shape (batch_size, ...). dx_t (Array): conditional target :math:`\frac{\partial X}{\partial t}`, shape: (batch_size, ...). """
[docs] x_1: Array = field(metadata={"help": "target samples X_1 (batch_size, ...)."})
[docs] x_0: Array = field(metadata={"help": "source samples X_0 (batch_size, ...)."})
[docs] t: Array = field(metadata={"help": "time samples t (batch_size, ...)."})
[docs] x_t: Array = field( metadata={"help": "samples x_t ~ p_t(X_t), shape (batch_size, ...)."} )
[docs] dx_t: Array = field( metadata={"help": "conditional target dX_t, shape: (batch_size, ...)."} )
# @dataclass # class DiscretePathSample: # r""" # Represents a sample of a conditional-flow generated discrete probability path. # Attributes: # x_1 (Array): the target sample :math:`X_1`. # x_0 (Array): the source sample :math:`X_0`. # t (Array): the time sample :math:`t`. # x_t (Array): the sample along the path :math:`X_t \sim p_t`. # """ # x_1: Array = field(metadata={"help": "target samples X_1 (batch_size, ...)."}) # x_0: Array = field(metadata={"help": "source samples X_0 (batch_size, ...)."}) # t: Array = field(metadata={"help": "time samples t (batch_size, ...)."}) # x_t: Array = field( # metadata={"help": "samples X_t ~ p_t(X_t), shape (batch_size, ...)."} # )