gensbi.flow_matching.path.affine#
Classes#
The |
|
The |
Module Contents#
- class gensbi.flow_matching.path.affine.AffineProbPath(scheduler)[source]#
Bases:
gensbi.flow_matching.path.path.ProbPath
The
AffineProbPath
class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:\[X_t = \alpha_t X_1 + \sigma_t X_0,\]where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.
The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.
Example
from gensbi.flow_matching.path.scheduler import CondOTScheduler from gensbi.flow_matching.path import AffineProbPath import jax, jax.numpy as jnp scheduler = CondOTScheduler() path = AffineProbPath(scheduler) key = jax.random.PRNGKey(0) # x_1 should come from your dataset (e.g., a batch of real data) x_1 = jax.random.normal(key, (128, 2)) # replace with your data batch # x_0 is typically sampled from a prior, e.g., standard normal noise x_0 = jax.random.normal(key, (128, 2)) t = jax.random.uniform(key, (128,)) # random times in [0, 1] sample = path.sample(x_0, x_1, t) print(sample.x_t.shape) # (128, 2)
- Parameters:
scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.
- epsilon_to_target(epsilon, x_t, t)[source]#
Convert from epsilon representation to x_1 representation.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Target data point.
- Return type:
Array
- epsilon_to_velocity(epsilon, x_t, t)[source]#
Convert from epsilon representation to velocity.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- sample(x_0, x_1, t)[source]#
Sample from the affine probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).
- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample at \(X_t \sim p_t\).
- Return type:
- target_to_epsilon(x_1, x_t, t)[source]#
Convert from x_1 representation to noise.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- target_to_velocity(x_1, x_t, t)[source]#
Convert from x_1 representation to velocity.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- velocity_to_epsilon(velocity, x_t, t)[source]#
Convert from velocity to noise representation.
- Parameters:
velocity (Array) – Velocity at the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- class gensbi.flow_matching.path.affine.CondOTProbPath[source]#
Bases:
AffineProbPath
The
CondOTProbPath
class represents a conditional optimal transport probability path.This class is a specialized version of the
AffineProbPath
that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:
\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]Example
from gensbi.flow_matching.path import CondOTProbPath import jax, jax.numpy as jnp path = CondOTProbPath() key = jax.random.PRNGKey(0) # x_1 should come from your dataset (e.g., a batch of real data) x_1 = jax.random.normal(key, (64, 2)) # replace with your data batch # x_0 is typically sampled from a prior, e.g., standard normal noise x_0 = jax.random.normal(key, (64, 2)) t = jax.random.uniform(key, (64,)) # random times in [0, 1] sample = path.sample(x_0, x_1, t) print(sample.x_t.shape) # (64, 2)