gensbi.diffusion.solver package#

Submodules#

gensbi.diffusion.solver.edm_samplers module#

gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=inf, S_noise=1, method='Heun', model_kwargs={})[source]#
gensbi.diffusion.solver.edm_samplers.edm_sampler(sde: Any, model: Callable, x_1: Array, *, key: Array, condition_mask: Array | None = None, condition_value: Array | None = None, return_intermediates: bool = False, n_steps: int = 18, S_churn: float = 0, S_min: float = 0, S_max: float = inf, S_noise: float = 1, method: str = 'Heun', model_kwargs: dict = {}) Array[source]#

EDM sampler for diffusion models.

Parameters:
  • sde – SDE scheduler object.

  • model (Callable) – Model function.

  • x_1 (Array) – Initial value.

  • key (Array) – JAX random key.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • n_steps (int) – Number of steps.

  • S_churn (float) – Churn parameter.

  • S_min (float) – Minimum S value.

  • S_max (float) – Maximum S value.

  • S_noise (float) – Noise scale.

  • method (str) – Integration method (“Euler” or “Heun”).

  • model_kwargs (dict) – Additional model arguments.

Returns:

Sampled output.

Return type:

Array

gensbi.diffusion.solver.sde_solver module#

class gensbi.diffusion.solver.sde_solver.SDESolver(score_model: Callable, path: EDMPath)[source]#

Bases: Solver

get_sampler(condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Callable[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampler function.

Return type:

Callable

sample(key: Array, x_init: Array, condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Array[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampled output.

Return type:

Array

gensbi.diffusion.solver.solver module#

class gensbi.diffusion.solver.solver.Solver[source]#

Bases: ABC

Abstract base class for solvers.

abstractmethod sample(key, x_1: Array) Array[source]#

Module contents#

class gensbi.diffusion.solver.SDESolver(score_model: Callable, path: EDMPath)[source]#

Bases: Solver

get_sampler(condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Callable[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampler function.

Return type:

Callable

sample(key: Array, x_init: Array, condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Array[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampled output.

Return type:

Array

class gensbi.diffusion.solver.Solver[source]#

Bases: ABC

Abstract base class for solvers.

abstractmethod sample(key, x_1: Array) Array[source]#