gensbi.diffusion.path package#
Subpackages#
- gensbi.diffusion.path.scheduler package
- Submodules
- gensbi.diffusion.path.scheduler.edm module
BaseSDE
BaseSDE.c_in()
BaseSDE.c_noise()
BaseSDE.c_out()
BaseSDE.c_skip()
BaseSDE.denoise()
BaseSDE.f()
BaseSDE.g()
BaseSDE.get_loss_fn()
BaseSDE.get_score_function()
BaseSDE.loss_weight()
BaseSDE.name
BaseSDE.s()
BaseSDE.s_deriv()
BaseSDE.sample_noise()
BaseSDE.sample_prior()
BaseSDE.sample_sigma()
BaseSDE.sigma()
BaseSDE.sigma_deriv()
BaseSDE.sigma_inv()
BaseSDE.time_schedule()
BaseSDE.timesteps()
EDMScheduler
EDMScheduler.c_in()
EDMScheduler.c_noise()
EDMScheduler.c_out()
EDMScheduler.c_skip()
EDMScheduler.loss_weight()
EDMScheduler.name
EDMScheduler.s()
EDMScheduler.s_deriv()
EDMScheduler.sample_sigma()
EDMScheduler.sigma()
EDMScheduler.sigma_deriv()
EDMScheduler.sigma_inv()
EDMScheduler.time_schedule()
VEScheduler
VPScheduler
VPScheduler.c_in()
VPScheduler.c_noise()
VPScheduler.c_out()
VPScheduler.c_skip()
VPScheduler.f()
VPScheduler.g()
VPScheduler.loss_weight()
VPScheduler.name
VPScheduler.s()
VPScheduler.s_deriv()
VPScheduler.sample_sigma()
VPScheduler.sigma()
VPScheduler.sigma_deriv()
VPScheduler.sigma_inv()
VPScheduler.time_schedule()
- Module contents
EDMScheduler
EDMScheduler.c_in()
EDMScheduler.c_noise()
EDMScheduler.c_out()
EDMScheduler.c_skip()
EDMScheduler.loss_weight()
EDMScheduler.name
EDMScheduler.s()
EDMScheduler.s_deriv()
EDMScheduler.sample_sigma()
EDMScheduler.sigma()
EDMScheduler.sigma_deriv()
EDMScheduler.sigma_inv()
EDMScheduler.time_schedule()
VEScheduler
VPScheduler
VPScheduler.c_in()
VPScheduler.c_noise()
VPScheduler.c_out()
VPScheduler.c_skip()
VPScheduler.f()
VPScheduler.g()
VPScheduler.loss_weight()
VPScheduler.name
VPScheduler.s()
VPScheduler.s_deriv()
VPScheduler.sample_sigma()
VPScheduler.sigma()
VPScheduler.sigma_deriv()
VPScheduler.sigma_inv()
VPScheduler.time_schedule()
Submodules#
gensbi.diffusion.path.edm_path module#
- class gensbi.diffusion.path.edm_path.EDMPath(scheduler)[source]#
Bases:
ProbPath
- get_loss_fn() Callable [source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(key: Array, x_1: Array, sigma: Array) EDMPathSample [source]#
Sample from the EDM probability path.
- Parameters:
key (Array) – JAX random key.
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type:
gensbi.diffusion.path.path module#
- class gensbi.diffusion.path.path.ProbPath(scheduler: Any)[source]#
Bases:
ABC
- property name: str#
Returns the name of the scheduler.
- Returns:
Scheduler name.
- Return type:
str
- abstractmethod sample(*args: Any, **kwargs: Any) EDMPathSample [source]#
Abstract method to sample from the probability path.
- Returns:
Sample from the path.
- Return type:
gensbi.diffusion.path.path_sample module#
- class gensbi.diffusion.path.path_sample.EDMPathSample(x_1: Array, sigma: Array, x_t: Array)[source]#
Bases:
object
Represents a sample of a diffusion generated probability path.
- x_1#
the target sample \(X_1\).
- Type:
Array
- sigma#
the noise scale \(t\).
- Type:
Array
- x_t#
samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).
- Type:
Array
- get_batch() Tuple[Array, Array, Array] [source]#
Returns the batch as a tuple (x_1, x_t, sigma).
- Returns:
The target sample, the noisy sample, and the noise scale.
- Return type:
Tuple[Array, Array, Array]
- sigma: Array#
- x_1: Array#
- x_t: Array#
Module contents#
- class gensbi.diffusion.path.EDMPath(scheduler)[source]#
Bases:
ProbPath
- get_loss_fn() Callable [source]#
Returns the loss function for the EDM path.
- Returns:
The loss function as provided by the scheduler.
- Return type:
Callable
- sample(key: Array, x_1: Array, sigma: Array) EDMPathSample [source]#
Sample from the EDM probability path.
- Parameters:
key (Array) – JAX random key.
x_1 (Array) – Target data point, shape (batch_size, …).
sigma (Array) – Noise scale, shape (batch_size, …).
- Returns:
A sample from the EDM path.
- Return type: