gensbi.models.simformer.embedding#

Classes#

GaussianFourierEmbedding

Base class for all neural network modules.

MLPEmbedder

Base class for all neural network modules.

SimpleTimeEmbedding

Base class for all neural network modules.

SinusoidalEmbedding

Base class for all neural network modules.

Module Contents#

class gensbi.models.simformer.embedding.GaussianFourierEmbedding(output_dim=128, learnable=False, *, rngs)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • output_dim (int)

  • learnable (bool)

  • rngs (flax.nnx.Rngs)

__call__(t)[source]#
B[source]#
output_dim = 128[source]#
class gensbi.models.simformer.embedding.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.float32)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • in_dim (int)

  • hidden_dim (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
Parameters:

x (jax.Array)

Return type:

jax.Array

in_layer[source]#
out_layer[source]#
p_skip[source]#
silu[source]#
class gensbi.models.simformer.embedding.SimpleTimeEmbedding[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
__call__(t)[source]#
class gensbi.models.simformer.embedding.SinusoidalEmbedding(output_dim=128)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:

output_dim (int)

__call__(t)[source]#
output_dim = 128[source]#