Source code for gensbi.models.autoencoders.commons

from dataclasses import dataclass

from flax import nnx
from jax.typing import DTypeLike

from jax import Array
import jax
import jax.numpy as jnp

import optax

@dataclass
[docs] class AutoEncoderParams: """ Configuration parameters for the AutoEncoder models. Attributes: resolution (int): The input feature dimension (length for 1D, height/width for 2D). in_channels (int): Number of input channels (e.g., 1 for scalar features, >1 for multi-channel). ch (int): Base number of channels for the first convolutional layer. out_ch (int): Number of output channels produced by the decoder (matches input channels for reconstruction). ch_mult (list[int]): Multipliers for the number of channels at each resolution level (controls model width/depth). num_res_blocks (int): Number of residual blocks per resolution level. z_channels (int): Number of latent channels in the bottleneck (size of encoded representation). scale_factor (float): Scaling factor applied to the latent representation (for normalization or data scaling). shift_factor (float): Shift factor applied to the latent representation (for normalization or data centering). rngs (nnx.Rngs): Random number generators for parameter initialization and stochastic layers. param_dtype (DTypeLike): Data type for model parameters (e.g., jnp.float32, jnp.bfloat16). """
[docs] resolution: int
[docs] in_channels: int
[docs] ch: int
[docs] out_ch: int
[docs] ch_mult: list[int]
[docs] num_res_blocks: int
[docs] z_channels: int
[docs] scale_factor: float
[docs] shift_factor: float
[docs] rngs: nnx.Rngs
[docs] param_dtype: DTypeLike
[docs] class Loss(nnx.Variable): """ Placeholder variable for storing loss values in the model. """ pass
[docs] class DiagonalGaussian(nnx.Module): """ Diagonal Gaussian distribution module for VAE latent space. Args: sample (bool): Whether to sample from the distribution (default: True). chunk_dim (int): Axis along which to split mean and logvar (default: -1). """ def __init__( self, sample: bool = True, chunk_dim: int = -1, ):
[docs] self.sample = sample
[docs] self.chunk_dim = chunk_dim
[docs] def __call__(self, z: Array, key=None) -> Array: """ Split input into mean and log-variance, compute KL loss, and sample if required. Args: z (Array): Input tensor containing concatenated mean and logvar. key (Array, optional): PRNG key for sampling. Required if sampling is enabled. Returns: Array: Sampled latent or mean, depending on self.sample. """ mean, logvar = jnp.split(z, 2, axis=self.chunk_dim) std = jnp.exp(0.5 * logvar) self.kl_loss = Loss( jnp.mean(0.5 * jnp.mean(-jnp.log(std**2) - 1.0 + std**2 + mean**2, axis=-1)) ) if self.sample: return mean + std * jax.random.normal( key=key, shape=mean.shape, dtype=z.dtype ) else: return mean
[docs] def vae_loss_fn(model: nnx.Module, x: jax.Array, key) -> jax.Array: """ Compute the VAE loss as the sum of reconstruction and KL divergence losses. Args: model (nnx.Module): The VAE model. x (Array): Input data. key (Array): PRNG key for stochastic operations. Returns: jax.Array: Scalar loss value. """ logits = model(x, key) losses = nnx.pop(model, Loss) kl_loss = sum(jax.tree_util.tree_leaves(losses), 0.0) reconstruction_loss = jnp.mean( optax.sigmoid_binary_cross_entropy(logits, x) ) loss = reconstruction_loss + 0.1 * kl_loss return loss