gensbi.models.simformer.transformer#
Classes#
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
A transformer stack. |
Module Contents#
- class gensbi.models.simformer.transformer.AttentionBlock(din, num_heads, features, skip_connection, rngs)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
din (int)
num_heads (int)
features (int)
skip_connection (bool)
rngs (flax.nnx.Rngs)
- class gensbi.models.simformer.transformer.DenseBlock(din, dcontext, num_hidden_layers, widening_factor, act, skip_connection, rngs)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
widening_factor (int)
act (Callable)
skip_connection (bool)
rngs (flax.nnx.Rngs)
- class gensbi.models.simformer.transformer.Transformer(din, dcontext, num_heads, num_layers, features, widening_factor=4, num_hidden_layers=1, act=jax.nn.gelu, skip_connection_attn=True, skip_connection_mlp=True, *, rngs)[source]#
Bases:
flax.nnx.Module
A transformer stack.
- Parameters:
din (int)
dcontext (int)
num_heads (int)
num_layers (int)
features (int)
widening_factor (int)
num_hidden_layers (int)
act (Callable)
skip_connection_attn (bool)
skip_connection_mlp (bool)
rngs (flax.nnx.Rngs)
- __call__(inputs, context=None, mask=None)[source]#
- Parameters:
inputs (jaxtyping.Array)
context (Optional[jaxtyping.Array])
mask (jaxtyping.Array | None)
- Return type:
jax.Array