gensbi.models.flux1.layers#

Classes#

DoubleStreamBlock

Base class for all neural network modules.

EmbedND

Base class for all neural network modules.

LastLayer

Base class for all neural network modules.

MLPEmbedder

Base class for all neural network modules.

Modulation

Base class for all neural network modules.

ModulationOut

QKNorm

Base class for all neural network modules.

SelfAttention

Base class for all neural network modules.

SingleStreamBlock

A DiT block with parallel linear layers as described in

Functions#

timestep_embedding(t, dim[, max_period, time_factor])

Generate timestep embeddings.

Module Contents#

class gensbi.models.flux1.layers.DoubleStreamBlock(hidden_size, num_heads, mlp_ratio, rngs, qkv_features=None, param_dtype=jnp.bfloat16, qkv_bias=False)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • hidden_size (int)

  • num_heads (int)

  • mlp_ratio (float)

  • rngs (flax.nnx.Rngs)

  • qkv_features (int | None)

  • param_dtype (jax.typing.DTypeLike)

  • qkv_bias (bool)

__call__(obs, cond, vec, pe=None, mask=None)[source]#
Parameters:
  • obs (jax.Array)

  • cond (jax.Array)

  • vec (jax.Array)

  • pe (jax.Array | None)

  • mask (jax.Array | None)

Return type:

tuple[jax.Array, jax.Array]

cond_attn[source]#
cond_mlp[source]#
cond_mod[source]#
cond_norm1[source]#
cond_norm2[source]#
hidden_size[source]#
num_heads[source]#
obs_attn[source]#
obs_mlp[source]#
obs_mod[source]#
obs_norm1[source]#
obs_norm2[source]#
qkv_features = None[source]#
class gensbi.models.flux1.layers.EmbedND(dim, theta, axes_dim)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • dim (int)

  • theta (int)

  • axes_dim (list[int])

__call__(ids)[source]#
Parameters:

ids (jax.Array)

Return type:

jax.Array

axes_dim[source]#
dim[source]#
theta[source]#
class gensbi.models.flux1.layers.LastLayer(hidden_size, patch_size, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • hidden_size (int)

  • patch_size (int)

  • out_channels (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x, vec)[source]#
Parameters:
  • x (jax.Array)

  • vec (jax.Array)

Return type:

jax.Array

adaLN_modulation[source]#
linear[source]#
norm_final[source]#
class gensbi.models.flux1.layers.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • in_dim (int)

  • hidden_dim (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(x)[source]#
Parameters:

x (jax.Array)

Return type:

jax.Array

in_layer[source]#
out_layer[source]#
silu[source]#
class gensbi.models.flux1.layers.Modulation(dim, double, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • dim (int)

  • double (bool)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(vec)[source]#
Parameters:

vec (jax.Array)

Return type:

tuple[ModulationOut, ModulationOut | None]

is_double[source]#
lin[source]#
multiplier = 6[source]#
class gensbi.models.flux1.layers.ModulationOut[source]#
gate: jax.Array[source]#
scale: jax.Array[source]#
shift: jax.Array[source]#
class gensbi.models.flux1.layers.QKNorm(dim, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • dim (int)

  • rngs (flax.nnx.Rngs)

  • param_dtype (jax.typing.DTypeLike)

__call__(q, k, v)[source]#
Parameters:
  • q (jax.Array)

  • k (jax.Array)

  • v (jax.Array)

Return type:

tuple[jax.Array, jax.Array]

key_norm[source]#
query_norm[source]#
class gensbi.models.flux1.layers.SelfAttention(dim, rngs, qkv_features=None, param_dtype=jnp.bfloat16, num_heads=8, qkv_bias=False)[source]#

Bases: flax.nnx.Module

Base class for all neural network modules.

Layers and models should subclass this class.

Module’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the __init__ method.

You can define arbitrary “forward pass” methods on your Module subclass. While no methods are special-cased, __call__ is a popular choice since you can call the Module directly:

>>> from flax import nnx
>>> import jax.numpy as jnp

>>> class Model(nnx.Module):
...   def __init__(self, rngs):
...     self.linear1 = nnx.Linear(2, 3, rngs=rngs)
...     self.linear2 = nnx.Linear(3, 4, rngs=rngs)
...   def __call__(self, x):
...     x = self.linear1(x)
...     x = nnx.relu(x)
...     x = self.linear2(x)
...     return x

>>> x = jnp.ones((1, 2))
>>> model = Model(rngs=nnx.Rngs(0))
>>> y = model(x)
Parameters:
  • dim (int)

  • rngs (flax.nnx.Rngs)

  • qkv_features (int | None)

  • param_dtype (jax.typing.DTypeLike)

  • num_heads (int)

  • qkv_bias (bool)

__call__(x, pe, mask=None)[source]#
Parameters:
  • x (jax.Array)

  • pe (jax.Array)

  • mask (jax.Array | None)

Return type:

jax.Array

norm[source]#
num_heads = 8[source]#
proj[source]#
qkv[source]#
class gensbi.models.flux1.layers.SingleStreamBlock(hidden_size, num_heads, rngs, qkv_features=None, param_dtype=jnp.bfloat16, mlp_ratio=4.0, qk_scale=None)[source]#

Bases: flax.nnx.Module

A DiT block with parallel linear layers as described in arXiv:2302.05442 and adapted modulation interface.

Parameters:
  • hidden_size (int)

  • num_heads (int)

  • rngs (flax.nnx.Rngs)

  • qkv_features (int | None)

  • param_dtype (jax.typing.DTypeLike)

  • mlp_ratio (float)

  • qk_scale (float | None)

__call__(x, vec, pe=None, mask=None)[source]#
Parameters:
  • x (jax.Array)

  • vec (jax.Array)

  • pe (jax.Array | None)

  • mask (jax.Array | None)

Return type:

jax.Array

hidden_dim[source]#
hidden_size[source]#
linear1[source]#
linear2[source]#
mlp_act[source]#
mlp_hidden_dim[source]#
modulation[source]#
norm[source]#
num_heads[source]#
pre_norm[source]#
scale[source]#
gensbi.models.flux1.layers.timestep_embedding(t, dim, max_period=10000, time_factor=1000.0)[source]#

Generate timestep embeddings.

Parameters:
  • t (jax.Array) – a 1-D Tensor of N indices, one per batch element. These may be fractional.

  • dim (int) – the dimension of the output.

  • max_period – controls the minimum frequency of the embeddings.

  • time_factor (float) – Tensor of positional embeddings.

Returns:

timestep embeddings.

Return type:

jax.Array