from abc import ABC
from flax import nnx
from jax import Array
import jax.numpy as jnp
from typing import Callable
from .math import divergence
from einops import rearrange
[docs]
def _expand_dims(x: Array) -> Array:
if x.ndim < 3:
x = rearrange(x, '... -> 1 ... 1' if x.ndim == 1 else '... -> ... 1')
return x
[docs]
def _expand_time(t: Array) -> Array:
t = jnp.atleast_1d(t)
if t.ndim < 2:
t = t[..., None]
return t
[docs]
class ModelWrapper(nnx.Module):
"""
This class is used to wrap around another model. We define a call method which returns the model output.
Furthermore, we define a vector_field method which computes the vector field of the model,
and a divergence method which computes the divergence of the model, in a form useful for diffrax.
This is useful for ODE solvers that require the vector field and divergence of the model.
"""
def __init__(self, model: nnx.Module):
[docs]
def _expand_dims(self, x: Array) -> Array:
if x.ndim < 3:
x = rearrange(x, '... -> 1 ... 1' if x.ndim == 1 else '... -> ... 1')
return x
[docs]
def _expand_time(self, t: Array) -> Array:
t = jnp.atleast_1d(t)
if t.ndim < 2:
t = t[..., None]
return t
[docs]
def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array:
r"""
This method defines how inputs should be passed through the wrapped model.
Here, we're assuming that the wrapped model takes both :math:`obs` and :math:`t` as input,
along with any additional keyword arguments.
Optional things to do here:
- check that t is in the dimensions that the model is expecting.
- add a custom forward pass logic.
- call the wrapped model.
| given obs, t
| returns the model output for input obs at time t, with extra information `extra`.
Args:
obs (Array): input data to the model (batch_size, ...).
t (Array): time (batch_size).
**extras: additional information forwarded to the model, e.g., text condition.
Returns:
Array: model output.
"""
obs = _expand_dims(obs)
# t = self._expand_time(t)
return self.model(obs, t, *args, **kwargs)
[docs]
def get_vector_field(self, **kwargs) -> Callable:
r"""Compute the vector field of the model, properly squeezed for the ODE term.
Args:
x (Array): input data to the model (batch_size, ...).
t (Array): time (batch_size).
args: additional information forwarded to the model, e.g., text condition.
Returns:
Array: vector field of the model.
"""
def vf(t, x, args):
# merge args and kwargs
args = args if args is not None else {}
vf = self(t, x, **args, **kwargs)
# squeeze the first dimension of the vector field if it is 1
if vf.shape[0] == 1:
vf = jnp.squeeze(vf, axis=0)
vf = jnp.squeeze(vf, axis=-1)
return vf
return vf
[docs]
def get_divergence(self, **kwargs) -> Callable:
r"""Compute the divergence of the model.
Args:
t (Array): time (batch_size).
x (Array): input data to the model (batch_size, ...).
args: additional information forwarded to the model, e.g., text condition.
Returns:
Array: divergence of the model.
"""
vf = self.get_vector_field(**kwargs)
def div_(t, x, args):
div = divergence(vf, t, x, args)
# squeeze the first dimension of the divergence if it is 1
if div.shape[0] == 1:
div = jnp.squeeze(div, axis=0)
return div
return div_
[docs]
class GuidedModelWrapper(ModelWrapper):
"""
This class is used to wrap around another model. We define a call method which returns the model output.
Furthermore, we define a vector_field method which computes the vector field of the model,
and a divergence method which computes the divergence of the model, in a form useful for diffrax.
This is useful for ODE solvers that require the vector field and divergence of the model.
"""
def __init__(self, model, cfg_scale=0.7):
super().__init__(model)
self.cfg_scale = cfg_scale
[docs]
def __call__(self, t: Array, obs: Array, *args, **kwargs) -> Array:
r"""Compute the guided model output as a weighted sum of conditioned and unconditioned predictions.
Args:
obs (Array): input data to the model (batch_size, ...).
t (Array): time (batch_size).
args: additional information forwarded to the model, e.g., text condition.
**kwargs: additional keyword arguments.
Returns:
Array: guided model output.
"""
kwargs.pop('conditioned', None) # we set this flag manually
# Get outputs from parent class
c_out = super().__call__(t, obs, *args, conditioned=True, **kwargs)
u_out = super().__call__(t, obs, *args, conditioned=False, **kwargs)
return (1 - self.cfg_scale) * u_out + self.cfg_scale * c_out
[docs]
def get_vector_field(self, **kwargs) -> Callable:
"""Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions."""
# Get vector fields from parent class
c_vf = super().get_vector_field(conditioned=True, **kwargs)
u_vf = super().get_vector_field(conditioned=False, **kwargs)
def g_vf(t, x, args):
return (1 - self.cfg_scale) * u_vf(t, x, args) + self.cfg_scale * c_vf(t, x, args)
return g_vf