gensbi.models.simformer.model#

Classes#

Simformer

Simformer model for joint density estimation.

SimformerParams

Parameters for the Simformer model.

Module Contents#

class gensbi.models.simformer.model.Simformer(params)[source]#

Bases: flax.nnx.Module

Simformer model for joint density estimation.

Parameters:

params (SimformerParams) – Parameters for the Simformer model.

__call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#

Forward pass of the Simformer model.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Input data.

  • args (Optional[dict]) – Additional arguments.

  • node_ids (Array) – Node identifiers.

  • condition_mask (Array) – Mask for conditioning.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Model output.

Return type:

Array

condition_embedding[source]#
dim_condition[source]#
dim_id[source]#
dim_value[source]#
embedding_net_id[source]#
embedding_net_value[source]#
embedding_time[source]#
in_channels[source]#
output_fn[source]#
params[source]#
total_tokens[source]#
transformer[source]#
class gensbi.models.simformer.model.SimformerParams[source]#

Parameters for the Simformer model.

Parameters:
  • rngs (nnx.Rngs) – Random number generators for initialization.

  • in_channels (int) – Number of input channels.

  • dim_value (int) – Dimension of the value embeddings.

  • dim_id (int) – Dimension of the ID embeddings.

  • dim_condition (int) – Dimension of the condition embeddings.

  • dim_joint (int) – Total dimension of the joint embeddings.

  • fourier_features (int) – Number of Fourier features for time embedding.

  • num_heads (int) – Number of attention heads.

  • num_layers (int) – Number of transformer layers.

  • widening_factor (int) – Widening factor for the transformer.

  • qkv_features (int) – Number of features for QKV layers.

  • num_hidden_layers (int) – Number of hidden layers in the transformer.

__post_init__()[source]#
dim_condition: int[source]#
dim_id: int[source]#
dim_joint: int[source]#
dim_value: int[source]#
fourier_features: int = 128[source]#
in_channels: int[source]#
num_heads: int[source]#
num_hidden_layers: int = 1[source]#
num_layers: int[source]#
qkv_features: int | None = None[source]#
rngs: flax.nnx.Rngs[source]#
widening_factor: int = 3[source]#