gensbi.models.autoencoders.commons#
Classes#
Configuration parameters for the AutoEncoder models. |
|
Diagonal Gaussian distribution module for VAE latent space. |
|
Placeholder variable for storing loss values in the model. |
Functions#
|
Compute the VAE loss as the sum of reconstruction and KL divergence losses. |
Module Contents#
- class gensbi.models.autoencoders.commons.AutoEncoderParams[source]#
Configuration parameters for the AutoEncoder models.
- in_channels[source]#
Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).
- Type:
int
- out_ch[source]#
Number of output channels produced by the decoder (matches input channels for reconstruction).
- Type:
int
- ch_mult[source]#
Multipliers for the number of channels at each resolution level (controls model width/depth).
- Type:
list[int]
- z_channels[source]#
Number of latent channels in the bottleneck (size of encoded representation).
- Type:
int
- scale_factor[source]#
Scaling factor applied to the latent representation (for normalization or data scaling).
- Type:
float
- shift_factor[source]#
Shift factor applied to the latent representation (for normalization or data centering).
- Type:
float
- rngs[source]#
Random number generators for parameter initialization and stochastic layers.
- Type:
nnx.Rngs
- class gensbi.models.autoencoders.commons.DiagonalGaussian(sample=True, chunk_dim=-1)[source]#
Bases:
flax.nnx.ModuleDiagonal Gaussian distribution module for VAE latent space.
- Parameters:
sample (bool) – Whether to sample from the distribution (default: True).
chunk_dim (int) – Axis along which to split mean and logvar (default: -1).
- __call__(z, key=None)[source]#
Split input into mean and log-variance, compute KL loss, and sample if required.
- Parameters:
z (Array) – Input tensor containing concatenated mean and logvar.
key (Array, optional) – PRNG key for sampling. Required if sampling is enabled.
- Returns:
Sampled latent or mean, depending on self.sample.
- Return type:
Array
- class gensbi.models.autoencoders.commons.Loss(value, *, is_hijax=None, has_ref=False, is_mutable=True, eager_sharding=None, **metadata)[source]#
Bases:
flax.nnx.VariablePlaceholder variable for storing loss values in the model.
- Parameters:
value (A | VariableMetadata[A])
is_hijax (bool | None)
has_ref (bool)
is_mutable (bool)
eager_sharding (bool | None)
metadata (Any)
- gensbi.models.autoencoders.commons.vae_loss_fn(model, x, key)[source]#
Compute the VAE loss as the sum of reconstruction and KL divergence losses.
- Parameters:
model (nnx.Module) – The VAE model.
x (Array) – Input data.
key (Array) – PRNG key for stochastic operations.
- Returns:
Scalar loss value.
- Return type:
jax.Array