gensbi.flow_matching.path#
Submodules#
Classes#
The |
|
The |
|
Represents a sample of a conditional-flow generated probability path. |
|
Abstract class, representing a probability path. |
Package Contents#
- class gensbi.flow_matching.path.AffineProbPath(scheduler)[source]#
Bases:
gensbi.flow_matching.path.path.ProbPath
The
AffineProbPath
class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:\[X_t = \alpha_t X_1 + \sigma_t X_0,\]where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.
The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.
Example
from gensbi.flow_matching.path.scheduler import CondOTScheduler from gensbi.flow_matching.path import AffineProbPath import jax, jax.numpy as jnp scheduler = CondOTScheduler() path = AffineProbPath(scheduler) key = jax.random.PRNGKey(0) # x_1 should come from your dataset (e.g., a batch of real data) x_1 = jax.random.normal(key, (128, 2)) # replace with your data batch # x_0 is typically sampled from a prior, e.g., standard normal noise x_0 = jax.random.normal(key, (128, 2)) t = jax.random.uniform(key, (128,)) # random times in [0, 1] sample = path.sample(x_0, x_1, t) print(sample.x_t.shape) # (128, 2)
- Parameters:
scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.
- epsilon_to_target(epsilon, x_t, t)[source]#
Convert from epsilon representation to x_1 representation.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Target data point.
- Return type:
Array
- epsilon_to_velocity(epsilon, x_t, t)[source]#
Convert from epsilon representation to velocity.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- sample(x_0, x_1, t)[source]#
Sample from the affine probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).
- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample at \(X_t \sim p_t\).
- Return type:
- target_to_epsilon(x_1, x_t, t)[source]#
Convert from x_1 representation to noise.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- target_to_velocity(x_1, x_t, t)[source]#
Convert from x_1 representation to velocity.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- velocity_to_epsilon(velocity, x_t, t)[source]#
Convert from velocity to noise representation.
- Parameters:
velocity (Array) – Velocity at the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- velocity_to_target(velocity, x_t, t)[source]#
Convert from velocity to x_1 representation.
- Parameters:
velocity (Array) – Velocity at the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Target data point.
- Return type:
Array
- scheduler#
- class gensbi.flow_matching.path.CondOTProbPath[source]#
Bases:
AffineProbPath
The
CondOTProbPath
class represents a conditional optimal transport probability path.This class is a specialized version of the
AffineProbPath
that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:
\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]Example
from gensbi.flow_matching.path import CondOTProbPath import jax, jax.numpy as jnp path = CondOTProbPath() key = jax.random.PRNGKey(0) # x_1 should come from your dataset (e.g., a batch of real data) x_1 = jax.random.normal(key, (64, 2)) # replace with your data batch # x_0 is typically sampled from a prior, e.g., standard normal noise x_0 = jax.random.normal(key, (64, 2)) t = jax.random.uniform(key, (64,)) # random times in [0, 1] sample = path.sample(x_0, x_1, t) print(sample.x_t.shape) # (64, 2)
- scheduler#
- class gensbi.flow_matching.path.PathSample[source]#
Represents a sample of a conditional-flow generated probability path.
- x_1#
the target sample \(X_1\).
- Type:
Array
- x_0#
the source sample \(X_0\).
- Type:
Array
- t#
the time sample \(t\).
- Type:
Array
- x_t#
samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).
- Type:
Array
- dx_t#
conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).
- Type:
Array
- dx_t: jax.Array#
- t: jax.Array#
- x_0: jax.Array#
- x_1: jax.Array#
- x_t: jax.Array#
- class gensbi.flow_matching.path.ProbPath[source]#
Bases:
abc.ABC
Abstract class, representing a probability path.
A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).
The
ProbPath
class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example# Instantiate a probability path my_path = ProbPath(...) # Sets t to a random value in [0,1] key = jax.random.PRNGKey(0) t = jax.random.uniform(key) # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
- assert_sample_shape(x_0, x_1, t)[source]#
Checks that the shapes of x_0, x_1, and t are compatible for sampling.
- Parameters:
x_0 (Array) – Source data point.
x_1 (Array) – Target data point.
t (Array) – Time vector.
- Raises:
AssertionError – If the shapes are not compatible.
- Return type:
None
- abstract sample(x_0, x_1, t)[source]#
Sample from an abstract probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under
PathSample
.- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample.
- Return type: