gensbi.flow_matching.path.scheduler#
Submodules#
Classes#
CondOT Scheduler. |
|
Base Scheduler class. |
|
Cosine Scheduler. |
|
Linear Variance Preserving Scheduler. |
|
Polynomial Scheduler. |
|
Change of scheduler for a velocity model. |
|
Base Scheduler class. |
|
Represents a sample of a conditional-flow generated probability path. |
|
Variance Preserving Scheduler. |
Package Contents#
- class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#
Bases:
ConvexScheduler
CondOT Scheduler.
- class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#
Bases:
Scheduler
Base Scheduler class.
- abstract __call__(t)[source]#
Scheduler for convex paths.
- Args:
t (Array): times in [0,1], shape (…).
- Returns:
SchedulerOutput: :math:`lpha_t,sigma_t,
rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`
- Parameters:
t (jax.Array)
- Return type:
- class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#
Bases:
Scheduler
Cosine Scheduler.
- class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#
Bases:
Scheduler
Linear Variance Preserving Scheduler.
- class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n)[source]#
Bases:
ConvexScheduler
Polynomial Scheduler.
- Parameters:
n (Union[float, int])
- __call__(t)[source]#
Scheduler for convex paths.
- Args:
t (Array): times in [0,1], shape (…).
- Returns:
SchedulerOutput: :math:`lpha_t,sigma_t,
rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`
- Parameters:
t (jax.Array)
- Return type:
- kappa_inverse(kappa)[source]#
Computes \(t\) from \(\kappa_t\).
- Parameters:
kappa (Array) – \(\kappa\), shape (…)
- Returns:
t, shape (…)
- Return type:
Array
- n#
- class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapper
Change of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- __call__(x, t, **extras)[source]#
Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.
- Parameters:
x (Array) – \(x_t\), the input array.
t (Array) – The time array (denoted as \(r\) above).
**extras – Additional arguments for the model.
- Returns:
The transformed velocity.
- Return type:
Array
- new_scheduler#
- original_scheduler#
- class gensbi.flow_matching.path.scheduler.Scheduler[source]#
Bases:
abc.ABC
Base Scheduler class.
- class gensbi.flow_matching.path.scheduler.SchedulerOutput[source]#
Represents a sample of a conditional-flow generated probability path.
- alpha_t#
\(\alpha_t\), shape (…).
- Type:
Array
- sigma_t#
\(\sigma_t\), shape (…).
- Type:
Array
- d_alpha_t#
\(\frac{\partial}{\partial t}\alpha_t\), shape (…).
- Type:
Array
- d_sigma_t#
\(\frac{\partial}{\partial t}\sigma_t\), shape (…).
- Type:
Array
- alpha_t: jax.Array#
- d_alpha_t: jax.Array#
- d_sigma_t: jax.Array#
- sigma_t: jax.Array#
- class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0)[source]#
Bases:
Scheduler
Variance Preserving Scheduler.
- Parameters:
beta_min (float)
beta_max (float)
- __call__(t)[source]#
- Parameters:
t (Array) – times in [0,1], shape (…).
- Returns:
\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)
- Return type:
- snr_inverse(snr)[source]#
Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).
- Parameters:
snr (Array) – The signal-to-noise, shape (…)
- Returns:
t, shape (…)
- Return type:
Array
- beta_max = 20.0#
- beta_min = 0.1#