gensbi.flow_matching.path.scheduler package#

Submodules#

gensbi.flow_matching.path.scheduler.schedule_transform module#

class gensbi.flow_matching.path.scheduler.schedule_transform.ScheduleTransformedModel(*args: Any, **kwargs: Any)[source]#

Bases: ModelWrapper

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example:

import jax
import jax.numpy as jnp
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

key = jax.random.PRNGKey(0)
x_0 = jax.random.normal(key, shape=(10, 2))  # Example initial condition

x_1 = solver.sample(
    time_steps=jnp.array([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

gensbi.flow_matching.path.scheduler.scheduler module#

class gensbi.flow_matching.path.scheduler.scheduler.CondOTScheduler[source]#

Bases: ConvexScheduler

CondOT Scheduler.

kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.ConvexScheduler[source]#

Bases: Scheduler

abstractmethod kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.CosineScheduler[source]#

Bases: Scheduler

Cosine Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.LinearVPScheduler[source]#

Bases: Scheduler

Linear Variance Preserving Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.PolynomialConvexScheduler(n: float | int)[source]#

Bases: ConvexScheduler

Polynomial Scheduler.

kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.Scheduler[source]#

Bases: ABC

Base Scheduler class.

abstractmethod snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.scheduler.SchedulerOutput(alpha_t: Array, sigma_t: Array, d_alpha_t: Array, d_sigma_t: Array)[source]#

Bases: object

Represents a sample of a conditional-flow generated probability path.

alpha_t#

\(\alpha_t\), shape (…).

Type:

Array

sigma_t#

\(\sigma_t\), shape (…).

Type:

Array

d_alpha_t#

\(\frac{\partial}{\partial t}\alpha_t\), shape (…).

Type:

Array

d_sigma_t#

\(\frac{\partial}{\partial t}\sigma_t\), shape (…).

Type:

Array

alpha_t: Array#
d_alpha_t: Array#
d_sigma_t: Array#
sigma_t: Array#
class gensbi.flow_matching.path.scheduler.scheduler.VPScheduler(beta_min: float = 0.1, beta_max: float = 20.0)[source]#

Bases: Scheduler

Variance Preserving Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

Module contents#

class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#

Bases: ConvexScheduler

CondOT Scheduler.

kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#

Bases: Scheduler

abstractmethod kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#

Bases: Scheduler

Cosine Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#

Bases: Scheduler

Linear Variance Preserving Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n: float | int)[source]#

Bases: ConvexScheduler

Polynomial Scheduler.

kappa_inverse(kappa: Array) Array[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(*args: Any, **kwargs: Any)[source]#

Bases: ModelWrapper

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example:

import jax
import jax.numpy as jnp
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

key = jax.random.PRNGKey(0)
x_0 = jax.random.normal(key, shape=(10, 2))  # Example initial condition

x_1 = solver.sample(
    time_steps=jnp.array([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

class gensbi.flow_matching.path.scheduler.Scheduler[source]#

Bases: ABC

Base Scheduler class.

abstractmethod snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.SchedulerOutput(alpha_t: Array, sigma_t: Array, d_alpha_t: Array, d_sigma_t: Array)[source]#

Bases: object

Represents a sample of a conditional-flow generated probability path.

alpha_t#

\(\alpha_t\), shape (…).

Type:

Array

sigma_t#

\(\sigma_t\), shape (…).

Type:

Array

d_alpha_t#

\(\frac{\partial}{\partial t}\alpha_t\), shape (…).

Type:

Array

d_sigma_t#

\(\frac{\partial}{\partial t}\sigma_t\), shape (…).

Type:

Array

alpha_t: Array#
d_alpha_t: Array#
d_sigma_t: Array#
sigma_t: Array#
class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min: float = 0.1, beta_max: float = 20.0)[source]#

Bases: Scheduler

Variance Preserving Scheduler.

snr_inverse(snr: Array) Array[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array