gensbi.flow_matching.path.scheduler.schedule_transform#

Classes#

ScheduleTransformedModel

Change of scheduler for a velocity model.

Module Contents#

class gensbi.flow_matching.path.scheduler.schedule_transform.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Change of scheduler for a velocity model.

This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.

Example

import jax
import jax.numpy as jnp
from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel
from flow_matching.solver import ODESolver

# Initialize the model and schedulers
model = ...

original_scheduler = CondOTScheduler()
new_scheduler = CosineScheduler()

# Create the transformed model
transformed_model = ScheduleTransformedModel(
    velocity_model=model,
    original_scheduler=original_scheduler,
    new_scheduler=new_scheduler
)

# Set up the solver
solver = ODESolver(velocity_model=transformed_model)

key = jax.random.PRNGKey(0)
x_0 = jax.random.normal(key, shape=(10, 2))  # Example initial condition

x_1 = solver.sample(
    time_steps=jnp.array([0.0, 1.0]),
    x_init=x_0,
    step_size=1/1000
    )[1]
Parameters:
  • velocity_model (ModelWrapper) – The original velocity model to be transformed.

  • original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.

  • new_scheduler (Scheduler) – The new scheduler to be applied to the model.

__call__(x, t, **extras)[source]#

Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.

Parameters:
  • x (Array) – \(x_t\), the input array.

  • t (Array) – The time array (denoted as \(r\) above).

  • **extras – Additional arguments for the model.

Returns:

The transformed velocity.

Return type:

Array

new_scheduler[source]#
original_scheduler[source]#