gensbi.utils package#

Submodules#

gensbi.utils.math module#

gensbi.utils.math.divergence(vf: Callable, t: Array, x: Array, args: Array | None = None)[source]#

Compute the divergence of the vector field vf at point x and time t. :param vf: The vector field function. :type vf: Callable :param x: The point at which to compute the divergence. :type x: Array :param t: The time at which to compute the divergence. :type t: Array

Returns:

The divergence of the vector field at point x and time t.

Return type:

Array

gensbi.utils.model_wrapping module#

class gensbi.utils.model_wrapping.GuidedModelWrapper(*args: Any, **kwargs: Any)[source]#

Bases: ModelWrapper

This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.

cfg_scale: float#
get_vector_field(**kwargs) Callable[source]#

Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions.

class gensbi.utils.model_wrapping.ModelWrapper(*args: Any, **kwargs: Any)[source]#

Bases: Module

This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.

get_divergence(**kwargs) Callable[source]#

Compute the divergence of the model.

Parameters:
  • t (Array) – time (batch_size).

  • x (Array) – input data to the model (batch_size, …).

  • args – additional information forwarded to the model, e.g., text condition.

Returns:

divergence of the model.

Return type:

Array

get_vector_field(**kwargs) Callable[source]#

Compute the vector field of the model, properly squeezed for the ODE term.

Parameters:
  • x (Array) – input data to the model (batch_size, …).

  • t (Array) – time (batch_size).

  • args – additional information forwarded to the model, e.g., text condition.

Returns:

vector field of the model.

Return type:

Array

gensbi.utils.plotting module#

gensbi.utils.plotting.plot_trajectories(traj)[source]#

Module contents#