gensbi.diffusion.solver package#
Submodules#
gensbi.diffusion.solver.edm_samplers module#
- gensbi.diffusion.solver.edm_samplers.edm_ablation_sampler(sde, model, x_1, *, key, condition_mask=None, condition_value=None, return_intermediates=False, n_steps=18, S_churn=0, S_min=0, S_max=inf, S_noise=1, method='Heun', model_kwargs={})[source]#
- gensbi.diffusion.solver.edm_samplers.edm_sampler(sde: Any, model: Callable, x_1: Array, *, key: Array, condition_mask: Array | None = None, condition_value: Array | None = None, return_intermediates: bool = False, n_steps: int = 18, S_churn: float = 0, S_min: float = 0, S_max: float = inf, S_noise: float = 1, method: str = 'Heun', model_kwargs: dict = {}) Array [source]#
EDM sampler for diffusion models.
- Parameters:
sde – SDE scheduler object.
model (Callable) – Model function.
x_1 (Array) – Initial value.
key (Array) – JAX random key.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
return_intermediates (bool) – Whether to return intermediate steps.
n_steps (int) – Number of steps.
S_churn (float) – Churn parameter.
S_min (float) – Minimum S value.
S_max (float) – Maximum S value.
S_noise (float) – Noise scale.
method (str) – Integration method (“Euler” or “Heun”).
model_kwargs (dict) – Additional model arguments.
- Returns:
Sampled output.
- Return type:
Array
gensbi.diffusion.solver.sde_solver module#
- class gensbi.diffusion.solver.sde_solver.SDESolver(score_model: Callable, path: EDMPath)[source]#
Bases:
Solver
- get_sampler(condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Callable [source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampler function.
- Return type:
Callable
- sample(key: Array, x_init: Array, condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Array [source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampled output.
- Return type:
Array
gensbi.diffusion.solver.solver module#
Module contents#
- class gensbi.diffusion.solver.SDESolver(score_model: Callable, path: EDMPath)[source]#
Bases:
Solver
- get_sampler(condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Callable [source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampler function.
- Return type:
Callable
- sample(key: Array, x_init: Array, condition_mask: Array | None = None, condition_value: Array | None = None, cfg_scale: float | None = None, nsteps: int = 18, method: str = 'Heun', return_intermediates: bool = False, model_extras: dict = {}, solver_params: dict | None = {}) Array [source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampled output.
- Return type:
Array