Source code for gensbi.diffusion.path.path

from abc import ABC, abstractmethod
from jax import Array
from gensbi.diffusion.path.path_sample import EDMPathSample
from typing import Any

[docs] class ProbPath(ABC): def __init__(self, scheduler: Any) -> None: r""" Initialize the probability path. Args: scheduler: Scheduler object. """ self.scheduler = scheduler return
[docs] def sample_prior(self, key: Array, shape: Any) -> Array: r""" Sample from the prior distribution. Args: key (Array): JAX random key. shape (Any): Shape of the samples to generate, should be (nsamples, ndim). Returns: Array: Samples from the prior distribution, shape (nsamples, ndim). """ return self.scheduler.sample_prior(key, shape)
@property def name(self) -> str: r""" Returns the name of the scheduler. Returns: str: Scheduler name. """ return self.scheduler.name
[docs] @abstractmethod def sample(self, *args: Any, **kwargs: Any) -> "EDMPathSample": r""" Abstract method to sample from the probability path. Returns: PathSample: Sample from the path. """ ...