gensbi.diffusion.solver.sde_solver#
Classes#
Abstract base class for solvers. |
Module Contents#
- class gensbi.diffusion.solver.sde_solver.SDESolver(score_model, path)[source]#
Bases:
gensbi.diffusion.solver.solver.Solver
Abstract base class for solvers.
- Parameters:
score_model (Callable)
- get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#
Returns a sampler function for the SDE.
- Parameters:
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampler function.
- Return type:
Callable
- sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#
Sample from the SDE using the sampler.
- Parameters:
key (Array) – JAX random key.
x_init (Array) – Initial value.
condition_mask (Optional[Array]) – Mask for conditioning.
condition_value (Optional[Array]) – Value for conditioning.
cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).
nsteps (int) – Number of steps.
method (str) – Integration method.
return_intermediates (bool) – Whether to return intermediate steps.
model_extras (dict) – Additional model arguments.
solver_params (Optional[dict]) – Additional solver parameters.
- Returns:
Sampled output.
- Return type:
Array