Source code for gensbi.models.simformer.transformer
import jax
from jax import numpy as jnp
from jax import jit, vmap
from flax import nnx
from typing import Callable, Optional
from jaxtyping import Array, PyTree
# layer = nnx.MultiHeadAttention(
# num_heads=8, in_features=5, qkv_features=16, decode=False, rngs=nnx.Rngs(0)
# )
[docs]
class AttentionBlock(nnx.Module):
def __init__(
self,
din: int,
num_heads: int,
features: int,
dropout_rate: float,
skip_connection: bool,
rngs: nnx.Rngs,
):
self.skip_connection = skip_connection
self.dropout_rate = dropout_rate
self.deterministic = not dropout_rate > 0
self.layer_norm = nnx.LayerNorm(din, rngs=rngs)
self.attn = nnx.MultiHeadAttention(
in_features=din,
num_heads=num_heads,
qkv_features=features,
dropout_rate=self.dropout_rate,
deterministic=self.deterministic,
decode=False,
rngs=rngs,
)
def __call__(self, x: jnp.ndarray, mask: jnp.ndarray | None) -> jnp.ndarray:
x = self.layer_norm(x)
x_in = x
x = self.attn(x, mask=mask)
if self.skip_connection:
x = x + x_in
return x
[docs]
class DenseBlock(nnx.Module):
def __init__(
self,
din,
dcontext,
num_hidden_layers,
widening_factor: int,
dropout_rate: float,
act: Callable,
skip_connection: bool,
rngs: nnx.Rngs,
):
self.skip_connection = skip_connection
n_features = din
self.layer_norm = nnx.LayerNorm(din, rngs=rngs)
self.hidden_blocks = []
self.hidden_blocks.append(
nnx.Linear(n_features, widening_factor * n_features, rngs=rngs)
)
n_features *= widening_factor
for i in range(1, num_hidden_layers):
self.hidden_blocks.append(
nnx.Linear(n_features, n_features, rngs=rngs)
)
self.hidden_blocks.append(nnx.Linear(n_features, din, rngs=rngs))
self.act = act
self.dropout_rate = dropout_rate
self.dropout = nnx.Dropout(rate=dropout_rate, rngs=rngs)
self.context_block = nnx.Linear(dcontext, din, rngs=rngs)
return
def __call__(self, x, context):
x = self.layer_norm(x)
x_in = x
for i in range(len(self.hidden_blocks) - 1):
x = self.hidden_blocks[i](x)
x = self.act(x)
x = self.hidden_blocks[-1](x)
if self.dropout_rate > 0:
x = self.dropout(x)
if context is not None:
context_emb = self.context_block(context)
context_emb = self.act(context_emb)
while context_emb.ndim < x.ndim:
context_emb = context_emb[..., None, :]
x = x + context_emb
if self.skip_connection:
x = x + x_in
return x
[docs]
class Transformer(nnx.Module):
"""A transformer stack."""
def __init__(
self,
din: int,
dcontext: int,
num_heads: int,
num_layers: int,
features: int,
dropout_rate: float = 0,
widening_factor: int = 4,
num_hidden_layers: int = 1,
act: Callable = jax.nn.gelu,
skip_connection_attn: bool = True,
skip_connection_mlp: bool = True,
*, # Enforce keyword arguments
rngs: nnx.Rngs,
):
self.din = din
self.dcontext = dcontext
self.num_heads = num_heads
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.widening_factor = widening_factor
self.num_hidden_layers = num_hidden_layers
self.act = act
self.skip_connection_attn = skip_connection_attn
self.skip_connection_mlp = skip_connection_mlp
self.rngs = rngs
# now we define attention and dense blocks
self.attention_blocks = []
self.dense_blocks = []
self.layer_norm = nnx.LayerNorm(din, rngs=rngs)
for _ in range(num_layers):
self.attention_blocks.append(
AttentionBlock(
din=self.din,
num_heads=num_heads,
features=features,
dropout_rate=self.dropout_rate,
skip_connection=skip_connection_attn,
rngs=rngs,
)
)
self.dense_blocks.append(
DenseBlock(
din,
dcontext,
num_hidden_layers,
widening_factor,
self.dropout_rate,
act=self.act,
skip_connection=skip_connection_mlp,
rngs=rngs,
)
)
return
def __call__(
self,
inputs: Array, # [B, T, D]
context: Optional[Array] = None, # [B, D_context]
mask: Array | None = None, # [T, T] or [B, T, T]
) -> jax.Array: # [B, T, D]
if mask is not None:
if mask.ndim == 2:
mask = mask[None, None, :, :]
elif mask.ndim == 3:
mask = mask[:, None, :, :]
else:
raise ValueError(f"Mask must have ndim 2 or 3, got {mask.ndim}.")
x = inputs
for i in range(self.num_layers):
x = self.attention_blocks[i](x, mask)
x = self.dense_blocks[i](x, context)
out = self.layer_norm(x)
return out