gensbi.models.autoencoders.commons#

Classes#

AutoEncoderParams

Configuration parameters for the AutoEncoder models.

DiagonalGaussian

Diagonal Gaussian distribution module for VAE latent space.

Loss

Placeholder variable for storing loss values in the model.

Functions#

vae_loss_fn(model, x, key)

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Module Contents#

class gensbi.models.autoencoders.commons.AutoEncoderParams[source]#

Configuration parameters for the AutoEncoder models.

resolution[source]#

The input feature dimension (length for 1D, height/width for 2D).

Type:

int

in_channels[source]#

Number of input channels (e.g., 1 for scalar features, >1 for multi-channel).

Type:

int

ch[source]#

Base number of channels for the first convolutional layer.

Type:

int

out_ch[source]#

Number of output channels produced by the decoder (matches input channels for reconstruction).

Type:

int

ch_mult[source]#

Multipliers for the number of channels at each resolution level (controls model width/depth).

Type:

list[int]

num_res_blocks[source]#

Number of residual blocks per resolution level.

Type:

int

z_channels[source]#

Number of latent channels in the bottleneck (size of encoded representation).

Type:

int

scale_factor[source]#

Scaling factor applied to the latent representation (for normalization or data scaling).

Type:

float

shift_factor[source]#

Shift factor applied to the latent representation (for normalization or data centering).

Type:

float

rngs[source]#

Random number generators for parameter initialization and stochastic layers.

Type:

nnx.Rngs

param_dtype[source]#

Data type for model parameters (e.g., jnp.float32, jnp.bfloat16).

Type:

DTypeLike

ch: int[source]#
ch_mult: list[int][source]#
in_channels: int[source]#
num_res_blocks: int[source]#
out_ch: int[source]#
param_dtype: jax.typing.DTypeLike[source]#
resolution: int[source]#
rngs: flax.nnx.Rngs[source]#
scale_factor: float[source]#
shift_factor: float[source]#
z_channels: int[source]#
class gensbi.models.autoencoders.commons.DiagonalGaussian(sample=True, chunk_dim=-1)[source]#

Bases: flax.nnx.Module

Diagonal Gaussian distribution module for VAE latent space.

Parameters:
  • sample (bool) – Whether to sample from the distribution (default: True).

  • chunk_dim (int) – Axis along which to split mean and logvar (default: -1).

__call__(z, key=None)[source]#

Split input into mean and log-variance, compute KL loss, and sample if required.

Parameters:
  • z (Array) – Input tensor containing concatenated mean and logvar.

  • key (Array, optional) – PRNG key for sampling. Required if sampling is enabled.

Returns:

Sampled latent or mean, depending on self.sample.

Return type:

Array

chunk_dim = -1[source]#
sample = True[source]#
class gensbi.models.autoencoders.commons.Loss(value, *, is_hijax=None, has_ref=False, is_mutable=True, eager_sharding=None, **metadata)[source]#

Bases: flax.nnx.Variable

Placeholder variable for storing loss values in the model.

Parameters:
  • value (A | VariableMetadata[A])

  • is_hijax (bool | None)

  • has_ref (bool)

  • is_mutable (bool)

  • eager_sharding (bool | None)

  • metadata (Any)

gensbi.models.autoencoders.commons.vae_loss_fn(model, x, key)[source]#

Compute the VAE loss as the sum of reconstruction and KL divergence losses.

Parameters:
  • model (nnx.Module) – The VAE model.

  • x (Array) – Input data.

  • key (Array) – PRNG key for stochastic operations.

Returns:

Scalar loss value.

Return type:

jax.Array