gensbi.diffusion.path.scheduler package#

Submodules#

gensbi.diffusion.path.scheduler.edm module#

class gensbi.diffusion.path.scheduler.edm.BaseSDE[source]#

Bases: ABC

abstractmethod c_in(sigma: Array) Array[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

abstractmethod c_noise(sigma: Array) Array[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

abstractmethod c_out(sigma: Array) Array[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

abstractmethod c_skip(sigma: Array) Array[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

denoise(F: Callable, x: Array, sigma: Array, *args, **kwargs) Array[source]#

Denoise function, \(D\) in the EDM paper, which shares a connection with the score function:

\[\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}\]

This function includes the preconditioning and is connected to the NN objective \(F\):

\[D_\theta(x; \sigma) = c_\text{skip}(\sigma) x + c_\text{out}(\sigma) F_\theta (c_\text{in}(\sigma) x; c_\text{noise}(\sigma))\]
Parameters:
  • F (Callable) – Model function.

  • x (Array) – Input data.

  • sigma (Array) – Noise scale.

  • *args – Additional arguments.

  • **kwargs – Additional keyword arguments.

Returns:

Denoised output.

Return type:

Array

f(x: Array, t: Array) Array[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x: Array, t: Array) Array[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

get_loss_fn() Callable[source]#

Returns the loss function for EDM training, as described in the EDM paper.

The loss is computed as (see Eq. 8 in the EDM paper):

\[\lambda(\sigma) \, c_\text{out}^2(\sigma) \left[ F(c_\text{in}(\sigma) x_t, c_\text{noise}(\sigma), \ldots) - \frac{1}{c_\text{out}(\sigma)} (x_1 - c_\text{skip}(\sigma) x_t) \right]^2\]
Parameters:

loss. (None directly; returns a function that computes the)

Returns:

Loss function.

Return type:

Callable

get_score_function(F: Callable) Callable[source]#

Returns the score function \(\nabla_x \log p(x; \sigma)\) as described in the EDM paper.

The score function is computed as:

\[\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}\]

where \(D(x; \sigma)\) is the denoised output (see denoise method).

Parameters:

F (Callable) – Model function.

Returns:

Score function.

Return type:

Callable

abstractmethod loss_weight(sigma: Array) Array[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

abstract property name: str#

Returns the name of the SDE scheduler.

abstractmethod s(t: Array) Array[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

abstractmethod s_deriv(t: Array) Array[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_noise(key: Array, shape: Any, sigma: Array) Array[source]#

Sample noise from the prior noise distribution with noise scale sigma(t).

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

  • sigma (Array) – Noise scale.

Returns:

Sampled noise.

Return type:

Array

sample_prior(key: Array, shape: Any) Array[source]#

Sample x from the prior distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled prior.

Return type:

Array

abstractmethod sample_sigma(key: Array, shape: Any) Array[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

abstractmethod sigma(t: Array) Array[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

abstractmethod sigma_deriv(t: Array) Array[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

abstractmethod sigma_inv(sigma: Array) Array[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

abstractmethod time_schedule(u: Array) Array[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

timesteps(i: Array, N: int) Array[source]#

Compute the time steps for a given index array and total number of steps.

Parameters:
  • i (Array) – Step indices.

  • N (int) – Total number of steps.

Returns:

Time steps.

Return type:

Array

class gensbi.diffusion.path.scheduler.edm.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

class gensbi.diffusion.path.scheduler.edm.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

class gensbi.diffusion.path.scheduler.edm.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

f(x, t)[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x, t)[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

Module contents#

class gensbi.diffusion.path.scheduler.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

class gensbi.diffusion.path.scheduler.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array

class gensbi.diffusion.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#

Bases: BaseSDE

c_in(sigma)[source]#

Preconditioning input coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Input coefficient.

Return type:

Array

c_noise(sigma)[source]#

Preconditioning noise coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Noise coefficient.

Return type:

Array

c_out(sigma)[source]#

Preconditioning output coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Output coefficient.

Return type:

Array

c_skip(sigma)[source]#

Preconditioning skip connection coefficient.

Parameters:

sigma (Array) – Noise scale.

Returns:

Skip coefficient.

Return type:

Array

f(x, t)[source]#

Drift term for the forward diffusion process.

Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Drift term.

Return type:

Array

g(x, t)[source]#

Diffusion term for the forward diffusion process.

Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.

Parameters:
  • x (Array) – Input data.

  • t (Array) – Time.

Returns:

Diffusion term.

Return type:

Array

loss_weight(sigma)[source]#

Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.

Parameters:

sigma (Array) – Noise scale.

Returns:

Loss weight.

Return type:

Array

property name#

Returns the name of the SDE scheduler.

s(t)[source]#

Scaling function as in EDM paper.

Parameters:

t (Array) – Time.

Returns:

Scaling value.

Return type:

Array

s_deriv(t)[source]#

Derivative of the scaling function.

Parameters:

t (Array) – Time.

Returns:

Derivative of scaling.

Return type:

Array

sample_sigma(key, shape)[source]#

Sample sigma from the prior noise distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the output.

Returns:

Sampled sigma.

Return type:

Array

sigma(t)[source]#

Returns the noise scale (schedule) at time t.

Parameters:

t (Array) – Time.

Returns:

Noise scale.

Return type:

Array

sigma_deriv(t)[source]#

Derivative of the noise scale with respect to time.

Parameters:

t (Array) – Time.

Returns:

Derivative of sigma.

Return type:

Array

sigma_inv(sigma)[source]#

Inverse of the noise scale function.

Parameters:

sigma (Array) – Noise scale.

Returns:

Time corresponding to the given sigma.

Return type:

Array

time_schedule(u)[source]#

Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.

Parameters:

u (Array) – Uniform random variable in [0, 1].

Returns:

Time in the schedule.

Return type:

Array