Source code for gensbi.utils.model_wrapping

# FIXME: first pass

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from abc import ABC
from flax import nnx
from jax import Array
import jax.numpy as jnp

from typing import Callable

from .math import divergence


[docs] class ModelWrapper(nnx.Module): """ This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model. """ def __init__(self, model: nnx.Module): self.model = model def _call_model(self, x: Array, t: Array, args, **kwargs) -> Array: r""" This method is a wrapper around the model's call method. It allows us to pass additional arguments to the model, such as text conditions or other auxiliary information. Args: x (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). args: additional information forwarded to the model, e.g., text condition. **kwargs: additional keyword arguments. Returns: Array: model output. """ return self.model(x, t, args=args, **kwargs) # type: ignore def __call__(self, x: Array, t: Array, args=None, **kwargs) -> Array: r""" This method defines how inputs should be passed through the wrapped model. Here, we're assuming that the wrapped model takes both :math:`x` and :math:`t` as input, along with any additional keyword arguments. Optional things to do here: - check that t is in the dimensions that the model is expecting. - add a custom forward pass logic. - call the wrapped model. | given x, t | returns the model output for input x at time t, with extra information `extra`. Args: x (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). **extras: additional information forwarded to the model, e.g., text condition. Returns: Array: model output. """ return self._call_model(x, t, args, **kwargs)
[docs] def get_vector_field(self, **kwargs) -> Callable: r"""Compute the vector field of the model, properly squeezed for the ODE term. Args: x (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). args: additional information forwarded to the model, e.g., text condition. Returns: Array: vector field of the model. """ def vf(t, x, args): vf = self._call_model(x, t, args, **kwargs) # squeeze the first dimension of the vector field if it is 1 if vf.shape[0] == 1: vf = jnp.squeeze(vf, axis=0) return vf return vf
[docs] def get_divergence(self, **kwargs) -> Callable: r"""Compute the divergence of the model. Args: t (Array): time (batch_size). x (Array): input data to the model (batch_size, ...). args: additional information forwarded to the model, e.g., text condition. Returns: Array: divergence of the model. """ vf = self.get_vector_field(**kwargs) def div_(t, x, args): div = divergence(vf, t, x, args) # squeeze the first dimension of the divergence if it is 1 if div.shape[0] == 1: div = jnp.squeeze(div, axis=0) return div return div_
[docs] class GuidedModelWrapper(ModelWrapper): """ This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model. """ cfg_scale: float def __init__(self, model, cfg_scale=0.7): super().__init__(model) self.cfg_scale = cfg_scale def __call__(self, x: Array, t: Array, args=None, **kwargs) -> Array: r"""Compute the guided model output as a weighted sum of conditioned and unconditioned predictions. Args: x (Array): input data to the model (batch_size, ...). t (Array): time (batch_size). args: additional information forwarded to the model, e.g., text condition. **kwargs: additional keyword arguments. Returns: Array: guided model output. """ # Get outputs from parent class c_out = self._call_model(x, t, args, conditioned=True, **kwargs) u_out = self._call_model(x, t, args, conditioned=False, **kwargs) return (1 - self.cfg_scale) * u_out + self.cfg_scale * c_out
[docs] def get_vector_field(self, **kwargs) -> Callable: """Compute the guided vector field as a weighted sum of conditioned and unconditioned predictions.""" # Get vector fields from parent class c_vf = super().get_vector_field(conditioned=True, **kwargs) u_vf = super().get_vector_field(conditioned=False, **kwargs) def g_vf(t, x, args): return (1 - self.cfg_scale) * u_vf(t, x, args) + self.cfg_scale * c_vf(t, x, args) return g_vf