gensbi.flow_matching.path.scheduler.schedule_transform#
Classes#
Change of scheduler for a velocity model. |
Module Contents#
- class gensbi.flow_matching.path.scheduler.schedule_transform.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapper
Change of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- __call__(x, t, **extras)[source]#
Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.
- Parameters:
x (Array) – \(x_t\), the input array.
t (Array) – The time array (denoted as \(r\) above).
**extras – Additional arguments for the model.
- Returns:
The transformed velocity.
- Return type:
Array