Source code for gensbi.models.simformer.model
import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike
from einops import rearrange
from flax import nnx
from functools import partial
from typing import Optional
from dataclasses import dataclass
from .transformer import Transformer
from .embedding import GaussianFourierEmbedding, MLPEmbedder
from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time
@dataclass
[docs]
class SimformerParams:
"""Parameters for the Simformer model.
Args:
rngs (nnx.Rngs): Random number generators for initialization.
dim_value (int): Dimension of the value embeddings.
dim_id (int): Dimension of the ID embeddings.
dim_condition (int): Dimension of the condition embeddings.
dim_joint (int): Total dimension of the joint embeddings.
fourier_features (int): Number of Fourier features for time embedding.
num_heads (int): Number of attention heads.
num_layers (int): Number of transformer layers.
widening_factor (int): Widening factor for the transformer.
qkv_features (int): Number of features for QKV layers.
num_hidden_layers (int): Number of hidden layers in the transformer.
"""
# param_dtype: DTypeLike = jnp.float32
[docs]
def __post_init__(self):
if self.qkv_features is None:
self.qkv_features = self.dim_value + self.dim_id + self.dim_condition
[docs]
class Simformer(nnx.Module):
"""
Simformer model for joint density estimation.
Args:
params (SimformerParams): Parameters for the Simformer model.
"""
def __init__(
self,
params: SimformerParams,
):
[docs]
self.embedding_net_value = MLPEmbedder(
in_dim=1, hidden_dim=params.dim_value, rngs=params.rngs
)
# self.embedding_net_value = lambda obs: jnp.repeat(obs, dim_value, axis=-1)
fourier_features = params.fourier_features
[docs]
self.embedding_time = GaussianFourierEmbedding(
fourier_features, rngs=params.rngs, learnable=True
)
[docs]
self.embedding_net_id = nnx.Embed(
num_embeddings=params.dim_joint, features=params.dim_id, rngs=params.rngs
)
[docs]
self.transformer = Transformer(
din=self.total_tokens,
dcontext=fourier_features,
num_heads=params.num_heads,
num_layers=params.num_layers,
features=params.qkv_features,
widening_factor=params.widening_factor,
num_hidden_layers=params.num_hidden_layers,
act=jax.nn.gelu,
skip_connection_attn=True,
skip_connection_mlp=True,
rngs=params.rngs,
)
return
[docs]
def __call__(
self,
t: Array,
obs: Array,
node_ids: Array,
condition_mask: Array,
edge_mask: Optional[Array] = None,
) -> Array:
"""
Forward pass of the Simformer model.
Args:
t (Array): Time steps.
obs (Array): Input data.
args (Optional[dict]): Additional arguments.
node_ids (Array): Node identifiers.
condition_mask (Array): Mask for conditioning.
edge_mask (Optional[Array]): Mask for edges.
Returns:
Array: Model output.
"""
obs = jnp.asarray(obs)
t = jnp.atleast_1d(t)
assert (
obs.ndim == 3
), f"Input obs must be of shape (batch_size, seq_len, 1), got {obs.shape}"
assert (
len(t.ravel()) == obs.shape[0] or len(t.ravel()) == 1
), "t must have the same batch size as obs or size 1, got {} and {}".format(
t.shape, obs.shape
)
t = t.reshape(-1, 1, 1)
batch_size, seq_len, _ = obs.shape
condition_mask = condition_mask.astype(jnp.bool_).reshape(-1, seq_len, 1)
condition_mask = jnp.broadcast_to(condition_mask, (batch_size, seq_len, 1))
if node_ids.ndim == 1:
node_ids = node_ids.reshape(-1, seq_len)
elif node_ids.ndim == 2:
assert (
node_ids.shape[1] == seq_len
), f"node_ids must have shape (-1, {seq_len}), got {node_ids.shape}"
elif node_ids.ndim == 3:
assert (
node_ids.shape[1] == seq_len and node_ids.shape[2] == 1
), f"node_ids must have shape (-1, {seq_len}, 1), got {node_ids.shape}"
node_ids = jnp.squeeze(node_ids, axis=-1)
else:
raise ValueError(f"node_ids must have ndim <=3, got {node_ids.ndim}")
time_embeddings = self.embedding_time(t)
condition_embedding = (
self.condition_embedding * condition_mask
) # If condition_mask is 0, then the embedding is 0, otherwise it is the condition_embedding vector
condition_embedding = jnp.broadcast_to(
condition_embedding, (batch_size, seq_len, self.dim_condition)
)
# Embed inputs and broadcast
value_embeddings = self.embedding_net_value(obs)
id_embeddings = self.embedding_net_id(node_ids)
id_embeddings = jnp.broadcast_to(
id_embeddings, (batch_size, seq_len, self.dim_id)
)
# Concatenate embeddings (alternatively you can also add instead of concatenating)
x_encoded = jnp.concatenate(
[value_embeddings, id_embeddings, condition_embedding], axis=-1
)
h = self.transformer(x_encoded, context=time_embeddings, mask=edge_mask)
out = self.output_fn(h)
# out = jnp.squeeze(out, axis=-1)
return out
[docs]
class SimformerWrapper(ModelWrapper):
def __init__(self, model: Simformer):
[docs]
def conditioned(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
edge_mask: Optional[Array] = None,
) -> Array:
"""
Perform conditioned inference.
Args:
obs (Array): Observations.
obs_ids (Array): Observation identifiers.
cond (Array): Conditioning values.
cond_ids (Array): Conditioning identifiers.
t (Array): Time steps.
edge_mask (Optional[Array]): Mask for edges.
Returns:
Array: Conditioned output.
"""
# repeat cond on the first dimension to match obs
batch_size = obs.shape[0]
cond = jnp.broadcast_to(cond, (batch_size, *cond.shape[1:]))
cond_ids = jnp.broadcast_to(cond_ids, (batch_size, *cond_ids.shape[1:]))
obs_ids = jnp.broadcast_to(obs_ids, (batch_size, *obs_ids.shape[1:]))
condition_mask_dim = obs.shape[1] + cond.shape[1]
condition_mask = jnp.zeros((batch_size, condition_mask_dim), dtype=jnp.bool_)
condition_mask = condition_mask.at[cond_ids].set(True)
obs = jnp.concatenate((obs, cond), axis=1)
node_ids = jnp.concatenate((obs_ids, cond_ids), axis=1)
res = self.model(
obs=obs,
t=t,
node_ids=node_ids,
condition_mask=condition_mask,
edge_mask=edge_mask,
)
# now return only the values on which we are not conditioning
return jnp.take_along_axis(res, obs_ids, axis=1)
[docs]
def unconditioned(
self, t: Array, obs: Array, obs_ids: Array, edge_mask: Optional[Array] = None
) -> Array:
"""
Perform unconditioned inference.
Args:
obs (Array): Observations.
obs_ids (Array): Observation identifiers.
t (Array): Time steps.
edge_mask (Optional[Array]): Mask for edges.
Returns:
Array: Unconditioned output.
"""
batch_size = obs.shape[0]
condition_mask = jnp.zeros((batch_size, obs.shape[1:]), dtype=jnp.bool_)
obs_ids = jnp.broadcast_to(obs_ids, (batch_size, *obs_ids.shape[1:]))
node_ids = obs_ids
res = self.model(
obs=obs,
t=t,
node_ids=node_ids,
condition_mask=condition_mask,
edge_mask=edge_mask,
)
return jnp.take_along_axis(res, obs_ids, axis=1)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True,
edge_mask: Optional[Array] = None,
) -> Array:
r"""
This method defines how inputs should be passed through the wrapped model.
Here, we're assuming that the wrapped model takes both :math:`obs` and :math:`t` as input,
along with additional keyword arguments.
Args:
obs (Array): input data to the model (batch_size, ...).
t (Array): time (batch_size).
cond (Array): conditioning data to the model (batch_size, ...).
obs_ids (Array): observation ids (batch_size, obs_dim).
cond_ids (Array): condition ids (batch_size, cond_dim).
conditioned (bool | Array): whether to use conditioning or not.
edge_mask (Optional[Array]): mask for edges.
Returns:
Array: model output.
"""
obs = _expand_dims(obs)
t = _expand_time(t)
cond = _expand_dims(cond)
obs_ids = _expand_dims(obs_ids)
cond_ids = _expand_dims(cond_ids)
if conditioned:
return self.conditioned(
t, obs, obs_ids, cond, cond_ids, edge_mask=edge_mask
)
else:
return self.unconditioned(t, obs, obs_ids, edge_mask=edge_mask)