Source code for gensbi.diffusion.path.scheduler.edm

import abc
import jax
import jax.numpy as jnp
from jax import Array
from typing import Callable, Any

# from .samplers import sampler #moved to samplers module

# we will create an abstract SDE class which can implement VP, VE, and EDM methods, following https://github.com/NVlabs/edm/
# we will then define a precondition function for each method

# TODO still need to test


[docs] class BaseSDE(abc.ABC): def __init__(self) -> None: """Base class for SDE schedulers.""" return @property @abc.abstractmethod def name(self) -> str: """Returns the name of the SDE scheduler.""" ...
[docs] @abc.abstractmethod def time_schedule(self, u: Array) -> Array: """ Given the value of the random uniform variable u ~ U(0,1), return the time t in the schedule. Args: u (Array): Uniform random variable in [0, 1]. Returns: Array: Time in the schedule. """ ...
[docs] def timesteps(self, i: Array, N: int) -> Array: """ Compute the time steps for a given index array and total number of steps. Args: i (Array): Step indices. N (int): Total number of steps. Returns: Array: Time steps. """ u = i / (N - 1) return self.time_schedule(u)
[docs] @abc.abstractmethod def sigma(self, t: Array) -> Array: """ Returns the noise scale (schedule) at time t. Args: t (Array): Time. Returns: Array: Noise scale. """ ...
[docs] @abc.abstractmethod def sigma_inv(self, sigma: Array) -> Array: """ Inverse of the noise scale function. Args: sigma (Array): Noise scale. Returns: Array: Time corresponding to the given sigma. """ ...
[docs] @abc.abstractmethod def sigma_deriv(self, t: Array) -> Array: """ Derivative of the noise scale with respect to time. Args: t (Array): Time. Returns: Array: Derivative of sigma. """ ...
[docs] @abc.abstractmethod def s(self, t: Array) -> Array: """ Scaling function as in EDM paper. Args: t (Array): Time. Returns: Array: Scaling value. """ ...
[docs] @abc.abstractmethod def s_deriv(self, t: Array) -> Array: """ Derivative of the scaling function. Args: t (Array): Time. Returns: Array: Derivative of scaling. """ ...
[docs] @abc.abstractmethod def c_skip(self, sigma: Array) -> Array: """ Preconditioning skip connection coefficient. Args: sigma (Array): Noise scale. Returns: Array: Skip coefficient. """ ...
[docs] @abc.abstractmethod def c_out(self, sigma: Array) -> Array: """ Preconditioning output coefficient. Args: sigma (Array): Noise scale. Returns: Array: Output coefficient. """ ...
[docs] @abc.abstractmethod def c_in(self, sigma: Array) -> Array: """ Preconditioning input coefficient. Args: sigma (Array): Noise scale. Returns: Array: Input coefficient. """ ...
[docs] @abc.abstractmethod def c_noise(self, sigma: Array) -> Array: """ Preconditioning noise coefficient. Args: sigma (Array): Noise scale. Returns: Array: Noise coefficient. """ ...
[docs] @abc.abstractmethod def sample_sigma(self, key: Array, shape: Any) -> Array: """ Sample sigma from the prior noise distribution. Args: key (Array): JAX random key. shape (Any): Shape of the output. Returns: Array: Sampled sigma. """ ...
[docs] def sample_noise(self, key: Array, shape: Any, sigma: Array) -> Array: """ Sample noise from the prior noise distribution with noise scale sigma(t). Args: key (Array): JAX random key. shape (Any): Shape of the output. sigma (Array): Noise scale. Returns: Array: Sampled noise. """ n = jax.random.normal(key, shape) * sigma return n
[docs] def sample_prior(self, key: Array, shape: Any) -> Array: """ Sample x from the prior distribution. Args: key (Array): JAX random key. shape (Any): Shape of the output. Returns: Array: Sampled prior. """ return jax.random.normal(key, shape)
[docs] @abc.abstractmethod def loss_weight(self, sigma: Array) -> Array: """ Weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper. Args: sigma (Array): Noise scale. Returns: Array: Loss weight. """ ...
[docs] def f(self, x: Array, t: Array) -> Array: r""" Drift term for the forward diffusion process. Computes the drift term :math:`f(x, t) = x \frac{ds}{dt} / s(t)` as used in the SDE formulation. Args: x (Array): Input data. t (Array): Time. Returns: Array: Drift term. """ return x * self.s_deriv(t) / self.s(t)
[docs] def g(self, x: Array, t: Array) -> Array: r""" Diffusion term for the forward diffusion process. Computes the diffusion term :math:`g(x, t) = s(t) \sqrt{2 \frac{d\sigma}{dt} \sigma(t)}` as used in the SDE formulation. Args: x (Array): Input data. t (Array): Time. Returns: Array: Diffusion term. """ return self.s(t) * jnp.sqrt(2 * self.sigma_deriv(t) * self.sigma(t))
[docs] def denoise(self, F: Callable, x: Array, sigma: Array, *args, **kwargs) -> Array: r""" Denoise function, :math:`D` in the EDM paper, which shares a connection with the score function: .. math:: \nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2} This function includes the preconditioning and is connected to the NN objective :math:`F`: .. math:: D_\theta(x; \sigma) = c_\text{skip}(\sigma) x + c_\text{out}(\sigma) F_\theta (c_\text{in}(\sigma) x; c_\text{noise}(\sigma)) Args: F (Callable): Model function. x (Array): Input data. sigma (Array): Noise scale. *args: Additional arguments. **kwargs: Additional keyword arguments. Returns: Array: Denoised output. """ return self.c_skip(sigma) * x + self.c_out(sigma) * F( self.c_in(sigma) * x, self.c_noise(sigma), *args, **kwargs )
[docs] def get_score_function(self, F: Callable) -> Callable: r""" Returns the score function :math:`\nabla_x \log p(x; \sigma)` as described in the EDM paper. The score function is computed as: .. math:: \nabla_x \log p(x; \sigma) = \frac{D(x; \sigma) - x}{\sigma^2} where :math:`D(x; \sigma)` is the denoised output (see `denoise` method). Args: F (Callable): Model function. Returns: Callable: Score function. """ def score(x: Array, u: Array, *args, **kwargs) -> Array: t = self.time_schedule(u) sigma = self.sigma(t) return (self.denoise(F, x, sigma, *args, **kwargs) - x) / (sigma**2) return score
[docs] def get_loss_fn(self) -> Callable: r""" Returns the loss function for EDM training, as described in the EDM paper. The loss is computed as (see Eq. 8 in the EDM paper): .. math:: \lambda(\sigma) \, c_\text{out}^2(\sigma) \left[ F(c_\text{in}(\sigma) x_t, c_\text{noise}(\sigma), \ldots) - \frac{1}{c_\text{out}(\sigma)} (x_1 - c_\text{skip}(\sigma) x_t) \right]^2 Args: None directly; returns a function that computes the loss. Returns: Callable: Loss function. """ def loss_fn(F: Callable, batch: tuple, loss_mask: Any = None, model_extras: dict = {}) -> Array: (x_1, x_t, sigma) = batch lam = self.loss_weight(sigma) c_out = self.c_out(sigma) c_in = self.c_in(sigma) c_noise = self.c_noise(sigma) c_skip = self.c_skip(sigma) if loss_mask is not None: loss_mask = jnp.broadcast_to(loss_mask, x_1.shape) x_t = jnp.where(loss_mask, x_1, x_t) loss = ( lam * c_out**2 * ( F(c_in * (x_t), c_noise, **model_extras) - 1 / c_out * (x_1 - c_skip * (x_t)) ) ** 2 ) if loss_mask is not None: loss = jnp.where(loss_mask, 0.0, loss) # we sum the loss on any dimension that is not the batch dimentsion, and then we compute the mean over the batch dimension (the first) return jnp.mean(jnp.sum(loss, axis=tuple(range(1, len(x_1.shape))))) # type: ignore return loss_fn
[docs] class VPScheduler(BaseSDE): def __init__(self, beta_min=0.1, beta_max=20.0, e_s=1e-3, e_t=1e-5, M=1000): super().__init__() self.beta_min = beta_min self.beta_max = beta_max self.beta_d = beta_max - beta_min self.e_s = e_s self.e_t = e_t self.M = M return @property def name(self): return "EDM-VP"
[docs] def time_schedule(self, u): return 1 + u * (self.e_s - 1)
[docs] def sigma(self, t): # also known as the schedule, as in tab 1 of EDM paper return jnp.sqrt(jnp.exp(0.5 * self.beta_d * t**2 + self.beta_min * t) - 1)
[docs] def sigma_inv(self, sigma): return ( jnp.sqrt(self.beta_min**2 + 2 * self.beta_d * jnp.log(1 + sigma**2)) - self.beta_min ) / self.beta_d
[docs] def sigma_deriv(self, t): # also known as the schedule derivative return ( 0.5 * (self.beta_min + self.beta_d * t) * (self.sigma(t) + 1 / self.sigma(t)) )
[docs] def s(self, t): # also known as scaling, as in tab 1 of EDM paper return 1 / jnp.sqrt(jnp.exp(0.5 * self.beta_d * t**2 + self.beta_min * t))
[docs] def s_deriv(self, t): # also known as scaling derivative return -self.sigma(t) * self.sigma_deriv(t) * (self.s(t) ** 3)
[docs] def f(self, x, t): # f(x, sigma) in the SDE, also known as drift term for the forward diffusion process return -x * 0.5 * (self.beta_min + self.beta_d * t)
[docs] def g(self, x, t): # g(sigma) in the SDE, also known as diffusion term for the forward diffusion process return jnp.sqrt(self.beta_min + self.beta_d * t)
[docs] def c_skip(self, sigma): # c_skip for preconditioning return jnp.ones_like(sigma)
[docs] def c_out(self, sigma): # c_out for preconditioning return -sigma
[docs] def c_in(self, sigma): # c_in for preconditioning return 1 / jnp.sqrt(sigma**2 + 1)
[docs] def c_noise(self, sigma): # c_noise for preconditioning return (self.M - 1) * self.sigma_inv(sigma)
[docs] def loss_weight(self, sigma): return 1 / sigma**2
[docs] def sample_sigma(self, key, shape): # sample sigma from the prior noise distribution u = jax.random.uniform(key, shape, minval=self.e_t, maxval=1) return self.sigma(u)
[docs] class VEScheduler(BaseSDE): def __init__(self, sigma_min=1e-3, sigma_max=15.0): super().__init__() self.sigma_min = sigma_min self.sigma_max = sigma_max return @property def name(self): return "EDM-VE"
[docs] def time_schedule(self, u): return self.sigma_max**2 * (self.sigma_min / self.sigma_max) ** (2 * u)
[docs] def sigma(self, t): # also known as the schedule, as in tab 1 of EDM paper return jnp.sqrt(t)
[docs] def sigma_inv(self, sigma): return sigma**2
[docs] def sigma_deriv(self, t): return 1 / (2 * jnp.sqrt(t))
[docs] def s(self, t): # also known as scaling, as in tab 1 of EDM paper return jnp.ones_like(t)
[docs] def s_deriv(self, t): # also known as scaling derivative return jnp.zeros_like(t)
[docs] def c_skip(self, sigma): # c_skip for preconditioning return jnp.ones_like(sigma)
[docs] def c_out(self, sigma): # c_out for preconditioning return sigma
[docs] def c_in(self, sigma): # c_in for preconditioning return jnp.ones_like(sigma)
[docs] def c_noise(self, sigma): # c_noise for preconditioning return jnp.log(0.5 * sigma)
[docs] def loss_weight(self, sigma): return 1 / sigma**2
[docs] def sample_sigma(self, key, shape): # sample sigma from the prior noise distribution log_sigma = jax.random.uniform( key, shape, minval=jnp.log(self.sigma_min), maxval=jnp.log(self.sigma_max) ) return jnp.exp(log_sigma)
[docs] class EDMScheduler(BaseSDE): def __init__( self, sigma_min=0.002, sigma_max=80.0, sigma_data=1.0, rho=7, P_mean=-1.2, P_std=1.2, ): super().__init__() self.sigma_min = sigma_min self.sigma_max = sigma_max self.sigma_data = sigma_data self.rho = rho self.P_mean = P_mean self.P_std = P_std return @property def name(self): return "EDM"
[docs] def time_schedule(self, u): return ( self.sigma_max ** (1 / self.rho) + u * (self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho)) ) ** self.rho
[docs] def sigma(self, t): # also known as the schedule, as in tab 1 of EDM paper return t
[docs] def sigma_inv(self, sigma): return sigma
[docs] def sigma_deriv(self, t): # also known as the schedule derivative return jnp.ones_like(t)
[docs] def s(self, t): # also known as scaling, as in tab 1 of EDM paper return jnp.ones_like(t)
[docs] def s_deriv(self, t): # also known as scaling derivative return jnp.zeros_like(t)
[docs] def c_skip(self, sigma): # c_skip for preconditioning return self.sigma_data**2 / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs] def c_out(self, sigma): # c_out for preconditioning return sigma * self.sigma_data / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs] def c_in(self, sigma): # c_in for preconditioning return 1 / jnp.sqrt(sigma**2 + self.sigma_data**2)
[docs] def c_noise(self, sigma): # c_noise for preconditioning return 0.25 * jnp.log(sigma)
[docs] def loss_weight(self, sigma): # weight for the loss function, for MLE estimation, also known as λ(σ) in the EDM paper return (sigma**2 + self.sigma_data**2) / (sigma * self.sigma_data) ** 2
[docs] def sample_sigma(self, key, shape): # sample sigma from the prior noise distribution, in this case it is not anymore a uniform distribution, see https://github.com/NVlabs/edm/blob/008a4e5316c8e3bfe61a62f874bddba254295afb/training/loss.py#L66 rnd_normal = jax.random.normal(key, shape) sigma = jnp.exp(rnd_normal * self.P_std + self.P_mean) return sigma