gensbi.models.simformer#
Submodules#
Classes#
Simformer model for joint density estimation. |
|
Parameters for the Simformer model. |
Package Contents#
- class gensbi.models.simformer.Simformer(params)[source]#
Bases:
flax.nnx.ModuleSimformer model for joint density estimation.
- Parameters:
params (SimformerParams) – Parameters for the Simformer model.
- __call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#
Forward pass of the Simformer model.
- Parameters:
t (Array) – Time steps.
obs (Array) – Input data.
args (Optional[dict]) – Additional arguments.
node_ids (Array) – Node identifiers.
condition_mask (Array) – Mask for conditioning.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Model output.
- Return type:
Array
- condition_embedding#
- dim_condition#
- dim_id#
- dim_value#
- embedding_net_id#
- embedding_net_value#
- embedding_time#
- in_channels#
- output_fn#
- params#
- total_tokens#
- transformer#
- class gensbi.models.simformer.SimformerParams[source]#
Parameters for the Simformer model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for initialization.
in_channels (int) – Number of input channels.
dim_value (int) – Dimension of the value embeddings.
dim_id (int) – Dimension of the ID embeddings.
dim_condition (int) – Dimension of the condition embeddings.
dim_joint (int) – Total dimension of the joint embeddings.
fourier_features (int) – Number of Fourier features for time embedding.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
widening_factor (int) – Widening factor for the transformer.
qkv_features (int) – Number of features for QKV layers.
num_hidden_layers (int) – Number of hidden layers in the transformer.
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- fourier_features: int = 128#
- in_channels: int#
- num_heads: int#
- num_layers: int#
- qkv_features: int | None = None#
- rngs: flax.nnx.Rngs#
- widening_factor: int = 3#