gensbi.models.simformer#

Submodules#

Classes#

Simformer

Simformer model for joint density estimation.

SimformerParams

Parameters for the Simformer model.

Package Contents#

class gensbi.models.simformer.Simformer(params)[source]#

Bases: flax.nnx.Module

Simformer model for joint density estimation.

Parameters:

params (SimformerParams) – Parameters for the Simformer model.

__call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#

Forward pass of the Simformer model.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Input data.

  • args (Optional[dict]) – Additional arguments.

  • node_ids (Array) – Node identifiers.

  • condition_mask (Array) – Mask for conditioning.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Model output.

Return type:

Array

condition_embedding#
dim_condition#
dim_id#
dim_value#
embedding_net_id#
embedding_net_value#
embedding_time#
in_channels#
output_fn#
params#
total_tokens#
transformer#
class gensbi.models.simformer.SimformerParams[source]#

Parameters for the Simformer model.

Parameters:
  • rngs (nnx.Rngs) – Random number generators for initialization.

  • in_channels (int) – Number of input channels.

  • dim_value (int) – Dimension of the value embeddings.

  • dim_id (int) – Dimension of the ID embeddings.

  • dim_condition (int) – Dimension of the condition embeddings.

  • dim_joint (int) – Total dimension of the joint embeddings.

  • fourier_features (int) – Number of Fourier features for time embedding.

  • num_heads (int) – Number of attention heads.

  • num_layers (int) – Number of transformer layers.

  • widening_factor (int) – Widening factor for the transformer.

  • qkv_features (int) – Number of features for QKV layers.

  • num_hidden_layers (int) – Number of hidden layers in the transformer.

__post_init__()[source]#
dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
fourier_features: int = 128#
in_channels: int#
num_heads: int#
num_hidden_layers: int = 1#
num_layers: int#
qkv_features: int | None = None#
rngs: flax.nnx.Rngs#
widening_factor: int = 3#