gensbi.models.flux1 package#

Submodules#

gensbi.models.flux1.layers module#

class gensbi.models.flux1.layers.DoubleStreamBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.EmbedND(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.LastLayer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.MLPEmbedder(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.Modulation(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.ModulationOut(shift: jax.Array, scale: jax.Array, gate: jax.Array)[source]#

Bases: object

gate: Array#
scale: Array#
shift: Array#
class gensbi.models.flux1.layers.QKNorm(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.SelfAttention(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.flux1.layers.SingleStreamBlock(*args: Any, **kwargs: Any)[source]#

Bases: Module

A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.

gensbi.models.flux1.layers.timestep_embedding(t: Array, dim: int, max_period=10000, time_factor: float = 1000.0) Array[source]#

Generate timestep embeddings.

Parameters:
  • t – a 1-D Tensor of N indices, one per batch element. These may be fractional.

  • dim – the dimension of the output.

  • max_period – controls the minimum frequency of the embeddings.

  • time_factor – Tensor of positional embeddings.

Returns:

timestep embeddings.

gensbi.models.flux1.loss module#

class gensbi.models.flux1.loss.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

gensbi.models.flux1.math module#

gensbi.models.flux1.math.apply_rope(xq: Array, xk: Array, freqs_cis: Array) Tuple[Array, Array][source]#

Apply rotary positional embeddings.

Parameters:
  • xq (Array) – Query tensor.

  • xk (Array) – Key tensor.

  • freqs_cis (Array) – Frequency embeddings.

Returns:

Transformed query and key tensors.

Return type:

Tuple[Array, Array]

gensbi.models.flux1.math.attention(q: Array, k: Array, v: Array, pe: Array | None = None, mask: Array | None = None) Array[source]#

Compute attention mechanism.

Parameters:
  • q (Array) – Query tensor.

  • k (Array) – Key tensor.

  • v (Array) – Value tensor.

  • pe (Optional[Array]) – Positional encoding.

  • mask (Optional[Array]) – Attention mask.

Returns:

Attention output.

Return type:

Array

gensbi.models.flux1.math.rope(pos: Array, dim: int, theta: int) Array[source]#

Compute rotary positional embeddings.

Parameters:
  • pos (Array) – Position tensor.

  • dim (int) – Dimension of embeddings.

  • theta (int) – Scaling factor.

Returns:

Rotary embeddings.

Return type:

Array

gensbi.models.flux1.model module#

class gensbi.models.flux1.model.Flux(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer model for flow matching on sequences.

class gensbi.models.flux1.model.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#

Bases: object

axes_dim: list[int]#
cond_dim: int | None = None#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int | None = None#
param_dtype#

alias of bfloat16

qkv_bias: bool#
qkv_multiplier: int = 1#
rngs: Rngs#
theta: int = 10000#
use_rope: bool = True#
vec_in_dim: int#
class gensbi.models.flux1.model.Identity(*args: Any, **kwargs: Any)[source]#

Bases: Module

Module contents#

class gensbi.models.flux1.Flux(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer model for flow matching on sequences.

class gensbi.models.flux1.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

class gensbi.models.flux1.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#

Bases: object

axes_dim: list[int]#
cond_dim: int | None = None#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int | None = None#
param_dtype#

alias of bfloat16

qkv_bias: bool#
qkv_multiplier: int = 1#
rngs: Rngs#
theta: int = 10000#
use_rope: bool = True#
vec_in_dim: int#