gensbi.utils package#
Submodules#
gensbi.utils.math module#
- gensbi.utils.math.divergence(vf: Callable, t: Array, x: Array, args: Array | None = None)[source]#
Compute the divergence of the vector field vf at point x and time t. :param vf: The vector field function. :type vf: Callable :param x: The point at which to compute the divergence. :type x: Array :param t: The time at which to compute the divergence. :type t: Array
- Returns:
The divergence of the vector field at point x and time t.
- Return type:
Array
gensbi.utils.model_wrapping module#
- class gensbi.utils.model_wrapping.GuidedModelWrapper(*args: Any, **kwargs: Any)[source]#
Bases:
ModelWrapper
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- cfg_scale: float#
- class gensbi.utils.model_wrapping.ModelWrapper(*args: Any, **kwargs: Any)[source]#
Bases:
Module
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- get_divergence(**kwargs) Callable [source]#
Compute the divergence of the model.
- Parameters:
t (Array) – time (batch_size).
x (Array) – input data to the model (batch_size, …).
args – additional information forwarded to the model, e.g., text condition.
- Returns:
divergence of the model.
- Return type:
Array
- get_vector_field(**kwargs) Callable [source]#
Compute the vector field of the model, properly squeezed for the ODE term.
- Parameters:
x (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
args – additional information forwarded to the model, e.g., text condition.
- Returns:
vector field of the model.
- Return type:
Array