gensbi.diffusion.path.scheduler.edm#

Classes#

BaseSDE

Helper class that provides a standard way to create an ABC using

EDMScheduler

Helper class that provides a standard way to create an ABC using

VEScheduler

Variance Exploding (VE) SDE scheduler as described in the EDM paper.

VPScheduler

Variance Preserving (VP) SDE scheduler as described in the EDM paper.

Module Contents#

class gensbi.diffusion.path.scheduler.edm.BaseSDE[source]#

Bases: abc.ABC

Helper class that provides a standard way to create an ABC using inheritance.

abstract c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

abstract c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

abstract c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

abstract c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

denoise(F, x, sigma, *args, **kwargs)[source]#

Denoise function, \(D\) in the EDM paper, which shares a connection with the score function:

\[\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}\]

This function includes the preconditioning and is connected to the NN objective \(F\):

\[D_\theta(x; \sigma) = c_\text{skip}(\sigma) x + c_\text{out}(\sigma) F_\theta (c_\text{in}(\sigma) x; c_\text{noise}(\sigma))\]
Parameters:
  • F (Callable) – Model function.

  • x (Array) – Input data.

  • sigma (Array) – Noise scale.

  • *args – Additional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Denoised output.

Return type:

Array

f(x, t)[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x, t)[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

get_loss_fn()[source]#

Returns the loss function for EDM training, as described in the EDM paper.

The loss is computed as (see Eq. 8 in the EDM paper):

\[\lambda(\sigma) \, c_\text{out}^2(\sigma) \left[ F(c_\text{in}(\sigma) x_t, c_\text{noise}(\sigma), \ldots) - \frac{1}{c_\text{out}(\sigma)} (x_1 - c_\text{skip}(\sigma) x_t) \right]^2\]
Parameters:

loss. (None directly; returns a function that computes the)

Returns:

Loss function.

Return type:

Callable

abstract loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

abstract s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

abstract s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_noise(key, shape, sigma)[source]#

Sample noise from the prior noise distribution with noise scale sigma(t).

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

  • sigma (Array) – Noise scale.

Returns:

Sampled noise.

Return type:

Array

sample_prior(key, shape)[source]#

Sample x from the prior distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled prior.

Return type:

Array

abstract sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

abstract sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

abstract sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

abstract sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

abstract time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

timesteps(i, N)[source]#

Compute the time steps for a given index array and total number of steps.

Parameters:
  • i (Array) – Step indices.

  • N (int) – Total number of steps.

Returns:

Time steps.

Return type:

Array

property name: str[source]#
Abstractmethod:

Return type:

str

Returns the name of the SDE scheduler.

class gensbi.diffusion.path.scheduler.edm.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#

Bases: BaseSDE

Helper class that provides a standard way to create an ABC using inheritance.

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

P_mean = -1.2[source]#
P_std = 1.2[source]#
property name[source]#

Returns the name of the SDE scheduler.

rho = 7[source]#
sigma_data = 1.0[source]#
sigma_max = 80.0[source]#
sigma_min = 0.002[source]#
class gensbi.diffusion.path.scheduler.edm.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#

Bases: BaseSDE

Variance Exploding (VE) SDE scheduler as described in the EDM paper.

Parameters:
  • sigma_min (float) – Minimum sigma value.

  • sigma_max (float) – Maximum sigma value.

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

property name[source]#

Returns the name of the SDE scheduler.

sigma_max = 15.0[source]#
sigma_min = 0.001[source]#
class gensbi.diffusion.path.scheduler.edm.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#

Bases: BaseSDE

Variance Preserving (VP) SDE scheduler as described in the EDM paper.

Parameters:
  • beta_min (float) – Minimum beta value.

  • beta_max (float) – Maximum beta value.

  • e_s (float) – Starting epsilon value for time schedule.

  • e_t (float) – Ending epsilon value for time schedule.

  • M (int) – Scaling factor for noise preconditioning.

References

  • Karras, Tero, et al. “Elucidating the design space of diffusion-based generative models.” arXiv:2206.00364

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

f(x, t)[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x, t)[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

M = 1000[source]#
beta_d = 19.9[source]#
beta_max = 20.0[source]#
beta_min = 0.1[source]#
e_s = 0.001[source]#
e_t = 1e-05[source]#
property name[source]#

Returns the name of the SDE scheduler.