gensbi.models.simformer.model#
Classes#
Simformer model for joint density estimation. |
|
Parameters for the Simformer model. |
Module Contents#
- class gensbi.models.simformer.model.Simformer(params)[source]#
Bases:
flax.nnx.ModuleSimformer model for joint density estimation.
- Parameters:
params (SimformerParams) – Parameters for the Simformer model.
- __call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#
Forward pass of the Simformer model.
- Parameters:
t (Array) – Time steps.
obs (Array) – Input data.
args (Optional[dict]) – Additional arguments.
node_ids (Array) – Node identifiers.
condition_mask (Array) – Mask for conditioning.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Model output.
- Return type:
Array
- class gensbi.models.simformer.model.SimformerParams[source]#
Parameters for the Simformer model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for initialization.
in_channels (int) – Number of input channels.
dim_value (int) – Dimension of the value embeddings.
dim_id (int) – Dimension of the ID embeddings.
dim_condition (int) – Dimension of the condition embeddings.
dim_joint (int) – Total dimension of the joint embeddings.
fourier_features (int) – Number of Fourier features for time embedding.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
widening_factor (int) – Widening factor for the transformer.
qkv_features (int) – Number of features for QKV layers.
num_hidden_layers (int) – Number of hidden layers in the transformer.