Source code for gensbi.flow_matching.path.scheduler.schedule_transform

#FIXME: first pass

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from jax import Array

from gensbi.flow_matching.path.scheduler.scheduler import Scheduler
from gensbi.utils.model_wrapping import ModelWrapper

from flax import nnx


[docs] class ScheduleTransformedModel(ModelWrapper): r""" Change of scheduler for a velocity model. This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model's behavior. Example: .. code-block:: python import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1] Args: velocity_model (ModelWrapper): The original velocity model to be transformed. original_scheduler (Scheduler): The scheduler used by the original model. Must implement the snr_inverse function. new_scheduler (Scheduler): The new scheduler to be applied to the model. """ def __init__( self, velocity_model: nnx.Module, original_scheduler: Scheduler, new_scheduler: Scheduler, ) -> None: """ Initialize the ScheduleTransformedModel. Args: velocity_model (nnx.Module): The original velocity model. original_scheduler (Scheduler): The scheduler used by the original model. new_scheduler (Scheduler): The new scheduler to be applied. """ super().__init__(model=velocity_model) self.original_scheduler = original_scheduler self.new_scheduler = new_scheduler assert hasattr(self.original_scheduler, "snr_inverse") and callable( getattr(self.original_scheduler, "snr_inverse") ), "The original scheduler must have a callable 'snr_inverse' method." def __call__(self, x: Array, t: Array, **extras) -> Array: r""" Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows. Args: x (Array): :math:`x_t`, the input array. t (Array): The time array (denoted as :math:`r` above). **extras: Additional arguments for the model. Returns: Array: The transformed velocity. """ r = t r_scheduler_output = self.new_scheduler(t=r) alpha_r = r_scheduler_output.alpha_t sigma_r = r_scheduler_output.sigma_t d_alpha_r = r_scheduler_output.d_alpha_t d_sigma_r = r_scheduler_output.d_sigma_t t = self.original_scheduler.snr_inverse(alpha_r / sigma_r) t_scheduler_output = self.original_scheduler(t=t) alpha_t = t_scheduler_output.alpha_t sigma_t = t_scheduler_output.sigma_t d_alpha_t = t_scheduler_output.d_alpha_t d_sigma_t = t_scheduler_output.d_sigma_t s_r = sigma_r / sigma_t dt_r = ( sigma_t * sigma_t * (sigma_r * d_alpha_r - alpha_r * d_sigma_r) / (sigma_r * sigma_r * (sigma_t * d_alpha_t - alpha_t * d_sigma_t)) ) ds_r = (sigma_t * d_sigma_r - sigma_r * d_sigma_t * dt_r) / (sigma_t * sigma_t) u_t = self.model(x=x / s_r, t=t, **extras) # type: ignore u_r = ds_r * x / s_r + dt_r * s_r * u_t return u_r