gensbi.diffusion.path.edm_path#

Classes#

EDMPath

Helper class that provides a standard way to create an ABC using

Module Contents#

class gensbi.diffusion.path.edm_path.EDMPath(scheduler)[source]#

Bases: gensbi.diffusion.path.path.ProbPath

Helper class that provides a standard way to create an ABC using inheritance.

get_loss_fn()[source]#

Returns the loss function for the EDM path.

Returns:

The loss function as provided by the scheduler.

Return type:

Callable

sample(key, x_1, sigma)[source]#

Sample from the EDM probability path.

Parameters:
  • key (Array) – JAX random key.

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • sigma (Array) – Noise scale, shape (batch_size, …).

Returns:

A sample from the EDM path.

Return type:

PathSample

sample_sigma(key, batch_size)[source]#

Sample the noise scale sigma from the scheduler.

Parameters:
  • key (Array) – JAX random key.

  • batch_size (int) – Number of samples to generate.

Returns:

Samples of sigma, shape (batch_size, …).

Return type:

Array

scheduler[source]#