gensbi.flow_matching.path package#

Subpackages#

Submodules#

gensbi.flow_matching.path.affine module#

class gensbi.flow_matching.path.affine.AffineProbPath(scheduler: Scheduler)[source]#

Bases: ProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:

\[X_t = \alpha_t X_1 + \sigma_t X_0,\]

where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.

The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.

Using AffineProbPath in the flow matching framework:

# Instantiates a probability path
my_path = AffineProbPath(...)
mse_loss = torch.nn.MSELoss()

for x_1 in dataset:
    # Sets x_0 to random noise
    x_0 = torch.randn()

    # Sets t to a random value in [0,1]
    t = torch.rand()

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Computes the MSE loss w.r.t. the velocity
    loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
    loss.backward()
Parameters:

scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.

epsilon_to_target(epsilon: Array, x_t: Array, t: Array) Array[source]#

Convert from epsilon representation to x_1 representation.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

epsilon_to_velocity(epsilon: Array, x_t: Array, t: Array) Array[source]#

Convert from epsilon representation to velocity.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

sample(x_0: Array, x_1: Array, t: Array) PathSample[source]#

Sample from the affine probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample at \(X_t \sim p_t\).

Return type:

PathSample

target_to_epsilon(x_1: Array, x_t: Array, t: Array) Array[source]#

Convert from x_1 representation to noise.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

target_to_velocity(x_1: Array, x_t: Array, t: Array) Array[source]#

Convert from x_1 representation to velocity.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

velocity_to_epsilon(velocity: Array, x_t: Array, t: Array) Array[source]#

Convert from velocity to noise representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

velocity_to_target(velocity: Array, x_t: Array, t: Array) Array[source]#

Convert from velocity to x_1 representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

class gensbi.flow_matching.path.affine.CondOTProbPath[source]#

Bases: AffineProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

This class is a specialized version of the AffineProbPath that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.

The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:

\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]

gensbi.flow_matching.path.path module#

class gensbi.flow_matching.path.path.ProbPath[source]#

Bases: ABC

Abstract class, representing a probability path.

A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).

The ProbPath class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example

# Instantiate a probability path
my_path = ProbPath(...)

for x_0, x_1 in dataset:
    # Sets t to a random value in [0,1]
    key = jax.random.PRNGKey(0)
    t = jax.random.uniform(key)

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Optimizes the model. The loss function varies, depending on model and path.
    loss = loss_fn(path_sample, my_model(x_t, t))
    grads = jax.grad(loss_fn)(params)
assert_sample_shape(x_0: Array, x_1: Array, t: Array) None[source]#

Checks that the shapes of x_0, x_1, and t are compatible for sampling.

Parameters:
  • x_0 (Array) – Source data point.

  • x_1 (Array) – Target data point.

  • t (Array) – Time vector.

Raises:

AssertionError – If the shapes are not compatible.

abstractmethod sample(x_0: Array, x_1: Array, t: Array) PathSample[source]#

Sample from an abstract probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under PathSample.

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample.

Return type:

PathSample

gensbi.flow_matching.path.path_sample module#

class gensbi.flow_matching.path.path_sample.DiscretePathSample(x_1: Array, x_0: Array, t: Array, x_t: Array)[source]#

Bases: object

Represents a sample of a conditional-flow generated discrete probability path.

x_1#

the target sample \(X_1\).

Type:

Array

x_0#

the source sample \(X_0\).

Type:

Array

t#

the time sample \(t\).

Type:

Array

x_t#

the sample along the path \(X_t \sim p_t\).

Type:

Array

t: Array#
x_0: Array#
x_1: Array#
x_t: Array#
class gensbi.flow_matching.path.path_sample.PathSample(x_1: Array, x_0: Array, t: Array, x_t: Array, dx_t: Array)[source]#

Bases: object

Represents a sample of a conditional-flow generated probability path.

x_1#

the target sample \(X_1\).

Type:

Array

x_0#

the source sample \(X_0\).

Type:

Array

t#

the time sample \(t\).

Type:

Array

x_t#

samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).

Type:

Array

dx_t#

conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).

Type:

Array

dx_t: Array#
t: Array#
x_0: Array#
x_1: Array#
x_t: Array#

