gensbi.flow_matching.path package#
Subpackages#
- gensbi.flow_matching.path.scheduler package
- Submodules
- gensbi.flow_matching.path.scheduler.schedule_transform module
- gensbi.flow_matching.path.scheduler.scheduler module
- Module contents
Submodules#
gensbi.flow_matching.path.affine module#
- class gensbi.flow_matching.path.affine.AffineProbPath(scheduler: Scheduler)[source]#
Bases:
ProbPath
The
AffineProbPath
class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:\[X_t = \alpha_t X_1 + \sigma_t X_0,\]where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.
The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.
Using
AffineProbPath
in the flow matching framework:# Instantiates a probability path my_path = AffineProbPath(...) mse_loss = torch.nn.MSELoss() for x_1 in dataset: # Sets x_0 to random noise x_0 = torch.randn() # Sets t to a random value in [0,1] t = torch.rand() # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) # Computes the MSE loss w.r.t. the velocity loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) loss.backward()
- Parameters:
scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.
- epsilon_to_target(epsilon: Array, x_t: Array, t: Array) Array [source]#
Convert from epsilon representation to x_1 representation.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Target data point.
- Return type:
Array
- epsilon_to_velocity(epsilon: Array, x_t: Array, t: Array) Array [source]#
Convert from epsilon representation to velocity.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- sample(x_0: Array, x_1: Array, t: Array) PathSample [source]#
Sample from the affine probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).
- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample at \(X_t \sim p_t\).
- Return type:
- target_to_epsilon(x_1: Array, x_t: Array, t: Array) Array [source]#
Convert from x_1 representation to noise.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- target_to_velocity(x_1: Array, x_t: Array, t: Array) Array [source]#
Convert from x_1 representation to velocity.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- class gensbi.flow_matching.path.affine.CondOTProbPath[source]#
Bases:
AffineProbPath
The
CondOTProbPath
class represents a conditional optimal transport probability path.This class is a specialized version of the
AffineProbPath
that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:
\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]
gensbi.flow_matching.path.path module#
- class gensbi.flow_matching.path.path.ProbPath[source]#
Bases:
ABC
Abstract class, representing a probability path.
A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).
The
ProbPath
class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example# Instantiate a probability path my_path = ProbPath(...) for x_0, x_1 in dataset: # Sets t to a random value in [0,1] key = jax.random.PRNGKey(0) t = jax.random.uniform(key) # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) # Optimizes the model. The loss function varies, depending on model and path. loss = loss_fn(path_sample, my_model(x_t, t)) grads = jax.grad(loss_fn)(params)
- assert_sample_shape(x_0: Array, x_1: Array, t: Array) None [source]#
Checks that the shapes of x_0, x_1, and t are compatible for sampling.
- Parameters:
x_0 (Array) – Source data point.
x_1 (Array) – Target data point.
t (Array) – Time vector.
- Raises:
AssertionError – If the shapes are not compatible.
- abstractmethod sample(x_0: Array, x_1: Array, t: Array) PathSample [source]#
Sample from an abstract probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under
PathSample
.- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample.
- Return type:
gensbi.flow_matching.path.path_sample module#
- class gensbi.flow_matching.path.path_sample.DiscretePathSample(x_1: Array, x_0: Array, t: Array, x_t: Array)[source]#
Bases:
object
Represents a sample of a conditional-flow generated discrete probability path.
- x_1#
the target sample \(X_1\).
- Type:
Array
- x_0#
the source sample \(X_0\).
- Type:
Array
- t#
the time sample \(t\).
- Type:
Array
- x_t#
the sample along the path \(X_t \sim p_t\).
- Type:
Array
- t: Array#
- x_0: Array#
- x_1: Array#
- x_t: Array#
- class gensbi.flow_matching.path.path_sample.PathSample(x_1: Array, x_0: Array, t: Array, x_t: Array, dx_t: Array)[source]#
Bases:
object
Represents a sample of a conditional-flow generated probability path.
- x_1#
the target sample \(X_1\).
- Type:
Array
- x_0#
the source sample \(X_0\).
- Type:
Array
- t#
the time sample \(t\).
- Type:
Array
- x_t#
samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).
- Type:
Array
- dx_t#
conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).
- Type:
Array
- dx_t: Array#
- t: Array#
- x_0: Array#
- x_1: Array#
- x_t: Array#
Module contents#
- class gensbi.flow_matching.path.AffineProbPath(scheduler: Scheduler)[source]#
Bases:
ProbPath
The
AffineProbPath
class represents a specific type of probability path where the transformation between distributions is affine. An affine transformation can be represented as:\[X_t = \alpha_t X_1 + \sigma_t X_0,\]where \(X_t\) is the transformed data point at time t. \(X_0\) and \(X_1\) are the source and target data points, respectively. \(\alpha_t\) and \(\sigma_t\) are the parameters of the affine transformation at time t.
The scheduler is responsible for providing the time-dependent parameters \(\alpha_t\) and \(\sigma_t\), as well as their derivatives, which define the affine transformation at any given time t.
Using
AffineProbPath
in the flow matching framework:# Instantiates a probability path my_path = AffineProbPath(...) mse_loss = torch.nn.MSELoss() for x_1 in dataset: # Sets x_0 to random noise x_0 = torch.randn() # Sets t to a random value in [0,1] t = torch.rand() # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) # Computes the MSE loss w.r.t. the velocity loss = mse_loss(path_sample.dx_t, my_model(x_t, t)) loss.backward()
- Parameters:
scheduler (Scheduler) – An instance of a scheduler that provides the parameters \(\alpha_t\), \(\sigma_t\), and their derivatives over time.
- epsilon_to_target(epsilon: Array, x_t: Array, t: Array) Array [source]#
Convert from epsilon representation to x_1 representation.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Target data point.
- Return type:
Array
- epsilon_to_velocity(epsilon: Array, x_t: Array, t: Array) Array [source]#
Convert from epsilon representation to velocity.
- Parameters:
epsilon (Array) – Noise in the path sample.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- sample(x_0: Array, x_1: Array, t: Array) PathSample [source]#
Sample from the affine probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\) and a scheduler \((\alpha_t,\sigma_t)\). Returns \(X_0, X_1, X_t = \alpha_t X_1 + \sigma_t X_0\), and the conditional velocity at \(X_t, \dot{X}_t = \dot{\alpha}_t X_1 + \dot{\sigma}_t X_0\).
- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample at \(X_t \sim p_t\).
- Return type:
- target_to_epsilon(x_1: Array, x_t: Array, t: Array) Array [source]#
Convert from x_1 representation to noise.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Noise in the path sample.
- Return type:
Array
- target_to_velocity(x_1: Array, x_t: Array, t: Array) Array [source]#
Convert from x_1 representation to velocity.
- Parameters:
x_1 (Array) – Target data point.
x_t (Array) – Path sample at time t.
t (Array) – Time in [0,1].
- Returns:
Velocity.
- Return type:
Array
- class gensbi.flow_matching.path.CondOTProbPath[source]#
Bases:
AffineProbPath
The
CondOTProbPath
class represents a conditional optimal transport probability path.This class is a specialized version of the
AffineProbPath
that uses a conditional optimal transport scheduler to determine the parameters of the affine transformation.The parameters \(\alpha_t\) and \(\sigma_t\) for the conditional optimal transport path are defined as:
\[\alpha_t = t \quad \text{and} \quad \sigma_t = 1 - t.\]
- class gensbi.flow_matching.path.PathSample(x_1: Array, x_0: Array, t: Array, x_t: Array, dx_t: Array)[source]#
Bases:
object
Represents a sample of a conditional-flow generated probability path.
- x_1#
the target sample \(X_1\).
- Type:
Array
- x_0#
the source sample \(X_0\).
- Type:
Array
- t#
the time sample \(t\).
- Type:
Array
- x_t#
samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).
- Type:
Array
- dx_t#
conditional target \(\frac{\partial X}{\partial t}\), shape: (batch_size, …).
- Type:
Array
- dx_t: Array#
- t: Array#
- x_0: Array#
- x_1: Array#
- x_t: Array#
- class gensbi.flow_matching.path.ProbPath[source]#
Bases:
ABC
Abstract class, representing a probability path.
A probability path transforms the distribution \(p(X_0)\) into \(p(X_1)\) over \(t=0\rightarrow 1\).
The
ProbPath
class is designed to support model training in the flow matching framework. It supports two key functionalities: (1) sampling the conditional probability path and (2) conversion between various training objectives. Here is a high-level example# Instantiate a probability path my_path = ProbPath(...) for x_0, x_1 in dataset: # Sets t to a random value in [0,1] key = jax.random.PRNGKey(0) t = jax.random.uniform(key) # Samples the conditional path X_t ~ p_t(X_t|X_0,X_1) path_sample = my_path.sample(x_0=x_0, x_1=x_1, t=t) # Optimizes the model. The loss function varies, depending on model and path. loss = loss_fn(path_sample, my_model(x_t, t)) grads = jax.grad(loss_fn)(params)
- assert_sample_shape(x_0: Array, x_1: Array, t: Array) None [source]#
Checks that the shapes of x_0, x_1, and t are compatible for sampling.
- Parameters:
x_0 (Array) – Source data point.
x_1 (Array) – Target data point.
t (Array) – Time vector.
- Raises:
AssertionError – If the shapes are not compatible.
- abstractmethod sample(x_0: Array, x_1: Array, t: Array) PathSample [source]#
Sample from an abstract probability path.
Given \((X_0,X_1) \sim \pi(X_0,X_1)\). Returns \(X_0, X_1, X_t \sim p_t(X_t|X_0,X_1)\), and a conditional target \(Y\), all objects are under
PathSample
.- Parameters:
x_0 (Array) – Source data point, shape (batch_size, …).
x_1 (Array) – Target data point, shape (batch_size, …).
t (Array) – Times in [0,1], shape (batch_size,).
- Returns:
A conditional sample.
- Return type: