gensbi.flow_matching.path#

Submodules#

Classes#

AffineProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine.

CondOTProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

PathSample

Represents a sample of a conditional-flow generated probability path.

ProbPath

Abstract class, representing a probability path.

Package Contents#

class gensbi.flow_matching.path.AffineProbPath(scheduler)[source]#

Bases: gensbi.flow_matching.path.path.ProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:

\[X_t = \alpha_t X_1 + \sigma_t X_0,\]

where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.

The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.

Example

from gensbi.flow_matching.path.scheduler import CondOTScheduler
from gensbi.flow_matching.path import AffineProbPath
import jax, jax.numpy as jnp
scheduler = CondOTScheduler()
path = AffineProbPath(scheduler)
key = jax.random.PRNGKey(0)
# x_1 should come from your dataset (e.g., a batch of real data)
x_1 = jax.random.normal(key, (128, 2))  # replace with your data batch
# x_0 is typically sampled from a prior, e.g., standard normal noise
x_0 = jax.random.normal(key, (128, 2))
t = jax.random.uniform(key, (128,))  # random times in [0, 1]
sample = path.sample(x_0, x_1, t)
print(sample.x_t.shape)
# (128, 2)
Parameters:

scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.

epsilon_to_target(epsilon, x_t, t)[source]#

Convert from epsilon representation to x_1 representation.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

epsilon_to_velocity(epsilon, x_t, t)[source]#

Convert from epsilon representation to velocity.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

sample(x_0, x_1, t)[source]#

Sample from the affine probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample at \(X_t \sim p_t\).

Return type:

PathSample

target_to_epsilon(x_1, x_t, t)[source]#

Convert from x_1 representation to noise.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

target_to_velocity(x_1, x_t, t)[source]#

Convert from x_1 representation to velocity.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

velocity_to_epsilon(velocity, x_t, t)[source]#

Convert from velocity to noise representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

velocity_to_target(velocity, x_t, t)[source]#

Convert from velocity to x_1 representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

scheduler#
class gensbi.flow_matching.path.CondOTProbPath[source]#

Bases: AffineProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

This class is a specialized version of the AffineProbPath that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.

The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:

\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]

Example

from gensbi.flow_matching.path import CondOTProbPath
import jax, jax.numpy as jnp
path = CondOTProbPath()
key = jax.random.PRNGKey(0)
# x_1 should come from your dataset (e.g., a batch of real data)
x_1 = jax.random.normal(key, (64, 2))  # replace with your data batch
# x_0 is typically sampled from a prior, e.g., standard normal noise
x_0 = jax.random.normal(key, (64, 2))
t = jax.random.uniform(key, (64,))  # random times in [0, 1]
sample = path.sample(x_0, x_1, t)
print(sample.x_t.shape)
# (64, 2)
scheduler#
class gensbi.flow_matching.path.PathSample[source]#

Represents a sample of a conditional-flow generated probability path.

x_1#

the target sample \(X_1\).

Type:

Array

x_0#

the source sample \(X_0\).

Type:

Array

t#

the time sample \(t\).

Type:

Array

x_t#

samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).

Type:

Array

dx_t#

conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).

Type:

Array

dx_t: jax.Array#
t: jax.Array#
x_0: jax.Array#
x_1: jax.Array#
x_t: jax.Array#
class gensbi.flow_matching.path.ProbPath[source]#

Bases: abc.ABC

Abstract class, representing a probability path.

A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).

The ProbPath class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example

# Instantiate a probability path
my_path = ProbPath(...)

# Sets t to a random value in [0,1]
key = jax.random.PRNGKey(0)
t = jax.random.uniform(key)

# Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)
assert_sample_shape(x_0, x_1, t)[source]#

Checks that the shapes of x_0, x_1, and t are compatible for sampling.

Parameters:
  • x_0 (Array) – Source data point.

  • x_1 (Array) – Target data point.

  • t (Array) – Time vector.

Raises:

AssertionError – If the shapes are not compatible.

Return type:

None

abstract sample(x_0, x_1, t)[source]#

Sample from an abstract probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under PathSample.

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample.

Return type:

PathSample