gensbi.flow_matching.path.scheduler#
Submodules#
Classes#
CondOT Scheduler. |
|
Base Scheduler class. |
|
Cosine Scheduler. |
|
Linear Variance Preserving Scheduler. |
|
Polynomial Scheduler. |
|
Change of scheduler for a velocity model. |
|
Base Scheduler class. |
|
Represents a sample of a conditional-flow generated probability path. |
|
Variance Preserving Scheduler. |
Package Contents#
- class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#
Bases:
ConvexSchedulerCondOT Scheduler.
- class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#
Bases:
SchedulerBase Scheduler class.
- abstractmethod __call__(t)[source]#
Scheduler for convex paths.
- Args:
t (Array): times in [0,1], shape (…).
- Returns:
SchedulerOutput: :math:`lpha_t,sigma_t,
rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`
- Parameters:
t (jax.Array)
- Return type:
- class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#
Bases:
SchedulerCosine Scheduler.
- class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#
Bases:
SchedulerLinear Variance Preserving Scheduler.
- class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n)[source]#
Bases:
ConvexSchedulerPolynomial Scheduler.
- Parameters:
n (Union[float, int])
- __call__(t)[source]#
Scheduler for convex paths.
- Args:
t (Array): times in [0,1], shape (…).
- Returns:
SchedulerOutput: :math:`lpha_t,sigma_t,
rac{partial}{partial t}lpha_t, rac{partial}{partial t}sigma_t`
- Parameters:
t (jax.Array)
- Return type:
- kappa_inverse(kappa)[source]#
Computes \(t\) from \(\kappa_t\).
- Parameters:
kappa (Array) – \(\kappa\), shape (…)
- Returns:
t, shape (…)
- Return type:
Array
- n#
- class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(velocity_model, original_scheduler, new_scheduler)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperChange of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- __call__(x, t, **extras)[source]#
Compute the transformed marginal velocity field for a new scheduler. This method implements a post-training velocity scheduler change for affine conditional flows.
- Parameters:
x (Array) – \(x_t\), the input array.
t (Array) – The time array (denoted as \(r\) above).
**extras – Additional arguments for the model.
- Returns:
The transformed velocity.
- Return type:
Array
- new_scheduler#
- original_scheduler#
- class gensbi.flow_matching.path.scheduler.Scheduler[source]#
Bases:
abc.ABCBase Scheduler class.
- class gensbi.flow_matching.path.scheduler.SchedulerOutput[source]#
Represents a sample of a conditional-flow generated probability path.
- alpha_t#
\(\alpha_t\), shape (…).
- Type:
Array
- sigma_t#
\(\sigma_t\), shape (…).
- Type:
Array
- d_alpha_t#
\(\frac{\partial}{\partial t}\alpha_t\), shape (…).
- Type:
Array
- d_sigma_t#
\(\frac{\partial}{\partial t}\sigma_t\), shape (…).
- Type:
Array
- alpha_t: jax.Array#
- d_alpha_t: jax.Array#
- d_sigma_t: jax.Array#
- sigma_t: jax.Array#
- class gensbi.flow_matching.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0)[source]#
Bases:
SchedulerVariance Preserving Scheduler.
- Parameters:
beta_min (float)
beta_max (float)
- __call__(t)[source]#
- Parameters:
t (Array) – times in [0,1], shape (…).
- Returns:
\(\alpha_t,\sigma_t,\frac{\partial}{\partial t}\alpha_t,\frac{\partial}{\partial t}\sigma_t\)
- Return type:
- snr_inverse(snr)[source]#
Computes \(t\) from the signal-to-noise ratio \(\frac{\alpha_t}{\sigma_t}\).
- Parameters:
snr (Array) – The signal-to-noise, shape (…)
- Returns:
t, shape (…)
- Return type:
Array
- beta_max = 20.0#
- beta_min = 0.1#