Source code for gensbi.diffusion.path.edm_path

from abc import ABC, abstractmethod
import jax
from jax import Array
from jax import numpy as jnp
from typing import Callable
import chex

from gensbi.diffusion.path.path import ProbPath
from gensbi.diffusion.path.path_sample import EDMPathSample


[docs] class EDMPath(ProbPath): def __init__(self, scheduler) -> None: r""" Initialize the EDMPath with a scheduler. Args: scheduler: The scheduler object. """ self.scheduler = scheduler assert self.scheduler.name in [ "EDM", "EDM-VP", "EDM-VE", ], f"Scheduler must be one of ['EDM', 'EDM-VP', 'EDM-VE'], got {self.scheduler.name}." return
[docs] def sample(self, key: Array, x_1: Array, sigma: Array) -> EDMPathSample: r""" Sample from the EDM probability path. Args: key (Array): JAX random key. x_1 (Array): Target data point, shape (batch_size, ...). sigma (Array): Noise scale, shape (batch_size, ...). Returns: PathSample: A sample from the EDM path. """ noise = self.scheduler.sample_noise(key, x_1.shape, sigma) x_t = x_1 + noise return EDMPathSample( x_1=x_1, sigma=sigma, x_t=x_t, )
[docs] def sample_sigma(self, key: Array, batch_size: int) -> Array: r""" Sample the noise scale sigma from the scheduler. Args: key (Array): JAX random key. batch_size (int): Number of samples to generate. Returns: Array: Samples of sigma, shape (batch_size, ...). """ return self.scheduler.sample_sigma(key, batch_size)[..., None]
[docs] def get_loss_fn(self) -> Callable: r""" Returns the loss function for the EDM path. Returns: Callable: The loss function as provided by the scheduler. """ return self.scheduler.get_loss_fn()