gensbi.models.autoencoders#

Autoencoders submodule.

This module provides 1D and 2D autoencoder architectures with Gaussian latent spaces, including configuration dataclasses and VAE loss functions.

Submodules#

Classes#

AutoEncoder1D

1D Autoencoder model with Gaussian latent space.

AutoEncoder2D

2D Autoencoder model with Gaussian latent space.

AutoEncoderParams

Configuration parameters for the AutoEncoder models.

Functions#

vae_loss_fn(model, x, key)

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Package Contents#

class gensbi.models.autoencoders.AutoEncoder1D(params)[source]#

Bases: flax.nnx.Module

1D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder1D#
Encoder1D#
reg#
scale_factor#
shift_factor#
class gensbi.models.autoencoders.AutoEncoder2D(params)[source]#

Bases: flax.nnx.Module

2D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key)[source]#

Forward pass: encode and then decode the input.

Parameters:
  • x (Array) – Input tensor.

  • key – PRNG key for sampling.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder2D#
Encoder2D#
reg#
scale_factor#
shift_factor#
class gensbi.models.autoencoders.AutoEncoderParams[source]#

Configuration parameters for the AutoEncoder models.

resolution#

The input feature dimension (length for 1D, height/width for 2D).

Type:

int

in_channels#

Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).

Type:

int

ch#

Base number of channels for the first convolutional layer.

Type:

int

out_ch#

Number of output channels produced by the decoder (matches input channels for reconstruction).

Type:

int

ch_mult#

Multipliers for the number of channels at each resolution level (controls model width/depth).

Type:

list[int]

num_res_blocks#

Number of residual blocks per resolution level.

Type:

int

z_channels#

Number of latent channels in the bottleneck (size of encoded representation).

Type:

int

scale_factor#

Scaling factor applied to the latent representation (for normalization or data scaling).

Type:

float

shift_factor#

Shift factor applied to the latent representation (for normalization or data centering).

Type:

float

rngs#

Random number generators for parameter initialization and stochastic layers.

Type:

nnx.Rngs

param_dtype#

Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).

Type:

DTypeLike

ch: int#
ch_mult: list[int]#
in_channels: int#
num_res_blocks: int#
out_ch: int#
param_dtype: jax.typing.DTypeLike#
resolution: int#
rngs: flax.nnx.Rngs#
scale_factor: float#
shift_factor: float#
z_channels: int#
gensbi.models.autoencoders.vae_loss_fn(model, x, key)[source]#

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Parameters:
  • model (nnx.Module) – The VAE model.

  • x (Array) – Input data.

  • key (Array) – PRNG key for stochastic operations.

Returns:

Scalar loss value.

Return type:

jax.Array