gensbi.diffusion.solver.sde_solver#

Classes#

SDESolver

Abstract base class for solvers.

Module Contents#

class gensbi.diffusion.solver.sde_solver.SDESolver(score_model, path)[source]#

Bases: gensbi.diffusion.solver.solver.Solver

Abstract base class for solvers.

Parameters:
get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampler function.

Return type:

Callable

sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampled output.

Return type:

Array

path[source]#
score_model[source]#