gensbi.models.simformer.model#
Classes#
Simformer model for joint density estimation. |
|
Parameters for the Simformer model. |
|
This class is used to wrap around another model. We define a call method which returns the model output. |
Module Contents#
- class gensbi.models.simformer.model.Simformer(params)[source]#
Bases:
flax.nnx.Module
Simformer model for joint density estimation.
- Parameters:
params (SimformerParams) – Parameters for the Simformer model.
- __call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#
Forward pass of the Simformer model.
- Parameters:
t (Array) – Time steps.
obs (Array) – Input data.
args (Optional[dict]) – Additional arguments.
node_ids (Array) – Node identifiers.
condition_mask (Array) – Mask for conditioning.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Model output.
- Return type:
Array
- class gensbi.models.simformer.model.SimformerParams[source]#
Parameters for the Simformer model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for initialization.
dim_value (int) – Dimension of the value embeddings.
dim_id (int) – Dimension of the ID embeddings.
dim_condition (int) – Dimension of the condition embeddings.
dim_joint (int) – Total dimension of the joint embeddings.
fourier_features (int) – Number of Fourier features for time embedding.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
widening_factor (int) – Widening factor for the transformer.
qkv_features (int) – Number of features for QKV layers.
num_hidden_layers (int) – Number of hidden layers in the transformer.
- class gensbi.models.simformer.model.SimformerWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapper
This class is used to wrap around another model. We define a call method which returns the model output. Furthermore, we define a vector_field method which computes the vector field of the model, and a divergence method which computes the divergence of the model, in a form useful for diffrax. This is useful for ODE solvers that require the vector field and divergence of the model.
- Parameters:
model (Simformer)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, edge_mask=None)[source]#
This method defines how inputs should be passed through the wrapped model. Here, we’re assuming that the wrapped model takes both \(obs\) and \(t\) as input, along with additional keyword arguments.
- Parameters:
obs (Array) – input data to the model (batch_size, …).
t (Array) – time (batch_size).
cond (Array) – conditioning data to the model (batch_size, …).
obs_ids (Array) – observation ids (batch_size, obs_dim).
cond_ids (Array) – condition ids (batch_size, cond_dim).
conditioned (bool | Array) – whether to use conditioning or not.
edge_mask (Optional[Array]) – mask for edges.
- Returns:
model output.
- Return type:
Array
- conditioned(t, obs, obs_ids, cond, cond_ids, edge_mask=None)[source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Conditioned output.
- Return type:
Array