gensbi.diffusion.solver.edm_samplers#
Functions#
|
|
|
EDM sampler for diffusion models. |
Module Contents#
- gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs={})[source]#
- gensbi.diffusion.solver.edm_samplers.edm_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs={})[source]#
EDM sampler for diffusion models.
- Parameters:
sde (Any) – SDE scheduler object.
model (Callable) – Model function.
x_1 (Array) – Initial value.
key (Array) – JAX random key.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
return_intermediates (bool) – Whether to return intermediate steps.
n_steps (int) – Number of steps.
S_churn (float) – Churn parameter.
S_min (float) – Minimum S value.
S_max (float) – Maximum S value.
S_noise (float) – Noise scale.
method (str) – Integration method (“Euler” or “Heun”).
model_kwargs (dict) – Additional model arguments.
- Returns:
Sampled output.
- Return type:
Array