gensbi.models.simformer package#
Submodules#
gensbi.models.simformer.embedding module#
- class gensbi.models.simformer.embedding.GaussianFourierEmbedding(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class gensbi.models.simformer.embedding.MLPEmbedder(*args: Any, **kwargs: Any)[source]#
Bases:
Module
gensbi.models.simformer.loss module#
- class gensbi.models.simformer.loss.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
gensbi.models.simformer.simformer module#
- class gensbi.models.simformer.simformer.SimformerConditioner(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Conditioned output.
- Return type:
Array
- unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.simformer.simformer.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#
Bases:
object
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- dropout_rate: float = 0.0#
- fourier_features: int = 128#
- num_heads: int = 4#
- num_layers: int = 6#
- qkv_features: int = 8#
- rngs: Rngs#
- widening_factor: int = 3#
gensbi.models.simformer.transformer module#
- class gensbi.models.simformer.transformer.AttentionBlock(*args: Any, **kwargs: Any)[source]#
Bases:
Module
Module contents#
- class gensbi.models.simformer.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
- class gensbi.models.simformer.SimformerConditioner(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Conditioned output.
- Return type:
Array
- unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.simformer.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#
Bases:
object
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- dropout_rate: float = 0.0#
- fourier_features: int = 128#
- num_heads: int = 4#
- num_layers: int = 6#
- qkv_features: int = 8#
- rngs: Rngs#
- widening_factor: int = 3#