Source code for gensbi.models.autoencoders.autoencoder_2d

import jax
import jax.numpy as jnp

# from chex import Array
from jax import Array
from einops import rearrange
from flax import nnx
from jax.typing import DTypeLike, ArrayLike


from gensbi.models.autoencoders.commons import AutoEncoderParams, DiagonalGaussian
from flax.nnx import swish




[docs] class AttnBlock2D(nnx.Module): """ 2D Self-attention block for image or grid data. Args: in_channels (int): Number of input channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, in_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ) -> None:
[docs] self.in_channels = in_channels
[docs] self.norm = nnx.GroupNorm( num_groups=32, num_features=in_channels, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.q = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] self.k = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] self.v = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] self.proj_out = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] def attention(self, h_: Array) -> Array: """ Compute self-attention for 2D input. Args: h_ (Array): Input tensor of shape (batch, height, width, channels). Returns: Array: Output tensor after attention. """ h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, h, w, c = q.shape q = rearrange(q, "b h w c-> b (h w) 1 c") k = rearrange(k, "b h w c-> b (h w) 1 c") v = rearrange(v, "b h w c-> b (h w) 1 c") # Calculate Attention h_ = jax.nn.dot_product_attention(q, k, v) return rearrange(h_, "b (h w) 1 c -> b h w c", h=h, w=w, c=c, b=b)
[docs] def __call__(self, x: Array) -> Array: """ Forward pass for the attention block. Args: x (Array): Input tensor. Returns: Array: Output tensor after residual attention. """ return x + self.proj_out(self.attention(x))
[docs] class ResnetBlock2D(nnx.Module): """ 2D Residual block with optional channel up/downsampling. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, in_channels: int, out_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ) -> None:
[docs] self.in_channels = in_channels
[docs] self.out_channels = in_channels if out_channels is None else out_channels
[docs] self.norm1 = nnx.GroupNorm( num_groups=32, num_features=in_channels, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.conv1 = nnx.Conv( in_features=in_channels, out_features=out_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] self.norm2 = nnx.GroupNorm( num_groups=32, num_features=out_channels, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.conv2 = nnx.Conv( in_features=out_channels, out_features=out_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
if self.in_channels != self.out_channels: self.nin_shortcut = nnx.Conv( in_features=in_channels, out_features=out_channels, kernel_size=(1, 1), strides=(1, 1), padding=(0, 0), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x: Array) -> Array: """ Forward pass for the residual block. Args: x (Array): Input tensor. Returns: Array: Output tensor after residual connection. """ h = x h = self.norm1(h) h = swish(h) h = self.conv1(h) h = self.norm2(h) h = swish(h) h = self.conv2(h) if self.in_channels != self.out_channels: x = self.nin_shortcut(x) return x + h
[docs] class Downsample2D(nnx.Module): """ 2D Downsampling block using strided convolution. Args: in_channels (int): Number of input channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, in_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.conv = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(3, 3), strides=(2, 2), padding=(0, 0), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x: Array) -> Array: """ Downsample the input tensor by a factor of 2. Args: x (Array): Input tensor of shape (batch, height, width, channels). Returns: Array: Downsampled tensor. """ # no padding for height and channel, padding for height and width pad_width = ((0, 0), (0, 1), (0, 1), (0, 0)) x = jnp.pad(array=x, pad_width=pad_width, mode="constant", constant_values=0) x = self.conv(x) return x
[docs] class Upsample2D(nnx.Module): """ 2D Upsampling block using nearest-neighbor interpolation and convolution. Args: in_channels (int): Number of input channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, in_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.conv = nnx.Conv( in_features=in_channels, out_features=in_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x: Array) -> Array: """ Upsample the input tensor by a factor of 2. Args: x (Array): Input tensor of shape (batch, height, width, channels). Returns: Array: Upsampled tensor. """ # Assuming `x` is a 4D tensor with shape (batch, height, width, channels) scale_factor = 2.0 b, h, w, c = x.shape new_height = int(h * scale_factor) new_width = int(w * scale_factor) new_shape = (b, new_height, new_width, c) # Resize using nearest-neighbor interpolation x = jax.image.resize(x, new_shape, method="nearest") x = self.conv(x) return x
[docs] class Encoder2D(nnx.Module): """ 2D Encoder for autoencoder architectures. Args: resolution (int): Input image height/width. in_channels (int): Number of input channels. ch (int): Base number of channels. ch_mult (list[int]): Channel multipliers for each resolution. num_res_blocks (int): Number of residual blocks per resolution. z_channels (int): Number of latent channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, resolution: int, in_channels: int, ch: int, ch_mult: list[int], num_res_blocks: int, z_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ) -> None:
[docs] self.ch = ch
[docs] self.num_resolutions = len(ch_mult)
[docs] self.num_res_blocks = num_res_blocks
[docs] self.resolution = resolution
[docs] self.in_channels = in_channels
# downsampling
[docs] self.conv_in = nnx.Conv( in_features=in_channels, out_features=self.ch, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult)
[docs] self.in_ch_mult = in_ch_mult
[docs] self.down = nnx.Sequential()
block_in = self.ch for i_level in range(self.num_resolutions): block = nnx.Sequential() attn = nnx.Sequential() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks): block.layers.append( ResnetBlock2D( in_channels=block_in, out_channels=block_out, rngs=rngs, param_dtype=param_dtype, ) ) block_in = block_out down = nnx.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.Downsample2D = Downsample2D( in_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) curr_res = curr_res // 2 self.down.layers.append(down) # middle
[docs] self.mid = nnx.Module()
self.mid.block_1 = ResnetBlock2D( in_channels=block_in, out_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) self.mid.attn_1 = AttnBlock2D( in_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) self.mid.block_2 = ResnetBlock2D( in_channels=block_in, out_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) # end
[docs] self.norm_out = nnx.GroupNorm( num_groups=32, num_features=block_in, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.conv_out = nnx.Conv( in_features=block_in, out_features=2 * z_channels, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, x: Array) -> Array: """ Forward pass for the encoder. Args: x (Array): Input tensor of shape (batch, height, width, channels). Returns: Array: Encoded latent representation. """ # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down.layers[i_level].block.layers[i_block](hs[-1]) if len(self.down.layers[i_level].attn.layers) > 0: h = self.down.layers[i_level].attn.layers[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down.layers[i_level].Downsample2D(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h
[docs] class Decoder2D(nnx.Module): """ 2D Decoder for autoencoder architectures. Args: ch (int): Base number of channels. out_ch (int): Number of output channels. ch_mult (list[int]): Channel multipliers for each resolution. num_res_blocks (int): Number of residual blocks per resolution. in_channels (int): Number of input channels. resolution (int): Output image height/width. z_channels (int): Number of latent channels. rngs (nnx.Rngs): Random number generators for parameter initialization. param_dtype (DTypeLike): Data type for parameters (default: jnp.bfloat16). """ def __init__( self, ch: int, out_ch: int, ch_mult: list[int], num_res_blocks: int, in_channels: int, resolution: int, z_channels: int, rngs: nnx.Rngs, param_dtype: DTypeLike = jnp.bfloat16, ):
[docs] self.ch = ch
[docs] self.num_resolutions = len(ch_mult)
[docs] self.num_res_blocks = num_res_blocks
[docs] self.resolution = resolution
[docs] self.in_channels = in_channels
[docs] self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1)
[docs] self.z_shape = (1, curr_res, curr_res, z_channels) #(1, z_channels, curr_res, curr_res)
# z to block_in
[docs] self.conv_in = nnx.Conv( in_features=z_channels, out_features=block_in, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
# middle
[docs] self.mid = nnx.Module()
self.mid.block_1 = ResnetBlock2D( in_channels=block_in, out_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) self.mid.attn_1 = AttnBlock2D( in_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) self.mid.block_2 = ResnetBlock2D( in_channels=block_in, out_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) # upsampling
[docs] self.up = nnx.Sequential()
for i_level in reversed(range(self.num_resolutions)): block = nnx.Sequential() attn = nnx.Sequential() block_out = ch * ch_mult[i_level] for _ in range(self.num_res_blocks + 1): block.layers.append( ResnetBlock2D( in_channels=block_in, out_channels=block_out, rngs=rngs, param_dtype=param_dtype, ) ) block_in = block_out up = nnx.Module() up.block = block up.attn = attn if i_level != 0: up.Upsample2D = Upsample2D( in_channels=block_in, rngs=rngs, param_dtype=param_dtype, ) curr_res = curr_res * 2 self.up.layers.insert(0, up) # end
[docs] self.norm_out = nnx.GroupNorm( num_groups=32, num_features=block_in, epsilon=1e-6, rngs=rngs, param_dtype=param_dtype, )
[docs] self.conv_out = nnx.Conv( in_features=block_in, out_features=out_ch, kernel_size=(3, 3), strides=(1, 1), padding=(1, 1), rngs=rngs, param_dtype=param_dtype, )
[docs] def __call__(self, z: Array) -> Array: """ Forward pass for the decoder. Args: z (Array): Latent tensor of shape (batch, latent_height, latent_width, latent_channels). Returns: Array: Reconstructed output tensor. """ # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h) h = self.mid.attn_1(h) h = self.mid.block_2(h) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up.layers[i_level].block.layers[i_block](h) if len(self.up.layers[i_level].attn.layers) > 0: h = self.up.layers[i_level].attn.layers[i_block](h) if i_level != 0: h = self.up.layers[i_level].Upsample2D(h) # end h = self.norm_out(h) h = swish(h) h = self.conv_out(h) return h
[docs] class AutoEncoder2D(nnx.Module): """ 2D Autoencoder model with Gaussian latent space. Args: params (AutoEncoderParams): Configuration parameters for the autoencoder. """ def __init__( self, params: AutoEncoderParams, ):
[docs] self.Encoder2D = Encoder2D( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.Decoder2D = Decoder2D( resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, rngs=params.rngs, param_dtype=params.param_dtype, )
[docs] self.reg = DiagonalGaussian()
[docs] self.scale_factor = nnx.Param(params.scale_factor)
[docs] self.shift_factor = nnx.Param(params.shift_factor)
[docs] def encode(self, x: Array, key) -> Array: """ Encode input data into the latent space. Args: x (Array): Input tensor. key (Array): PRNG key for sampling the latent variable. Returns: Array: Latent representation. """ z = self.reg(self.Encoder2D(x), key) z = self.scale_factor * (z - self.shift_factor) return z
[docs] def decode(self, z: Array) -> Array: """ Decode latent representation back to data space. Args: z (Array): Latent tensor. Returns: Array: Reconstructed output. """ z = z / self.scale_factor + self.shift_factor z = self.Decoder2D(z) return z
[docs] def __call__(self, x: Array, key) -> Array: """ Forward pass: encode and then decode the input. Args: x (Array): Input tensor. key: PRNG key for sampling. Returns: Array: Reconstructed output. """ return self.decode(self.encode(x, key))