gensbi.models.autoencoders.autoencoder_2d#

Classes#

AttnBlock2D

2D Self-attention block for image or grid data.

AutoEncoder2D

2D Autoencoder model with Gaussian latent space.

Decoder2D

2D Decoder for autoencoder architectures.

Downsample2D

2D Downsampling block using strided convolution.

Encoder2D

2D Encoder for autoencoder architectures.

ResnetBlock2D

2D Residual block with optional channel up/downsampling.

Upsample2D

2D Upsampling block using nearest-neighbor interpolation and convolution.

Module Contents#

class gensbi.models.autoencoders.autoencoder_2d.AttnBlock2D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Self-attention block for image or grid data.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the attention block.

Parameters:

x (Array) – Input tensor.

Returns:

Output tensor after residual attention.

Return type:

Array

attention(h_)[source]#

Compute self-attention for 2D input.

Parameters:
  • h (Array) – Input tensor of shape (batch, height, width, channels).

  • h_ (jax.Array)

Returns:

Output tensor after attention.

Return type:

Array

in_channels[source]#
k[source]#
norm[source]#
proj_out[source]#
q[source]#
v[source]#
class gensbi.models.autoencoders.autoencoder_2d.AutoEncoder2D(params)[source]#

Bases: flax.nnx.Module

2D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key)[source]#

Forward pass: encode and then decode the input.

Parameters:
  • x (Array) – Input tensor.

  • key – PRNG key for sampling.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder2D[source]#
Encoder2D[source]#
reg[source]#
scale_factor[source]#
shift_factor[source]#
class gensbi.models.autoencoders.autoencoder_2d.Decoder2D(ch, out_ch, ch_mult, num_res_blocks, in_channels, resolution, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Decoder for autoencoder architectures.

Parameters:
  • ch (int) – Base number of channels.

  • out_ch (int) – Number of output channels.

  • ch_mult (list[int]) – Channel multipliers for each resolution.

  • num_res_blocks (int) – Number of residual blocks per resolution.

  • in_channels (int) – Number of input channels.

  • resolution (int) – Output image height/width.

  • z_channels (int) – Number of latent channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(z)[source]#

Forward pass for the decoder.

Parameters:

z (Array) – Latent tensor of shape (batch, latent_height, latent_width, latent_channels).

Returns:

Reconstructed output tensor.

Return type:

Array

ch[source]#
conv_in[source]#
conv_out[source]#
ffactor[source]#
in_channels[source]#
mid[source]#
norm_out[source]#
num_res_blocks[source]#
num_resolutions[source]#
resolution[source]#
up[source]#
z_shape[source]#
class gensbi.models.autoencoders.autoencoder_2d.Downsample2D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Downsampling block using strided convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Downsample the input tensor by a factor of 2.

Parameters:

x (Array) – Input tensor of shape (batch, height, width, channels).

Returns:

Downsampled tensor.

Return type:

Array

conv[source]#
class gensbi.models.autoencoders.autoencoder_2d.Encoder2D(resolution, in_channels, ch, ch_mult, num_res_blocks, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Encoder for autoencoder architectures.

Parameters:
  • resolution (int) – Input image height/width.

  • in_channels (int) – Number of input channels.

  • ch (int) – Base number of channels.

  • ch_mult (list[int]) – Channel multipliers for each resolution.

  • num_res_blocks (int) – Number of residual blocks per resolution.

  • z_channels (int) – Number of latent channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the encoder.

Parameters:

x (Array) – Input tensor of shape (batch, height, width, channels).

Returns:

Encoded latent representation.

Return type:

Array

ch[source]#
conv_in[source]#
conv_out[source]#
down[source]#
in_ch_mult[source]#
in_channels[source]#
mid[source]#
norm_out[source]#
num_res_blocks[source]#
num_resolutions[source]#
resolution[source]#
class gensbi.models.autoencoders.autoencoder_2d.ResnetBlock2D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Residual block with optional channel up/downsampling.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the residual block.

Parameters:

x (Array) – Input tensor.

Returns:

Output tensor after residual connection.

Return type:

Array

conv1[source]#
conv2[source]#
in_channels[source]#
norm1[source]#
norm2[source]#
out_channels[source]#
class gensbi.models.autoencoders.autoencoder_2d.Upsample2D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

2D Upsampling block using nearest-neighbor interpolation and convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Upsample the input tensor by a factor of 2.

Parameters:

x (Array) – Input tensor of shape (batch, height, width, channels).

Returns:

Upsampled tensor.

Return type:

Array

conv[source]#