from dataclasses import dataclass
from typing import Union
import jax
import jax.numpy as jnp
from jax import Array
from flax import nnx
from jax.typing import DTypeLike
from einops import repeat, rearrange
from gensbi.models.flux1.layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time
#TODO enforce rope usage, remove unused code
@dataclass
[docs]
class FluxParams:
"""Parameters for the Flux model.
Args:
in_channels (int): Number of input channels.
vec_in_dim (Union[int, None]): Dimension of the vector input, if applicable.
context_in_dim (int): Dimension of the context input.
mlp_ratio (float): Ratio for the MLP layers.
num_heads (int): Number of attention heads.
depth (int): Number of double stream blocks.
depth_single_blocks (int): Number of single stream blocks.
axes_dim (list[int]): Dimensions of the axes for positional encoding.
qkv_bias (bool): Whether to use bias in QKV layers.
rngs (nnx.Rngs): Random number generators for initialization.
obs_dim (int): Observation dimension.
cond_dim (int): Condition dimension.
theta (int): Scaling factor for positional encoding.
guidance_embed (bool): Whether to use guidance embedding.
qkv_multiplier (int): Multiplier for QKV features.
param_dtype (DTypeLike): Data type for model parameters.
"""
[docs]
vec_in_dim: Union[int, None]
[docs]
context_in_dim: int
[docs]
depth_single_blocks: int
[docs]
obs_dim: int # observation dimension
[docs]
cond_dim: int # condition dimension
[docs]
guidance_embed: bool = False
[docs]
qkv_multiplier: int = 1
[docs]
param_dtype: DTypeLike = jnp.bfloat16
[docs]
def __post_init__(self):
self.hidden_size = int(
jnp.sum(jnp.asarray(self.axes_dim, dtype=jnp.int32))
* self.qkv_multiplier
* self.num_heads
)
self.qkv_features = self.hidden_size // self.qkv_multiplier
[docs]
class Identity(nnx.Module):
[docs]
def __call__(self, x: Array) -> Array:
return x
[docs]
class Flux(nnx.Module):
"""
Transformer model for flow matching on sequences.
"""
def __init__(self, params: FluxParams):
[docs]
self.in_channels = params.in_channels
[docs]
self.out_channels = params.in_channels
[docs]
self.hidden_size = params.hidden_size
[docs]
self.qkv_features = params.qkv_features
pe_dim = self.qkv_features // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
[docs]
self.num_heads = params.num_heads
[docs]
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
[docs]
self.obs_in = nnx.Linear(
in_features=self.in_channels,
out_features=self.hidden_size,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.time_in = MLPEmbedder(
in_dim=256,
hidden_dim=self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.vector_in = (
MLPEmbedder(
params.vec_in_dim,
self.hidden_size,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
if params.guidance_embed
else Identity()
)
[docs]
self.cond_in = nnx.Linear(
in_features=params.context_in_dim,
out_features=self.hidden_size,
use_bias=True,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.condition_embedding = nnx.Param(
0.01 * jnp.ones((1, self.hidden_size), dtype=params.param_dtype)
)
[docs]
self.condition_null = nnx.Param(
jax.random.normal(
params.rngs.cond(),
(1, params.cond_dim, self.hidden_size),
dtype=params.param_dtype,
)
)
[docs]
self.double_blocks = nnx.Sequential(
*[
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_features=self.qkv_features,
qkv_bias=params.qkv_bias,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth)
]
)
[docs]
self.single_blocks = nnx.Sequential(
*[
SingleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_features=self.qkv_features,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
for _ in range(params.depth_single_blocks)
]
)
[docs]
self.final_layer = LastLayer(
self.hidden_size,
1,
self.out_channels,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True,
guidance: Array | None = None,
) -> Array:
# assumes obs, cond, obs_ids, cond_ids have shape (B, F, C)
# assumes t has shape (B,) or (B, 1)
obs = jnp.asarray(obs, dtype=self.params.param_dtype)
cond = jnp.asarray(cond, dtype=self.params.param_dtype)
t = jnp.asarray(t, dtype=self.params.param_dtype)
# obs = _expand_dims(obs)
# cond = _expand_dims(cond)
if obs.ndim != 3 or cond.ndim != 3:
raise ValueError("Input obs and cond tensors must have 3 dimensions, got {} and {}".format(obs.ndim, cond.ndim))
# running on sequences obs
obs = self.obs_in(obs)
vec = self.time_in(timestep_embedding(t, 256))
conditioned = jnp.asarray(conditioned, dtype=jnp.bool_) # type: ignore
conditioned_int = jnp.asarray(conditioned, dtype=jnp.int32)[..., None] # type: ignore
condition_embedding = self.condition_embedding * (1 - conditioned_int)
vec = vec + condition_embedding # we add the condition embedding to the vector
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.vector_in(guidance)
cond_processed = self.cond_in(cond) # (B, F, H)
cond_null = repeat(self.condition_null.value, "1 h c -> b h c", b=obs.shape[0]) # type: ignore
cond = jnp.where(
conditioned[..., None, None], cond_processed, cond_null
) # we replace the condition with a null vector if not conditioned
ids = jnp.concatenate((cond_ids, obs_ids), axis=1)
pe = self.pe_embedder(ids)
for block in self.double_blocks.layers:
obs, cond = block(obs=obs, cond=cond, vec=vec, pe=pe)
obs = jnp.concatenate((cond, obs), axis=1)
for block in self.single_blocks.layers:
obs = block(obs, vec=vec, pe=pe)
obs = obs[:, cond.shape[1] :, ...]
obs = self.final_layer(obs, vec) # (N, T, patch_size ** 2 * out_channels)
return obs
[docs]
class FluxWrapper(ModelWrapper):
def __init__(self, model):
super().__init__(model)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True,
guidance: Array | None = None,
) -> Array:
obs = _expand_dims(obs)
# t = self._expand_time(t)
cond = _expand_dims(cond)
obs_ids = _expand_dims(obs_ids)
cond_ids = _expand_dims(cond_ids)
return self.model(
obs=obs,
t=t,
cond=cond,
obs_ids=obs_ids,
cond_ids=cond_ids,
conditioned=conditioned,
guidance=guidance,
)