gensbi.diffusion.path.scheduler package#
Submodules#
gensbi.diffusion.path.scheduler.edm module#
- class gensbi.diffusion.path.scheduler.edm.BaseSDE[source]#
Bases:
ABC
- abstractmethod c_in(sigma: Array) Array [source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- abstractmethod c_noise(sigma: Array) Array [source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- abstractmethod c_out(sigma: Array) Array [source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- abstractmethod c_skip(sigma: Array) Array [source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- denoise(F: Callable, x: Array, sigma: Array, *args, **kwargs) Array [source]#
Denoise function, \(D\) in the EDM paper, which shares a connection with the score function:
\[\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}\]This function includes the preconditioning and is connected to the NN objective \(F\):
\[D_\theta(x; \sigma) = c_\text{skip}(\sigma) x + c_\text{out}(\sigma) F_\theta (c_\text{in}(\sigma) x; c_\text{noise}(\sigma))\]- Parameters:
F (Callable) – Model function.
x (Array) – Input data.
sigma (Array) – Noise scale.
*args – Additional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Denoised output.
- Return type:
Array
- f(x: Array, t: Array) Array [source]#
Drift term for the forward diffusion process.
Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Drift term.
- Return type:
Array
- g(x: Array, t: Array) Array [source]#
Diffusion term for the forward diffusion process.
Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Diffusion term.
- Return type:
Array
- get_loss_fn() Callable [source]#
Returns the loss function for EDM training, as described in the EDM paper.
The loss is computed as (see Eq. 8 in the EDM paper):
\[\lambda(\sigma) \, c_\text{out}^2(\sigma) \left[ F(c_\text{in}(\sigma) x_t, c_\text{noise}(\sigma), \ldots) - \frac{1}{c_\text{out}(\sigma)} (x_1 - c_\text{skip}(\sigma) x_t) \right]^2\]- Parameters:
loss. (None directly; returns a function that computes the)
- Returns:
Loss function.
- Return type:
Callable
- get_score_function(F: Callable) Callable [source]#
Returns the score function \(\nabla_x \log p(x; \sigma)\) as described in the EDM paper.
The score function is computed as:
\[\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}\]where \(D(x; \sigma)\) is the denoised output (see denoise method).
- Parameters:
F (Callable) – Model function.
- Returns:
Score function.
- Return type:
Callable
- abstractmethod loss_weight(sigma: Array) Array [source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- abstract property name: str#
Returns the name of the SDE scheduler.
- abstractmethod s(t: Array) Array [source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- abstractmethod s_deriv(t: Array) Array [source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_noise(key: Array, shape: Any, sigma: Array) Array [source]#
Sample noise from the prior noise distribution with noise scale sigma(t).
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
sigma (Array) – Noise scale.
- Returns:
Sampled noise.
- Return type:
Array
- sample_prior(key: Array, shape: Any) Array [source]#
Sample x from the prior distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled prior.
- Return type:
Array
- abstractmethod sample_sigma(key: Array, shape: Any) Array [source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- abstractmethod sigma(t: Array) Array [source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- abstractmethod sigma_deriv(t: Array) Array [source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- abstractmethod sigma_inv(sigma: Array) Array [source]#
Inverse of the noise scale function.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Time corresponding to the given sigma.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.edm.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.edm.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.edm.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- f(x, t)[source]#
Drift term for the forward diffusion process.
Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Drift term.
- Return type:
Array
- g(x, t)[source]#
Diffusion term for the forward diffusion process.
Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Diffusion term.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
Module contents#
- class gensbi.diffusion.path.scheduler.EDMScheduler(sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.VEScheduler(sigma_min=0.001, sigma_max=15.0)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array
- class gensbi.diffusion.path.scheduler.VPScheduler(beta_min=0.1, beta_max=20.0, e_s=0.001, e_t=1e-05, M=1000)[source]#
Bases:
BaseSDE
- c_in(sigma)[source]#
Preconditioning input coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Input coefficient.
- Return type:
Array
- c_noise(sigma)[source]#
Preconditioning noise coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Noise coefficient.
- Return type:
Array
- c_out(sigma)[source]#
Preconditioning output coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Output coefficient.
- Return type:
Array
- c_skip(sigma)[source]#
Preconditioning skip connection coefficient.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Skip coefficient.
- Return type:
Array
- f(x, t)[source]#
Drift term for the forward diffusion process.
Computes the drift term \(f(x, t) = x \frac{ds}{dt} / s(t)\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Drift term.
- Return type:
Array
- g(x, t)[source]#
Diffusion term for the forward diffusion process.
Computes the diffusion term \(g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}\) as used in the SDE formulation.
- Parameters:
x (Array) – Input data.
t (Array) – Time.
- Returns:
Diffusion term.
- Return type:
Array
- loss_weight(sigma)[source]#
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
- Parameters:
sigma (Array) – Noise scale.
- Returns:
Loss weight.
- Return type:
Array
- property name#
Returns the name of the SDE scheduler.
- s(t)[source]#
Scaling function as in EDM paper.
- Parameters:
t (Array) – Time.
- Returns:
Scaling value.
- Return type:
Array
- s_deriv(t)[source]#
Derivative of the scaling function.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of scaling.
- Return type:
Array
- sample_sigma(key, shape)[source]#
Sample sigma from the prior noise distribution.
- Parameters:
key (Array) – JAX random key.
shape (Any) – Shape of the output.
- Returns:
Sampled sigma.
- Return type:
Array
- sigma(t)[source]#
Returns the noise scale (schedule) at time t.
- Parameters:
t (Array) – Time.
- Returns:
Noise scale.
- Return type:
Array
- sigma_deriv(t)[source]#
Derivative of the noise scale with respect to time.
- Parameters:
t (Array) – Time.
- Returns:
Derivative of sigma.
- Return type:
Array