gensbi.models.simformer package#

Submodules#

gensbi.models.simformer.embedding module#

class gensbi.models.simformer.embedding.GaussianFourierEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.embedding.MLPEmbedder(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.embedding.SimpleTimeEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.embedding.SinusoidalEmbedding(*args: Any, **kwargs: Any)[source]#

Bases: Module

gensbi.models.simformer.loss module#

class gensbi.models.simformer.loss.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

gensbi.models.simformer.simformer module#

class gensbi.models.simformer.simformer.Simformer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.simformer.SimformerConditioner(*args: Any, **kwargs: Any)[source]#

Bases: Module

conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Conditioned output.

Return type:

Array

unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Unconditioned output.

Return type:

Array

class gensbi.models.simformer.simformer.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#

Bases: object

dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
dropout_rate: float = 0.0#
fourier_features: int = 128#
num_heads: int = 4#
num_hidden_layers: int = 1#
num_layers: int = 6#
qkv_features: int = 8#
rngs: Rngs#
widening_factor: int = 3#

gensbi.models.simformer.transformer module#

class gensbi.models.simformer.transformer.AttentionBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.transformer.DenseBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.transformer.Transformer(*args: Any, **kwargs: Any)[source]#

Bases: Module

A transformer stack.

Module contents#

class gensbi.models.simformer.Simformer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.simformer.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

class gensbi.models.simformer.SimformerConditioner(*args: Any, **kwargs: Any)[source]#

Bases: Module

conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Conditioned output.

Return type:

Array

unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Unconditioned output.

Return type:

Array

class gensbi.models.simformer.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#

Bases: object

dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
dropout_rate: float = 0.0#
fourier_features: int = 128#
num_heads: int = 4#
num_hidden_layers: int = 1#
num_layers: int = 6#
qkv_features: int = 8#
rngs: Rngs#
widening_factor: int = 3#