gensbi.models.simformer.transformer#

Classes#

AttentionBlock

Base class for all neural network modules.

DenseBlock

Base class for all neural network modules.

Transformer

A transformer stack.

Module Contents#

class gensbi.models.simformer.transformer.AttentionBlock(din, num_heads, features, skip_connection, rngs)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • din (int)

  • num_heads (int)

  • features (int)

  • skip_connection (bool)

  • rngs (flax.nnx.Rngs)

__call__(x, mask)[source]#
Parameters:
  • x (jax.numpy.ndarray)

  • mask (jax.numpy.ndarray | None)

Return type:

jax.numpy.ndarray

attn[source]#
layer_norm[source]#
skip_connection[source]#
class gensbi.models.simformer.transformer.DenseBlock(din, dcontext, num_hidden_layers, widening_factor, act, skip_connection, rngs)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • widening_factor (int)

  • act (Callable)

  • skip_connection (bool)

  • rngs (flax.nnx.Rngs)

__call__(x, context)[source]#
act[source]#
context_block[source]#
hidden_blocks[source]#
layer_norm[source]#
skip_connection[source]#
class gensbi.models.simformer.transformer.Transformer(din, dcontext, num_heads, num_layers, features, widening_factor=4, num_hidden_layers=1, act=jax.nn.gelu, skip_connection_attn=True, skip_connection_mlp=True, *, rngs)[source]#

Bases: flax.nnx.Module

A transformer stack.

Parameters:
  • din (int)

  • dcontext (int)

  • num_heads (int)

  • num_layers (int)

  • features (int)

  • widening_factor (int)

  • num_hidden_layers (int)

  • act (Callable)

  • skip_connection_attn (bool)

  • skip_connection_mlp (bool)

  • rngs (flax.nnx.Rngs)

__call__(inputs, context=None, mask=None)[source]#
Parameters:
  • inputs (jaxtyping.Array)

  • context (Optional[jaxtyping.Array])

  • mask (jaxtyping.Array | None)

Return type:

jax.Array

act[source]#
attention_blocks[source]#
dcontext[source]#
dense_blocks[source]#
din[source]#
layer_norm[source]#
num_heads[source]#
num_hidden_layers = 1[source]#
num_layers[source]#
rngs[source]#
skip_connection_attn = True[source]#
skip_connection_mlp = True[source]#
widening_factor = 4[source]#