gensbi.models.flux1.layers#
Classes#
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
Base class for all neural network modules. |
|
A DiT block with parallel linear layers as described in |
Functions#
|
Generate timestep embeddings. |
Module Contents#
- class gensbi.models.flux1.layers.DoubleStreamBlock(hidden_size, num_heads, mlp_ratio, rngs, qkv_features=None, param_dtype=jnp.bfloat16, qkv_bias=False)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
hidden_size (int)
num_heads (int)
mlp_ratio (float)
rngs (flax.nnx.Rngs)
qkv_features (int | None)
param_dtype (jax.typing.DTypeLike)
qkv_bias (bool)
- __call__(obs, cond, vec, pe=None, mask=None)[source]#
- Parameters:
obs (jax.Array)
cond (jax.Array)
vec (jax.Array)
pe (jax.Array | None)
mask (jax.Array | None)
- Return type:
tuple[jax.Array, jax.Array]
- class gensbi.models.flux1.layers.EmbedND(dim, theta, axes_dim)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
dim (int)
theta (int)
axes_dim (list[int])
- class gensbi.models.flux1.layers.LastLayer(hidden_size, patch_size, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
hidden_size (int)
patch_size (int)
out_channels (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.models.flux1.layers.MLPEmbedder(in_dim, hidden_dim, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
in_dim (int)
hidden_dim (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.models.flux1.layers.Modulation(dim, double, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
dim (int)
double (bool)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- __call__(vec)[source]#
- Parameters:
vec (jax.Array)
- Return type:
tuple[ModulationOut, ModulationOut | None]
- class gensbi.models.flux1.layers.QKNorm(dim, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
dim (int)
rngs (flax.nnx.Rngs)
param_dtype (jax.typing.DTypeLike)
- class gensbi.models.flux1.layers.SelfAttention(dim, rngs, qkv_features=None, param_dtype=jnp.bfloat16, num_heads=8, qkv_bias=False)[source]#
Bases:
flax.nnx.Module
Base class for all neural network modules.
Layers and models should subclass this class.
Module
’s can contain submodules, and in this way can be nested in a tree structure. Submodules can be assigned as regular attributes inside the__init__
method.You can define arbitrary “forward pass” methods on your
Module
subclass. While no methods are special-cased,__call__
is a popular choice since you can call theModule
directly:>>> from flax import nnx >>> import jax.numpy as jnp >>> class Model(nnx.Module): ... def __init__(self, rngs): ... self.linear1 = nnx.Linear(2, 3, rngs=rngs) ... self.linear2 = nnx.Linear(3, 4, rngs=rngs) ... def __call__(self, x): ... x = self.linear1(x) ... x = nnx.relu(x) ... x = self.linear2(x) ... return x >>> x = jnp.ones((1, 2)) >>> model = Model(rngs=nnx.Rngs(0)) >>> y = model(x)
- Parameters:
dim (int)
rngs (flax.nnx.Rngs)
qkv_features (int | None)
param_dtype (jax.typing.DTypeLike)
num_heads (int)
qkv_bias (bool)
- class gensbi.models.flux1.layers.SingleStreamBlock(hidden_size, num_heads, rngs, qkv_features=None, param_dtype=jnp.bfloat16, mlp_ratio=4.0, qk_scale=None)[source]#
Bases:
flax.nnx.Module
A DiT block with parallel linear layers as described in arXiv:2302.05442 and adapted modulation interface.
- Parameters:
hidden_size (int)
num_heads (int)
rngs (flax.nnx.Rngs)
qkv_features (int | None)
param_dtype (jax.typing.DTypeLike)
mlp_ratio (float)
qk_scale (float | None)
- __call__(x, vec, pe=None, mask=None)[source]#
- Parameters:
x (jax.Array)
vec (jax.Array)
pe (jax.Array | None)
mask (jax.Array | None)
- Return type:
jax.Array
- gensbi.models.flux1.layers.timestep_embedding(t, dim, max_period=10000, time_factor=1000.0)[source]#
Generate timestep embeddings.
- Parameters:
t (jax.Array) – a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim (int) – the dimension of the output.
max_period – controls the minimum frequency of the embeddings.
time_factor (float) – Tensor of positional embeddings.
- Returns:
timestep embeddings.
- Return type:
jax.Array