gensbi.models.flux1 package#
Submodules#
gensbi.models.flux1.layers module#
- class gensbi.models.flux1.layers.DoubleStreamBlock(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- class gensbi.models.flux1.layers.ModulationOut(shift: jax.Array, scale: jax.Array, gate: jax.Array)[source]#
Bases:
object
- gate: Array#
- scale: Array#
- shift: Array#
- class gensbi.models.flux1.layers.SingleStreamBlock(*args: Any, **kwargs: Any)[source]#
Bases:
Module
A DiT block with parallel linear layers as described in https://arxiv.org/abs/2302.05442 and adapted modulation interface.
- gensbi.models.flux1.layers.timestep_embedding(t: Array, dim: int, max_period=10000, time_factor: float = 1000.0) Array [source]#
Generate timestep embeddings.
- Parameters:
t – a 1-D Tensor of N indices, one per batch element. These may be fractional.
dim – the dimension of the output.
max_period – controls the minimum frequency of the embeddings.
time_factor – Tensor of positional embeddings.
- Returns:
timestep embeddings.
gensbi.models.flux1.loss module#
- class gensbi.models.flux1.loss.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
gensbi.models.flux1.math module#
- gensbi.models.flux1.math.apply_rope(xq: Array, xk: Array, freqs_cis: Array) Tuple[Array, Array] [source]#
Apply rotary positional embeddings.
- Parameters:
xq (Array) – Query tensor.
xk (Array) – Key tensor.
freqs_cis (Array) – Frequency embeddings.
- Returns:
Transformed query and key tensors.
- Return type:
Tuple[Array, Array]
- gensbi.models.flux1.math.attention(q: Array, k: Array, v: Array, pe: Array | None = None, mask: Array | None = None) Array [source]#
Compute attention mechanism.
- Parameters:
q (Array) – Query tensor.
k (Array) – Key tensor.
v (Array) – Value tensor.
pe (Optional[Array]) – Positional encoding.
mask (Optional[Array]) – Attention mask.
- Returns:
Attention output.
- Return type:
Array
gensbi.models.flux1.model module#
- class gensbi.models.flux1.model.Flux(*args: Any, **kwargs: Any)[source]#
Bases:
Module
Transformer model for flow matching on sequences.
- class gensbi.models.flux1.model.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#
Bases:
object
- axes_dim: list[int]#
- cond_dim: int | None = None#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- obs_dim: int | None = None#
- param_dtype#
alias of
bfloat16
- qkv_bias: bool#
- qkv_multiplier: int = 1#
- rngs: Rngs#
- theta: int = 10000#
- use_rope: bool = True#
- vec_in_dim: int#
Module contents#
- class gensbi.models.flux1.Flux(*args: Any, **kwargs: Any)[source]#
Bases:
Module
Transformer model for flow matching on sequences.
- class gensbi.models.flux1.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
- class gensbi.models.flux1.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#
Bases:
object
- axes_dim: list[int]#
- cond_dim: int | None = None#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- obs_dim: int | None = None#
- param_dtype#
alias of
bfloat16
- qkv_bias: bool#
- qkv_multiplier: int = 1#
- rngs: Rngs#
- theta: int = 10000#
- use_rope: bool = True#
- vec_in_dim: int#