Source code for gensbi.models.simformer.simformer

import jax
import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike

from einops import rearrange
from flax import nnx

from functools import partial
from typing import Optional

from dataclasses import dataclass

from .transformer import Transformer
from .embedding import GaussianFourierEmbedding, MLPEmbedder


[docs] @dataclass class SimformerParams: rngs: nnx.Rngs dim_value: int dim_id: int dim_condition: int dim_joint: int fourier_features: int = 128 num_heads: int = 4 num_layers: int = 6 widening_factor: int = 3 qkv_features: int = 8 num_hidden_layers: int = 1 dropout_rate: float = 0.0
# param_dtype: DTypeLike = jnp.float32
[docs] class Simformer(nnx.Module): def __init__( self, params: SimformerParams, ): """ Initialize the Simformer model for joint density estimation. Args: params (SimformerParams): Parameters for the Simformer model. """ self.params = params self.dim_value = params.dim_value self.dim_id = params.dim_id self.dim_condition = params.dim_condition self.embedding_net_value = MLPEmbedder( in_dim=1, hidden_dim=params.dim_value, rngs=params.rngs ) # self.embedding_net_value = lambda x: jnp.repeat(x, dim_value, axis=-1) fourier_features = params.fourier_features self.embedding_time = GaussianFourierEmbedding( fourier_features, rngs=params.rngs ) self.embedding_net_id = nnx.Embed( num_embeddings=params.dim_joint, features=params.dim_id, rngs=params.rngs ) self.condition_embedding = nnx.Param( 0.01 * jnp.ones((1, 1, params.dim_condition)) ) self.total_tokens = params.dim_value + params.dim_id + params.dim_condition self.transformer = Transformer( din=self.total_tokens, dcontext=fourier_features, num_heads=params.num_heads, num_layers=params.num_layers, features=params.qkv_features, widening_factor=params.widening_factor, dropout_rate=params.dropout_rate, num_hidden_layers=params.num_hidden_layers, act=jax.nn.gelu, skip_connection_attn=True, skip_connection_mlp=True, rngs=params.rngs, ) self.output_fn = nnx.Linear(self.total_tokens, 1, rngs=params.rngs) return def __call__( self, x: Array, t: Array, args: Optional[dict] = None, *, node_ids: Array, condition_mask: Array, edge_mask: Optional[Array] = None ) -> Array: """ Forward pass of the Simformer model. Args: x (Array): Input data. t (Array): Time steps. args (Optional[dict]): Additional arguments. node_ids (Array): Node identifiers. condition_mask (Array): Mask for conditioning. edge_mask (Optional[Array]): Mask for edges. Returns: Array: Model output. """ x = jnp.atleast_1d(x) t = jnp.atleast_1d(t) if x.ndim < 3: x = rearrange(x, "... -> 1 ... 1" if x.ndim == 1 else "... -> ... 1") t = t.reshape(-1, 1, 1) batch_size, seq_len, _ = x.shape condition_mask = condition_mask.astype(jnp.bool_).reshape(-1, seq_len, 1) condition_mask = jnp.broadcast_to(condition_mask, (batch_size, seq_len, 1)) node_ids = node_ids.reshape(-1, seq_len) time_embeddings = self.embedding_time(t) condition_embedding = ( self.condition_embedding * condition_mask ) # If condition_mask is 0, then the embedding is 0, otherwise it is the condition_embedding vector condition_embedding = jnp.broadcast_to( condition_embedding, (batch_size, seq_len, self.dim_condition) ) # Embed inputs and broadcast value_embeddings = self.embedding_net_value(x) id_embeddings = self.embedding_net_id(node_ids) id_embeddings = jnp.broadcast_to( id_embeddings, (batch_size, seq_len, self.dim_id) ) # Concatenate embeddings (alternatively you can also add instead of concatenating) x_encoded = jnp.concatenate( [value_embeddings, id_embeddings, condition_embedding], axis=-1 ) h = self.transformer(x_encoded, context=time_embeddings, mask=edge_mask) out = self.output_fn(h) out = jnp.squeeze(out, axis=-1) return out
[docs] class SimformerConditioner(nnx.Module): def __init__(self, model: Simformer): """ Initialize the SimformerConditioner. Args: model (Simformer): Simformer model instance. """ self.model = model self.dim_joint = model.params.dim_joint
[docs] def conditioned( self, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Optional[Array] = None ) -> Array: """ Perform conditioned inference. Args: obs (Array): Observations. obs_ids (Array): Observation identifiers. cond (Array): Conditioning values. cond_ids (Array): Conditioning identifiers. t (Array): Time steps. edge_mask (Optional[Array]): Mask for edges. Returns: Array: Conditioned output. """ obs = jnp.atleast_1d(obs) cond = jnp.atleast_1d(cond) t = jnp.atleast_1d(t) if obs.ndim < 3: obs = rearrange(obs, "... -> 1 ... 1" if obs.ndim == 1 else "... -> ... 1") if cond.ndim < 3: cond = rearrange( cond, "... -> 1 ... 1" if cond.ndim == 1 else "... -> ... 1" ) obs, cond = jnp.broadcast_arrays(obs, cond) condition_mask = jnp.zeros((self.dim_joint,), dtype=jnp.bool_) condition_mask = condition_mask.at[cond_ids].set(True) x = jnp.concatenate([obs, cond], axis=1) node_ids = jnp.concatenate([obs_ids, cond_ids]) # Sort the nodes and the corresponding values nodes_sort = jnp.argsort(node_ids) x = x[:, nodes_sort] node_ids = node_ids[nodes_sort] res = self.model( x=x, t=t, node_ids=node_ids, condition_mask=condition_mask, edge_mask=edge_mask, ) # now return only the values on which we are not conditioning res = res[:, obs_ids] return res
[docs] def unconditioned( self, obs: Array, obs_ids: Array, t: Array, edge_mask: Optional[Array] = None ) -> Array: """ Perform unconditioned inference. Args: obs (Array): Observations. obs_ids (Array): Observation identifiers. t (Array): Time steps. edge_mask (Optional[Array]): Mask for edges. Returns: Array: Unconditioned output. """ obs = jnp.atleast_1d(obs) t = jnp.atleast_1d(t) if obs.ndim < 3: obs = rearrange(obs, "... -> 1 ... 1" if obs.ndim == 1 else "... -> ... 1") condition_mask = jnp.zeros((obs.shape[1],), dtype=jnp.bool_) node_ids = obs_ids x = obs res = self.model( x=x, t=t, node_ids=node_ids, condition_mask=condition_mask, edge_mask=edge_mask, ) return res
def __call__( self, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, timesteps: Array, conditioned: bool = True, edge_mask: Optional[Array] = None ) -> Array: """ Perform inference based on conditioning. Args: obs (Array): Observations. obs_ids (Array): Observation identifiers. cond (Array): Conditioning values. cond_ids (Array): Conditioning identifiers. timesteps (Array): Time steps. conditioned (bool): Whether to perform conditioned inference. edge_mask (Optional[Array]): Mask for edges. Returns: Array: Model output. """ if conditioned: return self.conditioned( obs, obs_ids, cond, cond_ids, timesteps, edge_mask=edge_mask ) else: return self.unconditioned(obs, obs_ids, timesteps, edge_mask=edge_mask)