gensbi.diffusion.solver#

Submodules#

Classes#

SDESolver

Abstract base class for solvers.

Solver

Abstract base class for solvers.

Package Contents#

class gensbi.diffusion.solver.SDESolver(score_model, path)[source]#

Bases: gensbi.diffusion.solver.solver.Solver

Abstract base class for solvers.

Parameters:
get_sampler(condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Returns a sampler function for the SDE.

Parameters:
  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampler function.

Return type:

Callable

sample(key, x_init, condition_mask=None, condition_value=None, cfg_scale=None, nsteps=18, method='Heun', return_intermediates=False, model_extras={}, solver_params={})[source]#

Sample from the SDE using the sampler.

Parameters:
  • key (Array) – JAX random key.

  • x_init (Array) – Initial value.

  • condition_mask (Optional[Array]) – Mask for conditioning.

  • condition_value (Optional[Array]) – Value for conditioning.

  • cfg_scale (Optional[float]) – Classifier-free guidance scale (not implemented).

  • nsteps (int) – Number of steps.

  • method (str) – Integration method.

  • return_intermediates (bool) – Whether to return intermediate steps.

  • model_extras (dict) – Additional model arguments.

  • solver_params (Optional[dict]) – Additional solver parameters.

Returns:

Sampled output.

Return type:

Array

path#
score_model#
class gensbi.diffusion.solver.Solver[source]#

Bases: abc.ABC

Abstract base class for solvers.

abstract sample(key, x_1)[source]#
Parameters:

x_1 (jax.Array)

Return type:

jax.Array