gensbi.utils.model_wrapping#
Classes#
This class is used to wrap around another model. We define a call method which returns the model output. |
|
This class is used to wrap around another model. We define a call method which returns the model output. |
Functions#
|
|
|
Module Contents#
- class gensbi.utils.model_wrapping.GuidedModelWrapper(model, cfg_scale=0.7)[source]#
Bases:
ModelWrapper
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- __call__(t, obs, *args, **kwargs)[source]#
Compute the guided model output as a weighted sum of conditioned and unconditioned predictions.
- Parameters:
obs (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
args – additional information forwarded to the model, e.g., text condition.
**kwargs – additional keyword arguments.
- Returns:
guided model output.
- Return type:
Array
- class gensbi.utils.model_wrapping.ModelWrapper(model)[source]#
Bases:
flax.nnx.Module
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- Parameters:
model (flax.nnx.Module)
- __call__(t, obs, *args, **kwargs)[source]#
This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with any additional keyword arguments.
- Optional things to do here:
check that t is in the dimensions that the model is expecting.
add a custom forward pass logic.
call the wrapped model.
given obs, treturns the model output for input obs at time t, with extra information extra.- Parameters:
obs (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
**extras – additional information forwarded to the model, e.g., text condition.
- Returns:
model output.
- Return type:
Array
- get_divergence(**kwargs)[source]#
Compute the divergence of the model.
- Parameters:
t (Array) – time (batch_size).
x (Array) – input data to the model (batch_size, …).
args – additional information forwarded to the model, e.g., text condition.
- Returns:
divergence of the model.
- Return type:
Array
- get_vector_field(**kwargs)[source]#
Compute the vector field of the model, properly squeezed for the ODE term.
- Parameters:
x (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
args – additional information forwarded to the model, e.g., text condition.
- Returns:
vector field of the model.
- Return type:
Array