gensbi.models.autoencoders.autoencoder_1d#
Classes#
1D Self-attention block for sequence data. |
|
1D Autoencoder model with Gaussian latent space. |
|
1D Decoder for autoencoder architectures. |
|
1D Downsampling block using strided convolution. |
|
1D Encoder for autoencoder architectures. |
|
1D Residual block with optional channel up/downsampling. |
|
1D Upsampling block using nearest-neighbor interpolation and convolution. |
Module Contents#
- class gensbi.models.autoencoders.autoencoder_1d.AttnBlock1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Self-attention block for sequence data.
- Parameters:
in_channels (int) – Number of input channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).
- __call__(x)[source]#
Forward pass for the attention block.
- Parameters:
x (Array) – Input tensor.
- Returns:
Output tensor after residual attention.
- Return type:
Array
- class gensbi.models.autoencoders.autoencoder_1d.AutoEncoder1D(params)[source]#
Bases:
flax.nnx.Module1D Autoencoder model with Gaussian latent space.
- Parameters:
params (AutoEncoderParams) – Configuration parameters for the autoencoder.
- __call__(x, key)[source]#
Forward pass: encode and then decode the input.
- Parameters:
x (Array) – Input tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- decode(z)[source]#
Decode latent representation back to data space.
- Parameters:
z (Array) – Latent tensor.
- Returns:
Reconstructed output.
- Return type:
Array
- class gensbi.models.autoencoders.autoencoder_1d.Decoder1D(ch, out_ch, ch_mult, num_res_blocks, in_channels, resolution, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Decoder for autoencoder architectures.
- Parameters:
ch (int) – Base number of channels.
out_ch (int) – Number of output channels.
ch_mult (list[int]) – Channel multipliers for each resolution.
num_res_blocks (int) – Number of residual blocks per resolution.
in_channels (int) – Number of input channels.
resolution (int) – Output sequence length.
z_channels (int) – Number of latent channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).
- class gensbi.models.autoencoders.autoencoder_1d.Downsample1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Downsampling block using strided convolution.
- Parameters:
in_channels (int) – Number of input channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).
- class gensbi.models.autoencoders.autoencoder_1d.Encoder1D(resolution, in_channels, ch, ch_mult, num_res_blocks, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Encoder for autoencoder architectures.
- Parameters:
resolution (int) – Input sequence length.
in_channels (int) – Number of input channels.
ch (int) – Base number of channels.
ch_mult (list[int]) – Channel multipliers for each resolution.
num_res_blocks (int) – Number of residual blocks per resolution.
z_channels (int) – Number of latent channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).
- class gensbi.models.autoencoders.autoencoder_1d.ResnetBlock1D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Residual block with optional channel up/downsampling.
- Parameters:
in_channels (int) – Number of input channels.
out_channels (int) – Number of output channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).
- class gensbi.models.autoencoders.autoencoder_1d.Upsample1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#
Bases:
flax.nnx.Module1D Upsampling block using nearest-neighbor interpolation and convolution.
- Parameters:
in_channels (int) – Number of input channels.
rngs (nnx.Rngs) – Random number generators for parameter initialization.
param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).