gensbi.diffusion.path package#

Subpackages#

Submodules#

gensbi.diffusion.path.edm_path module#

class gensbi.diffusion.path.edm_path.EDMPath(scheduler)[source]#

Bases: ProbPath

get_loss_fn() Callable[source]#

Returns the loss function for the EDM path.

Returns:

The loss function as provided by the scheduler.

Return type:

Callable

sample(key: Array, x_1: Array, sigma: Array) EDMPathSample[source]#

Sample from the EDM probability path.

Parameters:
  • key (Array) – JAX random key.

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • sigma (Array) – Noise scale, shape (batch_size, …).

Returns:

A sample from the EDM path.

Return type:

PathSample

sample_sigma(key: Array, batch_size: int) Array[source]#

Sample the noise scale sigma from the scheduler.

Parameters:
  • key (Array) – JAX random key.

  • batch_size (int) – Number of samples to generate.

Returns:

Samples of sigma, shape (batch_size, …).

Return type:

Array

gensbi.diffusion.path.path module#

class gensbi.diffusion.path.path.ProbPath(scheduler: Any)[source]#

Bases: ABC

property name: str#

Returns the name of the scheduler.

Returns:

Scheduler name.

Return type:

str

abstractmethod sample(*args: Any, **kwargs: Any) EDMPathSample[source]#

Abstract method to sample from the probability path.

Returns:

Sample from the path.

Return type:

PathSample

sample_prior(key: Array, shape: Any) Array[source]#

Sample from the prior distribution.

Parameters:
  • key (Array) – JAX random key.

  • shape (Any) – Shape of the samples to generate, should be (nsamples, ndim).

Returns:

Samples from the prior distribution, shape (nsamples, ndim).

Return type:

Array

gensbi.diffusion.path.path_sample module#

class gensbi.diffusion.path.path_sample.EDMPathSample(x_1: Array, sigma: Array, x_t: Array)[source]#

Bases: object

Represents a sample of a diffusion generated probability path.

x_1#

the target sample \(X_1\).

Type:

Array

sigma#

the noise scale \(t\).

Type:

Array

x_t#

samples \(X_t \sim p_t(X_t)\), shape (batch_size, …).

Type:

Array

get_batch() Tuple[Array, Array, Array][source]#

Returns the batch as a tuple (x_1, x_t, sigma).

Returns:

The target sample, the noisy sample, and the noise scale.

Return type:

Tuple[Array, Array, Array]

sigma: Array#
x_1: Array#
x_t: Array#

Module contents#

class gensbi.diffusion.path.EDMPath(scheduler)[source]#

Bases: ProbPath

get_loss_fn() Callable[source]#

Returns the loss function for the EDM path.

Returns:

The loss function as provided by the scheduler.

Return type:

Callable

sample(key: Array, x_1: Array, sigma: Array) EDMPathSample[source]#

Sample from the EDM probability path.

Parameters:
  • key (Array) – JAX random key.

  • x_1 (Array) – Target data point, shape (batch_size, …).

  • sigma (Array) – Noise scale, shape (batch_size, …).

Returns:

A sample from the EDM path.

Return type:

PathSample

sample_sigma(key: Array, batch_size: int) Array[source]#

Sample the noise scale sigma from the scheduler.

Parameters:
  • key (Array) – JAX random key.

  • batch_size (int) – Number of samples to generate.

Returns:

Samples of sigma, shape (batch_size, …).

Return type:

Array