import abc
import jax
import jax.numpy as jnp
from jax import Array
from typing import Callable, Any
# from .samplers import sampler #moved to samplers module
# we will create an abstract SDE class which can implement VP, VE, and EDM methods, following https://github.com/NVlabs/edm/
# we will then define a precondition function for each method
# TODO still need to test
[docs]
class BaseSDE(abc.ABC):
def __init__(self) -> None:
"""Base class for SDE schedulers."""
return
@property
@abc.abstractmethod
def name(self) -> str:
"""Returns the name of the SDE scheduler."""
...
[docs]
@abc.abstractmethod
def time_schedule(self, u: Array) -> Array:
"""
Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule.
Args:
u (Array): Uniform random variable in [0, 1].
Returns:
Array: Time in the schedule.
"""
...
[docs]
def timesteps(self, i: Array, N: int) -> Array:
"""
Compute the time steps for a given index array and total number of steps.
Args:
i (Array): Step indices.
N (int): Total number of steps.
Returns:
Array: Time steps.
"""
u = i / (N - 1)
return self.time_schedule(u)
[docs]
@abc.abstractmethod
def sigma(self, t: Array) -> Array:
"""
Returns the noise scale (schedule) at time t.
Args:
t (Array): Time.
Returns:
Array: Noise scale.
"""
...
[docs]
@abc.abstractmethod
def sigma_inv(self, sigma: Array) -> Array:
"""
Inverse of the noise scale function.
Args:
sigma (Array): Noise scale.
Returns:
Array: Time corresponding to the given sigma.
"""
...
[docs]
@abc.abstractmethod
def sigma_deriv(self, t: Array) -> Array:
"""
Derivative of the noise scale with respect to time.
Args:
t (Array): Time.
Returns:
Array: Derivative of sigma.
"""
...
[docs]
@abc.abstractmethod
def s(self, t: Array) -> Array:
"""
Scaling function as in EDM paper.
Args:
t (Array): Time.
Returns:
Array: Scaling value.
"""
...
[docs]
@abc.abstractmethod
def s_deriv(self, t: Array) -> Array:
"""
Derivative of the scaling function.
Args:
t (Array): Time.
Returns:
Array: Derivative of scaling.
"""
...
[docs]
@abc.abstractmethod
def c_skip(self, sigma: Array) -> Array:
"""
Preconditioning skip connection coefficient.
Args:
sigma (Array): Noise scale.
Returns:
Array: Skip coefficient.
"""
...
[docs]
@abc.abstractmethod
def c_out(self, sigma: Array) -> Array:
"""
Preconditioning output coefficient.
Args:
sigma (Array): Noise scale.
Returns:
Array: Output coefficient.
"""
...
[docs]
@abc.abstractmethod
def c_in(self, sigma: Array) -> Array:
"""
Preconditioning input coefficient.
Args:
sigma (Array): Noise scale.
Returns:
Array: Input coefficient.
"""
...
[docs]
@abc.abstractmethod
def c_noise(self, sigma: Array) -> Array:
"""
Preconditioning noise coefficient.
Args:
sigma (Array): Noise scale.
Returns:
Array: Noise coefficient.
"""
...
[docs]
@abc.abstractmethod
def sample_sigma(self, key: Array, shape: Any) -> Array:
"""
Sample sigma from the prior noise distribution.
Args:
key (Array): JAX random key.
shape (Any): Shape of the output.
Returns:
Array: Sampled sigma.
"""
...
[docs]
def sample_noise(self, key: Array, shape: Any, sigma: Array) -> Array:
"""
Sample noise from the prior noise distribution with noise scale sigma(t).
Args:
key (Array): JAX random key.
shape (Any): Shape of the output.
sigma (Array): Noise scale.
Returns:
Array: Sampled noise.
"""
n = jax.random.normal(key, shape) * sigma
return n
[docs]
def sample_prior(self, key: Array, shape: Any) -> Array:
"""
Sample x from the prior distribution.
Args:
key (Array): JAX random key.
shape (Any): Shape of the output.
Returns:
Array: Sampled prior.
"""
return jax.random.normal(key, shape)
[docs]
@abc.abstractmethod
def loss_weight(self, sigma: Array) -> Array:
"""
Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper.
Args:
sigma (Array): Noise scale.
Returns:
Array: Loss weight.
"""
...
[docs]
def f(self, x: Array, t: Array) -> Array:
r"""
Drift term for the forward diffusion process.
Computes the drift term :math:`f(x, t) = x \frac{ds}{dt} / s(t)` as used in the SDE formulation.
Args:
x (Array): Input data.
t (Array): Time.
Returns:
Array: Drift term.
"""
return x * self.s_deriv(t) / self.s(t)
[docs]
def g(self, x: Array, t: Array) -> Array:
r"""
Diffusion term for the forward diffusion process.
Computes the diffusion term :math:`g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}` as used in the SDE formulation.
Args:
x (Array): Input data.
t (Array): Time.
Returns:
Array: Diffusion term.
"""
return self.s(t) * jnp.sqrt(2 * self.sigma_deriv(t) * self.sigma(t))
[docs]
def denoise(self, F: Callable, x: Array, sigma: Array, *args, **kwargs) -> Array:
r"""
Denoise function, :math:`D` in the EDM paper, which shares a connection with the score function:
.. math::
\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}
This function includes the preconditioning and is connected to the NN objective :math:`F`:
.. math::
D_\theta(x; \sigma) = c_\text{skip}(\sigma) x + c_\text{out}(\sigma) F_\theta (c_\text{in}(\sigma) x; c_\text{noise}(\sigma))
Args:
F (Callable): Model function.
x (Array): Input data.
sigma (Array): Noise scale.
*args: Additional arguments.
**kwargs: Additional keyword arguments.
Returns:
Array: Denoised output.
"""
return self.c_skip(sigma) * x + self.c_out(sigma) * F(
self.c_in(sigma) * x, self.c_noise(sigma), *args, **kwargs
)
[docs]
def get_score_function(self, F: Callable) -> Callable:
r"""
Returns the score function :math:`\nabla_x \log p(x; \sigma)` as described in the EDM paper.
The score function is computed as:
.. math::
\nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2}
where :math:`D(x; \sigma)` is the denoised output (see `denoise` method).
Args:
F (Callable): Model function.
Returns:
Callable: Score function.
"""
def score(x: Array, u: Array, *args, **kwargs) -> Array:
t = self.time_schedule(u)
sigma = self.sigma(t)
return (self.denoise(F, x, sigma, *args, **kwargs) - x) / (sigma**2)
return score
[docs]
def get_loss_fn(self) -> Callable:
r"""
Returns the loss function for EDM training, as described in the EDM paper.
The loss is computed as (see Eq. 8 in the EDM paper):
.. math::
\lambda(\sigma) \, c_\text{out}^2(\sigma) \left[
F(c_\text{in}(\sigma) x_t, c_\text{noise}(\sigma), \ldots)
- \frac{1}{c_\text{out}(\sigma)} (x_1 - c_\text{skip}(\sigma) x_t)
\right]^2
Args:
None directly; returns a function that computes the loss.
Returns:
Callable: Loss function.
"""
def loss_fn(F: Callable, batch: tuple, loss_mask: Any = None, model_extras: dict = {}) -> Array:
(x_1, x_t, sigma) = batch
lam = self.loss_weight(sigma)
c_out = self.c_out(sigma)
c_in = self.c_in(sigma)
c_noise = self.c_noise(sigma)
c_skip = self.c_skip(sigma)
if loss_mask is not None:
loss_mask = jnp.broadcast_to(loss_mask, x_1.shape)
x_t = jnp.where(loss_mask, x_1, x_t)
loss = (
lam
* c_out**2
* (
F(c_in * (x_t), c_noise, **model_extras)
- 1 / c_out * (x_1 - c_skip * (x_t))
)
** 2
)
if loss_mask is not None:
loss = jnp.where(loss_mask, 0.0, loss)
# we sum the loss on any dimension that is not the batch dimentsion, and then we compute the mean over the batch dimension (the first)
return jnp.mean(jnp.sum(loss, axis=tuple(range(1, len(x_1.shape))))) # type: ignore
return loss_fn
[docs]
class VPScheduler(BaseSDE):
def __init__(self, beta_min=0.1, beta_max=20.0, e_s=1e-3, e_t=1e-5, M=1000):
super().__init__()
self.beta_min = beta_min
self.beta_max = beta_max
self.beta_d = beta_max - beta_min
self.e_s = e_s
self.e_t = e_t
self.M = M
return
@property
def name(self):
return "EDM-VP"
[docs]
def time_schedule(self, u):
return 1 + u * (self.e_s - 1)
[docs]
def sigma(self, t):
# also known as the schedule, as in tab 1 of EDM paper
return jnp.sqrt(jnp.exp(0.5 * self.beta_d * t**2 + self.beta_min * t) - 1)
[docs]
def sigma_inv(self, sigma):
return (
jnp.sqrt(self.beta_min**2 + 2 * self.beta_d * jnp.log(1 + sigma**2))
- self.beta_min
) / self.beta_d
[docs]
def sigma_deriv(self, t):
# also known as the schedule derivative
return (
0.5
* (self.beta_min + self.beta_d * t)
* (self.sigma(t) + 1 / self.sigma(t))
)
[docs]
def s(self, t):
# also known as scaling, as in tab 1 of EDM paper
return 1 / jnp.sqrt(jnp.exp(0.5 * self.beta_d * t**2 + self.beta_min * t))
[docs]
def s_deriv(self, t):
# also known as scaling derivative
return -self.sigma(t) * self.sigma_deriv(t) * (self.s(t) ** 3)
[docs]
def f(self, x, t):
# f(x, sigma) in the SDE, also known as drift term for the forward diffusion process
return -x * 0.5 * (self.beta_min + self.beta_d * t)
[docs]
def g(self, x, t):
# g(sigma) in the SDE, also known as diffusion term for the forward diffusion process
return jnp.sqrt(self.beta_min + self.beta_d * t)
[docs]
def c_skip(self, sigma):
# c_skip for preconditioning
return jnp.ones_like(sigma)
[docs]
def c_out(self, sigma):
# c_out for preconditioning
return -sigma
[docs]
def c_in(self, sigma):
# c_in for preconditioning
return 1 / jnp.sqrt(sigma**2 + 1)
[docs]
def c_noise(self, sigma):
# c_noise for preconditioning
return (self.M - 1) * self.sigma_inv(sigma)
[docs]
def loss_weight(self, sigma):
return 1 / sigma**2
[docs]
def sample_sigma(self, key, shape):
# sample sigma from the prior noise distribution
u = jax.random.uniform(key, shape, minval=self.e_t, maxval=1)
return self.sigma(u)
[docs]
class VEScheduler(BaseSDE):
def __init__(self, sigma_min=1e-3, sigma_max=15.0):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
return
@property
def name(self):
return "EDM-VE"
[docs]
def time_schedule(self, u):
return self.sigma_max**2 * (self.sigma_min / self.sigma_max) ** (2 * u)
[docs]
def sigma(self, t):
# also known as the schedule, as in tab 1 of EDM paper
return jnp.sqrt(t)
[docs]
def sigma_inv(self, sigma):
return sigma**2
[docs]
def sigma_deriv(self, t):
return 1 / (2 * jnp.sqrt(t))
[docs]
def s(self, t):
# also known as scaling, as in tab 1 of EDM paper
return jnp.ones_like(t)
[docs]
def s_deriv(self, t):
# also known as scaling derivative
return jnp.zeros_like(t)
[docs]
def c_skip(self, sigma):
# c_skip for preconditioning
return jnp.ones_like(sigma)
[docs]
def c_out(self, sigma):
# c_out for preconditioning
return sigma
[docs]
def c_in(self, sigma):
# c_in for preconditioning
return jnp.ones_like(sigma)
[docs]
def c_noise(self, sigma):
# c_noise for preconditioning
return jnp.log(0.5 * sigma)
[docs]
def loss_weight(self, sigma):
return 1 / sigma**2
[docs]
def sample_sigma(self, key, shape):
# sample sigma from the prior noise distribution
log_sigma = jax.random.uniform(
key, shape, minval=jnp.log(self.sigma_min), maxval=jnp.log(self.sigma_max)
)
return jnp.exp(log_sigma)
[docs]
class EDMScheduler(BaseSDE):
def __init__(
self,
sigma_min=0.002,
sigma_max=80.0,
sigma_data=1.0,
rho=7,
P_mean=-1.2,
P_std=1.2,
):
super().__init__()
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.sigma_data = sigma_data
self.rho = rho
self.P_mean = P_mean
self.P_std = P_std
return
@property
def name(self):
return "EDM"
[docs]
def time_schedule(self, u):
return (
self.sigma_max ** (1 / self.rho)
+ u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho))
) ** self.rho
[docs]
def sigma(self, t):
# also known as the schedule, as in tab 1 of EDM paper
return t
[docs]
def sigma_inv(self, sigma):
return sigma
[docs]
def sigma_deriv(self, t):
# also known as the schedule derivative
return jnp.ones_like(t)
[docs]
def s(self, t):
# also known as scaling, as in tab 1 of EDM paper
return jnp.ones_like(t)
[docs]
def s_deriv(self, t):
# also known as scaling derivative
return jnp.zeros_like(t)
[docs]
def c_skip(self, sigma):
# c_skip for preconditioning
return self.sigma_data**2 / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs]
def c_out(self, sigma):
# c_out for preconditioning
return sigma * self.sigma_data / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs]
def c_in(self, sigma):
# c_in for preconditioning
return 1 / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs]
def c_noise(self, sigma):
# c_noise for preconditioning
return 0.25 * jnp.log(sigma)
[docs]
def loss_weight(self, sigma):
# weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper
return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
[docs]
def sample_sigma(self, key, shape):
# sample sigma from the prior noise distribution, in this case it is not anymore a uniform distribution, see https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L66
rnd_normal = jax.random.normal(key, shape)
sigma = jnp.exp(rnd_normal * self.P_std + self.P_mean)
return sigma