import jax
import jax.numpy as jnp
# from chex import Array
from jax import Array
from einops import rearrange
from flax import nnx
from jax.typing import DTypeLike, ArrayLike
from gensbi.models.autoencoders.commons import AutoEncoderParams, DiagonalGaussian
from flax.nnx import swish
[docs]
class AttnBlock2D(nnx.Module):
"""
2D Self-attention block for image or grid data.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
in_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
) -> None:
[docs]
self.in_channels = in_channels
[docs]
self.norm = nnx.GroupNorm(
num_groups=32,
num_features=in_channels,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.q = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.k = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.v = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.proj_out = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def attention(self, h_: Array) -> Array:
"""
Compute self-attention for 2D input.
Args:
h_ (Array): Input tensor of shape (batch, height, width, channels).
Returns:
Array: Output tensor after attention.
"""
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, h, w, c = q.shape
q = rearrange(q, "b h w c-> b (h w) 1 c")
k = rearrange(k, "b h w c-> b (h w) 1 c")
v = rearrange(v, "b h w c-> b (h w) 1 c")
# Calculate Attention
h_ = jax.nn.dot_product_attention(q, k, v)
return rearrange(h_, "b (h w) 1 c -> b h w c", h=h, w=w, c=c, b=b)
[docs]
def __call__(self, x: Array) -> Array:
"""
Forward pass for the attention block.
Args:
x (Array): Input tensor.
Returns:
Array: Output tensor after residual attention.
"""
return x + self.proj_out(self.attention(x))
[docs]
class ResnetBlock2D(nnx.Module):
"""
2D Residual block with optional channel up/downsampling.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
in_channels: int,
out_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
) -> None:
[docs]
self.in_channels = in_channels
[docs]
self.out_channels = in_channels if out_channels is None else out_channels
[docs]
self.norm1 = nnx.GroupNorm(
num_groups=32,
num_features=in_channels,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv1 = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.norm2 = nnx.GroupNorm(
num_groups=32,
num_features=out_channels,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv2 = nnx.Conv(
in_features=out_channels,
out_features=out_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nnx.Conv(
in_features=in_channels,
out_features=out_channels,
kernel_size=(1, 1),
strides=(1, 1),
padding=(0, 0),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def __call__(self, x: Array) -> Array:
"""
Forward pass for the residual block.
Args:
x (Array): Input tensor.
Returns:
Array: Output tensor after residual connection.
"""
h = x
h = self.norm1(h)
h = swish(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
[docs]
class Downsample2D(nnx.Module):
"""
2D Downsampling block using strided convolution.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
in_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
):
[docs]
self.conv = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(3, 3),
strides=(2, 2),
padding=(0, 0),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def __call__(self, x: Array) -> Array:
"""
Downsample the input tensor by a factor of 2.
Args:
x (Array): Input tensor of shape (batch, height, width, channels).
Returns:
Array: Downsampled tensor.
"""
# no padding for height and channel, padding for height and width
pad_width = ((0, 0), (0, 1), (0, 1), (0, 0))
x = jnp.pad(array=x, pad_width=pad_width, mode="constant", constant_values=0)
x = self.conv(x)
return x
[docs]
class Upsample2D(nnx.Module):
"""
2D Upsampling block using nearest-neighbor interpolation and convolution.
Args:
in_channels (int): Number of input channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
in_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
):
[docs]
self.conv = nnx.Conv(
in_features=in_channels,
out_features=in_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def __call__(self, x: Array) -> Array:
"""
Upsample the input tensor by a factor of 2.
Args:
x (Array): Input tensor of shape (batch, height, width, channels).
Returns:
Array: Upsampled tensor.
"""
# Assuming `x` is a 4D tensor with shape (batch, height, width, channels)
scale_factor = 2.0
b, h, w, c = x.shape
new_height = int(h * scale_factor)
new_width = int(w * scale_factor)
new_shape = (b, new_height, new_width, c)
# Resize using nearest-neighbor interpolation
x = jax.image.resize(x, new_shape, method="nearest")
x = self.conv(x)
return x
[docs]
class Encoder2D(nnx.Module):
"""
2D Encoder for autoencoder architectures.
Args:
resolution (int): Input image height/width.
in_channels (int): Number of input channels.
ch (int): Base number of channels.
ch_mult (list[int]): Channel multipliers for each resolution.
num_res_blocks (int): Number of residual blocks per resolution.
z_channels (int): Number of latent channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
) -> None:
[docs]
self.num_resolutions = len(ch_mult)
[docs]
self.num_res_blocks = num_res_blocks
[docs]
self.resolution = resolution
[docs]
self.in_channels = in_channels
# downsampling
[docs]
self.conv_in = nnx.Conv(
in_features=in_channels,
out_features=self.ch,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
[docs]
self.in_ch_mult = in_ch_mult
[docs]
self.down = nnx.Sequential()
block_in = self.ch
for i_level in range(self.num_resolutions):
block = nnx.Sequential()
attn = nnx.Sequential()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.layers.append(
ResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
rngs=rngs,
param_dtype=param_dtype,
)
)
block_in = block_out
down = nnx.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.Downsample2D = Downsample2D(
in_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
curr_res = curr_res // 2
self.down.layers.append(down)
# middle
[docs]
self.mid = nnx.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
self.mid.attn_1 = AttnBlock2D(
in_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
# end
[docs]
self.norm_out = nnx.GroupNorm(
num_groups=32,
num_features=block_in,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv_out = nnx.Conv(
in_features=block_in,
out_features=2 * z_channels,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def __call__(self, x: Array) -> Array:
"""
Forward pass for the encoder.
Args:
x (Array): Input tensor of shape (batch, height, width, channels).
Returns:
Array: Encoded latent representation.
"""
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down.layers[i_level].block.layers[i_block](hs[-1])
if len(self.down.layers[i_level].attn.layers) > 0:
h = self.down.layers[i_level].attn.layers[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down.layers[i_level].Downsample2D(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
[docs]
class Decoder2D(nnx.Module):
"""
2D Decoder for autoencoder architectures.
Args:
ch (int): Base number of channels.
out_ch (int): Number of output channels.
ch_mult (list[int]): Channel multipliers for each resolution.
num_res_blocks (int): Number of residual blocks per resolution.
in_channels (int): Number of input channels.
resolution (int): Output image height/width.
z_channels (int): Number of latent channels.
rngs (nnx.Rngs): Random number generators for parameter initialization.
param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16).
"""
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
rngs: nnx.Rngs,
param_dtype: DTypeLike = jnp.bfloat16,
):
[docs]
self.num_resolutions = len(ch_mult)
[docs]
self.num_res_blocks = num_res_blocks
[docs]
self.resolution = resolution
[docs]
self.in_channels = in_channels
[docs]
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
[docs]
self.z_shape = (1, curr_res, curr_res, z_channels) #(1, z_channels, curr_res, curr_res)
# z to block_in
[docs]
self.conv_in = nnx.Conv(
in_features=z_channels,
out_features=block_in,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
# middle
[docs]
self.mid = nnx.Module()
self.mid.block_1 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
self.mid.attn_1 = AttnBlock2D(
in_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
self.mid.block_2 = ResnetBlock2D(
in_channels=block_in,
out_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
# upsampling
[docs]
self.up = nnx.Sequential()
for i_level in reversed(range(self.num_resolutions)):
block = nnx.Sequential()
attn = nnx.Sequential()
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.layers.append(
ResnetBlock2D(
in_channels=block_in,
out_channels=block_out,
rngs=rngs,
param_dtype=param_dtype,
)
)
block_in = block_out
up = nnx.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.Upsample2D = Upsample2D(
in_channels=block_in,
rngs=rngs,
param_dtype=param_dtype,
)
curr_res = curr_res * 2
self.up.layers.insert(0, up)
# end
[docs]
self.norm_out = nnx.GroupNorm(
num_groups=32,
num_features=block_in,
epsilon=1e-6,
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
self.conv_out = nnx.Conv(
in_features=block_in,
out_features=out_ch,
kernel_size=(3, 3),
strides=(1, 1),
padding=(1, 1),
rngs=rngs,
param_dtype=param_dtype,
)
[docs]
def __call__(self, z: Array) -> Array:
"""
Forward pass for the decoder.
Args:
z (Array): Latent tensor of shape (batch, latent_height, latent_width, latent_channels).
Returns:
Array: Reconstructed output tensor.
"""
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h)
h = self.mid.attn_1(h)
h = self.mid.block_2(h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up.layers[i_level].block.layers[i_block](h)
if len(self.up.layers[i_level].attn.layers) > 0:
h = self.up.layers[i_level].attn.layers[i_block](h)
if i_level != 0:
h = self.up.layers[i_level].Upsample2D(h)
# end
h = self.norm_out(h)
h = swish(h)
h = self.conv_out(h)
return h
[docs]
class AutoEncoder2D(nnx.Module):
"""
2D Autoencoder model with Gaussian latent space.
Args:
params (AutoEncoderParams): Configuration parameters for the autoencoder.
"""
def __init__(
self,
params: AutoEncoderParams,
):
[docs]
self.Encoder2D = Encoder2D(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.Decoder2D = Decoder2D(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
rngs=params.rngs,
param_dtype=params.param_dtype,
)
[docs]
self.reg = DiagonalGaussian()
[docs]
self.scale_factor = nnx.Param(params.scale_factor)
[docs]
self.shift_factor = nnx.Param(params.shift_factor)
[docs]
def encode(self, x: Array, key) -> Array:
"""
Encode input data into the latent space.
Args:
x (Array): Input tensor.
key (Array): PRNG key for sampling the latent variable.
Returns:
Array: Latent representation.
"""
z = self.reg(self.Encoder2D(x), key)
z = self.scale_factor * (z - self.shift_factor)
return z
[docs]
def decode(self, z: Array) -> Array:
"""
Decode latent representation back to data space.
Args:
z (Array): Latent tensor.
Returns:
Array: Reconstructed output.
"""
z = z / self.scale_factor + self.shift_factor
z = self.Decoder2D(z)
return z
[docs]
def __call__(self, x: Array, key) -> Array:
"""
Forward pass: encode and then decode the input.
Args:
x (Array): Input tensor.
key: PRNG key for sampling.
Returns:
Array: Reconstructed output.
"""
return self.decode(self.encode(x, key))