gensbi.flow_matching.path.affine#

Classes#

AffineProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine.

CondOTProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

Module Contents#

class gensbi.flow_matching.path.affine.AffineProbPath(scheduler)[source]#

Bases: gensbi.flow_matching.path.path.ProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:

\[X_t = \alpha_t X_1 + \sigma_t X_0,\]

where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.

The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.

Example

from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.path import AffineProbPath
import jax, jax.numpy as jnp
scheduler = CondOTScheduler()
path = AffineProbPath(scheduler)
key = jax.random.PRNGKey(0)
# x_1 should come from your dataset (e.g., a batch of real data)
x_1 = jax.random.normal(key, (128, 2))  # replace with your data batch
# x_0 is typically sampled from a prior, e.g., standard normal noise
x_0 = jax.random.normal(key, (128, 2))
t = jax.random.uniform(key, (128,))  # random times in [0, 1]
sample = path.sample(x_0, x_1, t)
print(sample.x_t.shape)
# (128, 2)
Parameters:

scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.

epsilon_to_target(epsilon, x_t, t)[source]#

Convert from epsilon representation to x_1 representation.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

epsilon_to_velocity(epsilon, x_t, t)[source]#

Convert from epsilon representation to velocity.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

sample(x_0, x_1, t)[source]#

Sample from the affine probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample at \(X_t \sim p_t\).

Return type:

PathSample

target_to_epsilon(x_1, x_t, t)[source]#

Convert from x_1 representation to noise.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

target_to_velocity(x_1, x_t, t)[source]#

Convert from x_1 representation to velocity.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

velocity_to_epsilon(velocity, x_t, t)[source]#

Convert from velocity to noise representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

velocity_to_target(velocity, x_t, t)[source]#

Convert from velocity to x_1 representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

scheduler[source]#
class gensbi.flow_matching.path.affine.CondOTProbPath[source]#

Bases: AffineProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

This class is a specialized version of the AffineProbPath that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.

The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:

\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]

Example

from gensbi.flow_matching.path import CondOTProbPath
import jax, jax.numpy as jnp
path = CondOTProbPath()
key = jax.random.PRNGKey(0)
# x_1 should come from your dataset (e.g., a batch of real data)
x_1 = jax.random.normal(key, (64, 2))  # replace with your data batch
# x_0 is typically sampled from a prior, e.g., standard normal noise
x_0 = jax.random.normal(key, (64, 2))
t = jax.random.uniform(key, (64,))  # random times in [0, 1]
sample = path.sample(x_0, x_1, t)
print(sample.x_t.shape)
# (64, 2)
scheduler[source]#