gensbi.diffusion.path.scheduler#
Submodules#
Classes#
Helper class that provides a standard way to create an ABC using |
|
Variance Exploding (VE) SDE scheduler as described in the EDM paper. |
|
Variance Preserving (VP) SDE scheduler as described in the EDM paper. |
Package Contents#
- class gensbi.diffusion.path.scheduler.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#
Bases:
BaseSDE
Helper class that provides a standard way to create an ABC using inheritance.
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- P_mean = -1.2#
- P_std = 1.2#
- property name#
Returns the name of the SDE scheduler.
- rho = 7#
- sigma_data = 1.0#
- sigma_max = 80.0#
- sigma_min = 0.002#
- class gensbi.diffusion.path.scheduler.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#
Bases:
BaseSDE
Variance Exploding (VE) SDE scheduler as described in the EDM paper.
- Parameters:
sigma_min (float) – Minimum sigma value.
sigma_max (float) – Maximum sigma value.
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- sigma_max = 15.0#
- sigma_min = 0.001#
- class gensbi.diffusion.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#
Bases:
BaseSDE
Variance Preserving (VP) SDE scheduler as described in the EDM paper.
- Parameters:
beta_min (float) – Minimum beta value.
beta_max (float) – Maximum beta value.
e_s (float) – Starting epsilon value for time schedule.
e_t (float) – Ending epsilon value for time schedule.
M (int) – Scaling factor for noise preconditioning.
References
Karras, Tero, et al. “Elucidating the design space of diffusion-based generative models.” arXiv:2206.00364
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- f(x, t)[source]#
Drift term for the forward diffusion process.
Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Drift term.
- Return type:
Array
- g(x, t)[source]#
Diffusion term for the forward diffusion process.
Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Diffusion term.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- sigma_inv(sigma)[source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- time_schedule(u)[source]#
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
- Parameters:
u (Array) – Uniform random variable in [0, 1].
- Returns:
Time in the schedule.
- Return type:
Array
- M = 1000#
- beta_d = 19.9#
- beta_max = 20.0#
- beta_min = 0.1#
- e_s = 0.001#
- e_t = 1e-05#
- property name#
Returns the name of the SDE scheduler.