gensbi.models.autoencoders#
Autoencoders submodule.
This module provides 1D and 2D autoencoder architectures with Gaussian latent spaces, including configuration dataclasses and VAE loss functions.
Submodules#
Classes#
1D Autoencoder model with Gaussian latent space. |
|
2D Autoencoder model with Gaussian latent space. |
|
Configuration parameters for the AutoEncoder models. |
Functions#
|
Compute the VAE loss as the sum of reconstruction and KL divergence losses. |
Package Contents#
- class gensbi.models.autoencoders.AutoEncoder1D(params)[source]#
Bases:
flax.nnx.Module1D Autoencoder model with Gaussian latent space.
- Parameters:
params (AutoEncoderParams) – Configuration parameters for the autoencoder.
- __call__(x, key)[source]#
Forward pass: encode and then decode the input.
- Parameters:
x (Array) – Input tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- decode(z)[source]#
Decode latent representation back to data space.
- Parameters:
z (Array) – Latent tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- encode(x, key)[source]#
Encode input data into the latent space.
- Parameters:
x (Array) – Input tensor.
key (Array) – PRNG key for sampling the latent variable.
- Returns:
Latent representation.
- Return type:
Array
- Decoder1D#
- Encoder1D#
- reg#
- scale_factor#
- shift_factor#
- class gensbi.models.autoencoders.AutoEncoder2D(params)[source]#
Bases:
flax.nnx.Module2D Autoencoder model with Gaussian latent space.
- Parameters:
params (AutoEncoderParams) – Configuration parameters for the autoencoder.
- __call__(x, key)[source]#
Forward pass: encode and then decode the input.
- Parameters:
x (Array) – Input tensor.
key – PRNG key for sampling.
- Returns:
Reconstructed output.
- Return type:
Array
- decode(z)[source]#
Decode latent representation back to data space.
- Parameters:
z (Array) – Latent tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- encode(x, key)[source]#
Encode input data into the latent space.
- Parameters:
x (Array) – Input tensor.
key (Array) – PRNG key for sampling the latent variable.
- Returns:
Latent representation.
- Return type:
Array
- Decoder2D#
- Encoder2D#
- reg#
- scale_factor#
- shift_factor#
- class gensbi.models.autoencoders.AutoEncoderParams[source]#
Configuration parameters for the AutoEncoder models.
- resolution#
The input feature dimension (length for 1D, height/width for 2D).
- Type:
int
- in_channels#
Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).
- Type:
int
- ch#
Base number of channels for the first convolutional layer.
- Type:
int
- out_ch#
Number of output channels produced by the decoder (matches input channels for reconstruction).
- Type:
int
- ch_mult#
Multipliers for the number of channels at each resolution level (controls model width/depth).
- Type:
list[int]
- num_res_blocks#
Number of residual blocks per resolution level.
- Type:
int
- z_channels#
Number of latent channels in the bottleneck (size of encoded representation).
- Type:
int
- scale_factor#
Scaling factor applied to the latent representation (for normalization or data scaling).
- Type:
float
- shift_factor#
Shift factor applied to the latent representation (for normalization or data centering).
- Type:
float
- rngs#
Random number generators for parameter initialization and stochastic layers.
- Type:
nnx.Rngs
- param_dtype#
Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).
- Type:
DTypeLike
- ch: int#
- ch_mult: list[int]#
- in_channels: int#
- num_res_blocks: int#
- out_ch: int#
- param_dtype: jax.typing.DTypeLike#
- resolution: int#
- rngs: flax.nnx.Rngs#
- scale_factor: float#
- shift_factor: float#
- z_channels: int#
- gensbi.models.autoencoders.vae_loss_fn(model, x, key)[source]#
Compute the VAE loss as the sum of reconstruction and KL divergence losses.
- Parameters:
model (nnx.Module) – The VAE model.
x (Array) – Input data.
key (Array) – PRNG key for stochastic operations.
- Returns:
Scalar loss value.
- Return type:
jax.Array