gensbi.flow_matching.path.scheduler package#
Submodules#
gensbi.flow_matching.path.scheduler.schedule_transform module#
- class gensbi.flow_matching.path.scheduler.schedule_transform.ScheduleTransformedModel(*args: Any, **kwargs: Any)[source]#
Bases:
ModelWrapper
Change of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example:
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
gensbi.flow_matching.path.scheduler.scheduler module#
- class gensbi.flow_matching.path.scheduler.scheduler.CondOTScheduler[source]#
Bases:
ConvexScheduler
CondOT Scheduler.
- class gensbi.flow_matching.path.scheduler.scheduler.ConvexScheduler[source]#
Bases:
Scheduler
- class gensbi.flow_matching.path.scheduler.scheduler.CosineScheduler[source]#
Bases:
Scheduler
Cosine Scheduler.
- class gensbi.flow_matching.path.scheduler.scheduler.LinearVPScheduler[source]#
Bases:
Scheduler
Linear Variance Preserving Scheduler.
- class gensbi.flow_matching.path.scheduler.scheduler.PolynomialConvexScheduler(n: float | int)[source]#
Bases:
ConvexScheduler
Polynomial Scheduler.
- class gensbi.flow_matching.path.scheduler.scheduler.Scheduler[source]#
Bases:
ABC
Base Scheduler class.
- class gensbi.flow_matching.path.scheduler.scheduler.SchedulerOutput(alpha_t: Array, sigma_t: Array, d_alpha_t: Array, d_sigma_t: Array)[source]#
Bases:
object
Represents a sample of a conditional-flow generated probability path.
- alpha_t#
\(\alpha_t\), shape (…).
- Type:
Array
- sigma_t#
\(\sigma_t\), shape (…).
- Type:
Array
- d_alpha_t#
\(\frac{\partial}{\partial t}\alpha_t\), shape (…).
- Type:
Array
- d_sigma_t#
\(\frac{\partial}{\partial t}\sigma_t\), shape (…).
- Type:
Array
- alpha_t: Array#
- d_alpha_t: Array#
- d_sigma_t: Array#
- sigma_t: Array#
Module contents#
- class gensbi.flow_matching.path.scheduler.CondOTScheduler[source]#
Bases:
ConvexScheduler
CondOT Scheduler.
- class gensbi.flow_matching.path.scheduler.ConvexScheduler[source]#
Bases:
Scheduler
- class gensbi.flow_matching.path.scheduler.CosineScheduler[source]#
Bases:
Scheduler
Cosine Scheduler.
- class gensbi.flow_matching.path.scheduler.LinearVPScheduler[source]#
Bases:
Scheduler
Linear Variance Preserving Scheduler.
- class gensbi.flow_matching.path.scheduler.PolynomialConvexScheduler(n: float | int)[source]#
Bases:
ConvexScheduler
Polynomial Scheduler.
- class gensbi.flow_matching.path.scheduler.ScheduleTransformedModel(*args: Any, **kwargs: Any)[source]#
Bases:
ModelWrapper
Change of scheduler for a velocity model.
This class wraps a given velocity model and transforms its scheduling to a new scheduler function. It modifies the time dynamics of the model according to the new scheduler while maintaining the original model’s behavior.
Example:
import jax import jax.numpy as jnp from flow_matching.path.scheduler import CondOTScheduler, CosineScheduler, ScheduleTransformedModel from flow_matching.solver import ODESolver # Initialize the model and schedulers model = ... original_scheduler = CondOTScheduler() new_scheduler = CosineScheduler() # Create the transformed model transformed_model = ScheduleTransformedModel( velocity_model=model, original_scheduler=original_scheduler, new_scheduler=new_scheduler ) # Set up the solver solver = ODESolver(velocity_model=transformed_model) key = jax.random.PRNGKey(0) x_0 = jax.random.normal(key, shape=(10, 2)) # Example initial condition x_1 = solver.sample( time_steps=jnp.array([0.0, 1.0]), x_init=x_0, step_size=1/1000 )[1]
- Parameters:
velocity_model (ModelWrapper) – The original velocity model to be transformed.
original_scheduler (Scheduler) – The scheduler used by the original model. Must implement the snr_inverse function.
new_scheduler (Scheduler) – The new scheduler to be applied to the model.
- class gensbi.flow_matching.path.scheduler.SchedulerOutput(alpha_t: Array, sigma_t: Array, d_alpha_t: Array, d_sigma_t: Array)[source]#
Bases:
object
Represents a sample of a conditional-flow generated probability path.
- alpha_t#
\(\alpha_t\), shape (…).
- Type:
Array
- sigma_t#
\(\sigma_t\), shape (…).
- Type:
Array
- d_alpha_t#
\(\frac{\partial}{\partial t}\alpha_t\), shape (…).
- Type:
Array
- d_sigma_t#
\(\frac{\partial}{\partial t}\sigma_t\), shape (…).
- Type:
Array
- alpha_t: Array#
- d_alpha_t: Array#
- d_sigma_t: Array#
- sigma_t: Array#