Source code for gensbi.models.simformer.embedding

import jax
from jax import numpy as jnp
from flax import nnx
import numpy as np
from jax.typing import DTypeLike 
from jax import Array

[docs] class MLPEmbedder(nnx.Module): def __init__( self, in_dim: int, hidden_dim: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.float32, ): self.p_skip = nnx.Param(0.01*jnp.ones((1, 1, hidden_dim))) self.in_layer = nnx.Linear( in_features=in_dim, out_features=hidden_dim, use_bias=True, rngs=rngs, param_dtype=param_dtype, ) self.silu = nnx.silu self.out_layer = nnx.Linear( in_features=hidden_dim, out_features=hidden_dim, use_bias=True, rngs=rngs, param_dtype=param_dtype, ) def __call__(self, x: Array) -> Array: x = jnp.atleast_1d(x) out = self.out_layer(self.silu(self.in_layer(x))) x_repeated, out = jnp.broadcast_arrays(x, out) out = x_repeated * self.p_skip + (1-self.p_skip)*out return out
[docs] class SimpleTimeEmbedding(nnx.Module): def __init__(self): """Simple time embedding module. Mostly used to embed time. """ return def __call__(self, t): t = jnp.atleast_1d(t) if t.ndim == 1: t = jnp.expand_dims(t, axis=1) out = jnp.concatenate([ t - 0.5, jnp.cos(2 * jnp.pi * t), jnp.sin(2 * jnp.pi * t), -jnp.cos(4 * jnp.pi * t) ], axis=-1) return out
[docs] class SinusoidalEmbedding(nnx.Module): def __init__(self, output_dim: int = 128): """Sinusoidal embedding module. Mostly used to embed time. Args: output_dim (int, optional): Output dimesion. Defaults to 128. """ self.output_dim = output_dim return def __call__(self, t): t = jnp.atleast_1d(t) if t.ndim == 1: t = jnp.expand_dims(t, axis=1) half_dim = self.output_dim // 2 + 1 emb = jnp.log(10000) / (half_dim - 1) emb = jnp.exp(jnp.arange(half_dim) * -emb) emb = jnp.expand_dims(emb, 0) # emb = t[..., None] * emb[None, ...] emb = jnp.dot(t, emb) emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], -1) return emb[..., : self.output_dim]
# class GaussianFourierEmbedding(nnx.Module): # def __init__( # self, # output_dim: int = 128, # learnable: bool = False, # *, # rngs: nnx.Rngs # ): # """Gaussian Fourier embedding module. Mostly used to embed time. # Args: # output_dim (int, optional): Output dimesion. Defaults to 128. # """ # self.output_dim = output_dim # self.B = nnx.Param(jax.random.normal(rngs.params(), [self.output_dim // 2 + 1])) # self.learnable = learnable # return # def __call__(self, t): # t = jnp.atleast_1d(t) # if t.ndim == 1: # t = jnp.expand_dims(t, axis=1) # if not self.learnable: # B = jnp.expand_dims(jax.lax.stop_gradient(self.B), 0) # else: # B = jnp.expand_dims(self.B, 0) # arg = 2 * jnp.pi * jnp.dot(t,B) # term1 = jnp.cos(arg) # term2 = jnp.sin(arg) # out = jnp.concatenate([term1, term2], axis=-1) # return out[..., : self.output_dim]
[docs] class GaussianFourierEmbedding(nnx.Module): def __init__( self, output_dim: int = 128, learnable: bool = False, *, rngs: nnx.Rngs ): """Gaussian Fourier embedding module. Mostly used to embed time. Args: output_dim (int, optional): Output dimesion. Defaults to 128. """ self.output_dim = output_dim half_dim = self.output_dim // 2 + 1 self.B = nnx.Param(jax.random.normal(rngs.params(), [half_dim , 1])) if not learnable: self.B = jax.lax.stop_gradient(self.B) return def __call__(self, t): t = jnp.atleast_1d(t) if t.ndim == 1: t = jnp.expand_dims(t, axis=1) # B = jax.lax.cond( # self.learnable, # lambda: self.B, # True branch: use B directly # lambda: jax.lax.stop_gradient(self.B) # False branch: use B with stop_gradient # ) B = self.B arg = 2 * jnp.pi * jnp.dot(t,B.T) term1 = jnp.cos(arg) term2 = jnp.sin(arg) out = jnp.concatenate([term1, term2], axis=-1) return out[..., : self.output_dim]