gensbi.models.simformer.embedding#
Classes#
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
Module Contents#
- class gensbi.models.simformer.embedding.GaussianFourierEmbedding(output_dim=128, learnable=False, *, rngs)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
output_dim (int)
learnable (bool)
rngs (flax.nnx.Rngs)
- class gensbi.models.simformer.embedding.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.float32)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
in_dim (int)
hidden_dim (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.models.simformer.embedding.SimpleTimeEmbedding[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- class gensbi.models.simformer.embedding.SinusoidalEmbedding(output_dim=128)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
output_dim (int)