Source code for gensbi.utils.math


import jax 
import jax.numpy as jnp
from jax import Array
from typing import Callable, Optional

def _divergence_single(vf, t, x):
    res = jnp.trace(jax.jacfwd(vf, argnums=1)(t, x),axis1=-2, axis2=-1)
    return res

    
[docs] def divergence( vf: Callable, t: Array, x: Array, args: Optional[Array] = None, ): """ Compute the divergence of the vector field vf at point x and time t. Args: vf (Callable): The vector field function. x (Array): The point at which to compute the divergence. t (Array): The time at which to compute the divergence. Returns: Array: The divergence of the vector field at point x and time t. """ x = jnp.atleast_1d(x) if x.ndim < 2: x = jnp.expand_dims(x, axis=0) t = jnp.atleast_1d(t) t = jnp.broadcast_to( t, (*x.shape[:-1], t.shape[-1]) ) vf_wrapped = lambda t, x: vf(t, x, args=args) res = jax.vmap(_divergence_single, in_axes=(None, 0, 0))(vf_wrapped, t, x) return jnp.squeeze(res)