Module contents#

class gensbi.flow_matching.path.AffineProbPath(scheduler: Scheduler)[source]#

Bases: ProbPath

The AffineProbPath class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:

\[X_t = \alpha_t X_1 + \sigma_t X_0,\]

where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.

The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.

Using AffineProbPath in the flow matching framework:

# Instantiates a probability path
my_path = AffineProbPath(...)
mse_loss = torch.nn.MSELoss()

for x_1 in dataset:
    # Sets x_0 to random noise
    x_0 = torch.randn()

    # Sets t to a random value in [0,1]
    t = torch.rand()

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Computes the MSE loss w.r.t. the velocity
    loss = mse_loss(path_sample.dx_t, my_model(x_t, t))
    loss.backward()
Parameters:

scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.

epsilon_to_target(epsilon: Array, x_t: Array, t: Array) Array[source]#

Convert from epsilon representation to x_1 representation.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

epsilon_to_velocity(epsilon: Array, x_t: Array, t: Array) Array[source]#

Convert from epsilon representation to velocity.

Parameters:
  • epsilon (Array) – Noise in the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

sample(x_0: Array, x_1: Array, t: Array) PathSample[source]#

Sample from the affine probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample at \(X_t \sim p_t\).

Return type:

PathSample

target_to_epsilon(x_1: Array, x_t: Array, t: Array) Array[source]#

Convert from x_1 representation to noise.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

target_to_velocity(x_1: Array, x_t: Array, t: Array) Array[source]#

Convert from x_1 representation to velocity.

Parameters:
  • x_1 (Array) – Target data point.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Velocity.

Return type:

Array

velocity_to_epsilon(velocity: Array, x_t: Array, t: Array) Array[source]#

Convert from velocity to noise representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Noise in the path sample.

Return type:

Array

velocity_to_target(velocity: Array, x_t: Array, t: Array) Array[source]#

Convert from velocity to x_1 representation.

Parameters:
  • velocity (Array) – Velocity at the path sample.

  • x_t (Array) – Path sample at time t.

  • t (Array) – Time in [0,1].

Returns:

Target data point.

Return type:

Array

class gensbi.flow_matching.path.CondOTProbPath[source]#

Bases: AffineProbPath

The CondOTProbPath class represents a conditional optimal transport probability path.

This class is a specialized version of the AffineProbPath that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.

The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:

\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]
class gensbi.flow_matching.path.PathSample(x_1: Array, x_0: Array, t: Array, x_t: Array, dx_t: Array)[source]#

Bases: object

Represents a sample of a conditional-flow generated probability path.

x_1#

the target sample \(X_1\).

Type:

Array

x_0#

the source sample \(X_0\).

Type:

Array

t#

the time sample \(t\).

Type:

Array

x_t#

samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).

Type:

Array

dx_t#

conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).

Type:

Array

dx_t: Array#
t: Array#
x_0: Array#
x_1: Array#
x_t: Array#
class gensbi.flow_matching.path.ProbPath[source]#

Bases: ABC

Abstract class, representing a probability path.

A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).

The ProbPath class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example

# Instantiate a probability path
my_path = ProbPath(...)

for x_0, x_1 in dataset:
    # Sets t to a random value in [0,1]
    key = jax.random.PRNGKey(0)
    t = jax.random.uniform(key)

    # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1)
    path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t)

    # Optimizes the model. The loss function varies, depending on model and path.
    loss = loss_fn(path_sample, my_model(x_t, t))
    grads = jax.grad(loss_fn)(params)
assert_sample_shape(x_0: Array, x_1: Array, t: Array) None[source]#

Checks that the shapes of x_0, x_1, and t are compatible for sampling.

Parameters:
  • x_0 (Array) – Source data point.

  • x_1 (Array) – Target data point.

  • t (Array) – Time vector.

Raises:

AssertionError – If the shapes are not compatible.

abstractmethod sample(x_0: Array, x_1: Array, t: Array) PathSample[source]#

Sample from an abstract probability path.

Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under PathSample.

Parameters:
  • x_0 (Array) – Source data point, shape (batch_size, …).

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • t (Array) – Times in [0,1], shape (batch_size,).

Returns:

A conditional sample.

Return type:

PathSample