Source code for gensbi.utils.model_wrapping

from abc import ABC
from flax import nnx
from jax import Array
import jax.numpy as jnp

from typing import Callable

from .math import divergence

from einops import rearrange

[docs] def _expand_dims(x: Array) -> Array: if x.ndim < 3: x = rearrange(x, '... -> 1 ... 1' if x.ndim == 1 else '... -> ... 1') return x
[docs] def _expand_time(t: Array) -> Array: t = jnp.atleast_1d(t) if t.ndim < 2: t = t[..., None] return t
[docs] class ModelWrapper(nnx.Module): """ This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model. """ def __init__(self, model: nnx.Module):
[docs] self.model = model
[docs] def _expand_dims(self, x: Array) -> Array: if x.ndim < 3: x = rearrange(x, '... -> 1 ... 1' if x.ndim == 1 else '... -> ... 1') return x
[docs] def _expand_time(self, t: Array) -> Array: t = jnp.atleast_1d(t) if t.ndim < 2: t = t[..., None] return t
[docs] def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array: r""" This method defines how inputs should be passed through the wrapped model. Here, we're assuming that the wrapped model takes both :math:`obs` and :math:`t` as input, along with any additional keyword arguments. Optional things to do here: - check that t is in the dimensions that the model is expecting. - add a custom forward pass logic. - call the wrapped model. | given obs, t | returns the model output for input obs at time t, with extra information `extra`. Args: obs (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). **extras: additional information forwarded to the model, e.g., text condition. Returns: Array: model output. """ obs = _expand_dims(obs) # t = self._expand_time(t) return self.model(obs, t, *args, **kwargs)
[docs] def get_vector_field(self, **kwargs) -> Callable: r"""Compute the vector field of the model, properly squeezed for the ODE term. Args: x (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). args: additional information forwarded to the model, e.g., text condition. Returns: Array: vector field of the model. """ def vf(t, x, args): # merge args and kwargs args = args if args is not None else {} vf = self(t, x, **args, **kwargs) # squeeze the first dimension of the vector field if it is 1 if vf.shape[0] == 1: vf = jnp.squeeze(vf, axis=0) vf = jnp.squeeze(vf, axis=-1) return vf return vf
[docs] def get_divergence(self, **kwargs) -> Callable: r"""Compute the divergence of the model. Args: t (Array): time (batch_size). x (Array): input data to the model (batch_size, ...). args: additional information forwarded to the model, e.g., text condition. Returns: Array: divergence of the model. """ vf = self.get_vector_field(**kwargs) def div_(t, x, args): div = divergence(vf, t, x, args) # squeeze the first dimension of the divergence if it is 1 if div.shape[0] == 1: div = jnp.squeeze(div, axis=0) return div return div_
[docs] class GuidedModelWrapper(ModelWrapper): """ This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model. """
[docs] cfg_scale: float
def __init__(self, model, cfg_scale=0.7): super().__init__(model) self.cfg_scale = cfg_scale
[docs] def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array: r"""Compute the guided model output as a weighted sum of conditioned and unconditioned predictions. Args: obs (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). args: additional information forwarded to the model, e.g., text condition. **kwargs: additional keyword arguments. Returns: Array: guided model output. """ kwargs.pop('conditioned', None) # we set this flag manually # Get outputs from parent class c_out = super().__call__(t, obs, *args, conditioned=True, **kwargs) u_out = super().__call__(t, obs, *args, conditioned=False, **kwargs) return (1 - self.cfg_scale) * u_out + self.cfg_scale * c_out
[docs] def get_vector_field(self, **kwargs) -> Callable: """Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions.""" # Get vector fields from parent class c_vf = super().get_vector_field(conditioned=True, **kwargs) u_vf = super().get_vector_field(conditioned=False, **kwargs) def g_vf(t, x, args): return (1 - self.cfg_scale) * u_vf(t, x, args) + self.cfg_scale * c_vf(t, x, args) return g_vf