gensbi.flow_matching.path.scheduler#

Submodules#

Classes#

CondOTScheduler

CondOT Scheduler.

ConvexScheduler

Base Scheduler class.

CosineScheduler

Cosine Scheduler.

LinearVPScheduler

Linear Variance Preserving Scheduler.

PolynomialConvexScheduler

Polynomial Scheduler.

ScheduleTransformedModel

Change of scheduler for a velocity model.

Scheduler

Base Scheduler class.

SchedulerOutput

Represents a sample of a conditional-flow generated probability path.

VPScheduler

Variance Preserving Scheduler.

Package Contents#

class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#

Bases: ConvexScheduler

CondOT Scheduler.

__call__(t)[source]#

Scheduler for convex paths.

Args:

t (Array): times in [0,1], shape (…).

Returns:

SchedulerOutput: :math:`lpha_t,sigma_t,

rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`

Parameters:

t (jax.Array)

Return type:

SchedulerOutput

kappa_inverse(kappa)[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#

Bases: Scheduler

Base Scheduler class.

abstract __call__(t)[source]#

Scheduler for convex paths.

Args:

t (Array): times in [0,1], shape (…).

Returns:

SchedulerOutput: :math:`lpha_t,sigma_t,

rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`

Parameters:

t (jax.Array)

Return type:

SchedulerOutput

abstract kappa_inverse(kappa)[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#

Bases: Scheduler

Cosine Scheduler.

__call__(t)[source]#
Parameters:

t (Array) – times in [0,1], shape (…).

Returns:

\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)

Return type:

SchedulerOutput

snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#

Bases: Scheduler

Linear Variance Preserving Scheduler.

__call__(t)[source]#
Parameters:

t (Array) – times in [0,1], shape (…).

Returns:

\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)

Return type:

SchedulerOutput

snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n)[source]#

Bases: ConvexScheduler

Polynomial Scheduler.

Parameters:

n (Union[float, int])

__call__(t)[source]#

Scheduler for convex paths.

Args:

t (Array): times in [0,1], shape (…).

Returns:

SchedulerOutput: :math:`lpha_t,sigma_t,

rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`

Parameters:

t (jax.Array)

Return type:

SchedulerOutput

kappa_inverse(kappa)[source]#

Computes \(t\) from \(\kappa_t\).

Parameters:

kappa (Array) – \(\kappa\), shape (…)

Returns:

t, shape (…)

Return type:

Array

n#
class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example

import jax
import jax.numpy as jnp
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

key = jax.random.PRNGKey(0)
x_0 = jax.random.normal(key, shape=(10, 2))  # Example initial condition

x_1 = solver.sample(
    time_steps=jnp.array([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

__call__(x, t, **extras)[source]#

Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.

Parameters:
  • x (Array) – \(x_t\), the input array.

  • t (Array) – The time array (denoted as \(r\) above).

  • **extras – Additional arguments for the model.

Returns:

The transformed velocity.

Return type:

Array

new_scheduler#
original_scheduler#
class gensbi.flow_matching.path.scheduler.Scheduler[source]#

Bases: abc.ABC

Base Scheduler class.

abstract __call__(t)[source]#
Parameters:

t (Array) – times in [0,1], shape (…).

Returns:

\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)

Return type:

SchedulerOutput

abstract snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

class gensbi.flow_matching.path.scheduler.SchedulerOutput[source]#

Represents a sample of a conditional-flow generated probability path.

alpha_t#

\(\alpha_t\), shape (…).

Type:

Array

sigma_t#

\(\sigma_t\), shape (…).

Type:

Array

d_alpha_t#

\(\frac{\partial}{\partial t}\alpha_t\), shape (…).

Type:

Array

d_sigma_t#

\(\frac{\partial}{\partial t}\sigma_t\), shape (…).

Type:

Array

alpha_t: jax.Array#
d_alpha_t: jax.Array#
d_sigma_t: jax.Array#
sigma_t: jax.Array#
class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0)[source]#

Bases: Scheduler

Variance Preserving Scheduler.

Parameters:
  • beta_min (float)

  • beta_max (float)

__call__(t)[source]#
Parameters:

t (Array) – times in [0,1], shape (…).

Returns:

\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)

Return type:

SchedulerOutput

snr_inverse(snr)[source]#

Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).

Parameters:

snr (Array) – The signal-to-noise, shape (…)

Returns:

t, shape (…)

Return type:

Array

beta_max = 20.0#
beta_min = 0.1#