from typing import Callable, Optional, Sequence, Tuple, Union
import jax
import jax.numpy as jnp
from jax import jit
from jax import Array
from gensbi.diffusion.solver.solver import Solver
from gensbi.diffusion.solver.edm_samplers import edm_sampler, edm_ablation_sampler
from gensbi.diffusion.path import EDMPath
[docs]
class SDESolver(Solver):
def __init__(self, score_model: Callable, path: EDMPath) -> None:
"""
Initialize the SDE solver.
Args:
score_model (Callable): The score model function.
path (EDMPath): The EDMPath object.
"""
self.score_model = score_model
self.path = path
assert self.path.scheduler.name in [
"EDM",
"EDM-VP",
"EDM-VE",
], f"Path must be one of ['EDM', 'EDM-VP', 'EDM-VE'], got {self.path.name}."
[docs]
def get_sampler(
self,
condition_mask: Optional[Array] = None,
condition_value: Optional[Array] = None,
cfg_scale: Optional[float] = None,
nsteps: int = 18,
method: str = "Heun",
return_intermediates: bool = False,
model_extras: dict = {},
solver_params: Optional[dict] = {},
) -> Callable:
"""
Returns a sampler function for the SDE.
Args:
condition_mask (Optional[Array]): Mask for conditioning.
condition_value (Optional[Array]): Value for conditioning.
cfg_scale (Optional[float]): Classifier-free guidance scale (not implemented).
nsteps (int): Number of steps.
method (str): Integration method.
return_intermediates (bool): Whether to return intermediate steps.
model_extras (dict): Additional model arguments.
solver_params (Optional[dict]): Additional solver parameters.
Returns:
Callable: Sampler function.
"""
if self.path.name == "EDM":
sampler_ = edm_sampler
else:
sampler_ = edm_ablation_sampler
if cfg_scale is not None:
raise NotImplementedError(
"CFG scale is not implemented for EDM samplers yet."
)
S_churn = solver_params.get("S_churn", 0) # type: ignore
S_min = solver_params.get("S_min", 0) # type: ignore
S_max = solver_params.get("S_max", float("inf")) # type: ignore
S_noise = solver_params.get("S_noise", 1) # type: ignore
@jit
def sample(key: Array, x_init: Array) -> Array:
return sampler_(
self.path.scheduler,
self.score_model,
x_init,
key=key,
condition_mask=condition_mask,
condition_value=condition_value,
return_intermediates=return_intermediates,
n_steps=nsteps,
S_churn=S_churn,
S_min=S_min,
S_max=S_max,
S_noise=S_noise,
method=method,
model_kwargs=model_extras,
)
return sample
[docs]
def sample(
self,
key: Array,
x_init: Array,
condition_mask: Optional[Array] = None,
condition_value: Optional[Array] = None,
cfg_scale: Optional[float] = None,
nsteps: int = 18,
method: str = "Heun",
return_intermediates: bool = False,
model_extras: dict = {},
solver_params: Optional[dict] = {},
) -> Array:
"""
Sample from the SDE using the sampler.
Args:
key (Array): JAX random key.
x_init (Array): Initial value.
condition_mask (Optional[Array]): Mask for conditioning.
condition_value (Optional[Array]): Value for conditioning.
cfg_scale (Optional[float]): Classifier-free guidance scale (not implemented).
nsteps (int): Number of steps.
method (str): Integration method.
return_intermediates (bool): Whether to return intermediate steps.
model_extras (dict): Additional model arguments.
solver_params (Optional[dict]): Additional solver parameters.
Returns:
Array: Sampled output.
"""
sample = self.get_sampler(
condition_mask=condition_mask,
condition_value=condition_value,
cfg_scale=cfg_scale,
nsteps=nsteps,
method=method,
return_intermediates=return_intermediates,
model_extras=model_extras,
solver_params=solver_params,
)
return sample(key, x_init)