gensbi.diffusion.solver.edm_samplers#

Functions#

edm_ablation_sampler(sde, model, x_1, *, key[, ...])

edm_sampler(sde, model, x_1, *, key[, condition_mask, ...])

EDM sampler for diffusion models.

Module Contents#

gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs={})[source]#
gensbi.diffusion.solver.edm_samplers.edm_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=float('inf'), S_noise=1, method='Heun', model_kwargs={})[source]#

EDM sampler for diffusion models.

Parameters:
  • sde (Any) – SDE scheduler object.

  • model (Callable) – Model function.

  • x_1 (Array) – Initial value.

  • key (Array) – JAX random key.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • n_steps (int) – Number of steps.

  • S_churn (float) – Churn parameter.

  • S_min (float) – Minimum S value.

  • S_max (float) – Maximum S value.

  • S_noise (float) – Noise scale.

  • method (str) – Integration method (“Euler” or “Heun”).

  • model_kwargs (dict) – Additional model arguments.

Returns:

Sampled output.

Return type:

Array