gensbi.models.autoencoders.autoencoder_1d#

Classes#

AttnBlock1D

1D Self-attention block for sequence data.

AutoEncoder1D

1D Autoencoder model with Gaussian latent space.

Decoder1D

1D Decoder for autoencoder architectures.

Downsample1D

1D Downsampling block using strided convolution.

Encoder1D

1D Encoder for autoencoder architectures.

ResnetBlock1D

1D Residual block with optional channel up/downsampling.

Upsample1D

1D Upsampling block using nearest-neighbor interpolation and convolution.

Module Contents#

class gensbi.models.autoencoders.autoencoder_1d.AttnBlock1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Self-attention block for sequence data.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the attention block.

Parameters:

x (Array) – Input tensor.

Returns:

Output tensor after residual attention.

Return type:

Array

attention(h_)[source]#

Compute self-attention for 1D input.

Parameters:
  • h (Array) – Input tensor of shape (batch, length, channels).

  • h_ (jax.Array)

Returns:

Output tensor after attention.

Return type:

Array

in_channels[source]#
k[source]#
norm[source]#
proj_out[source]#
q[source]#
v[source]#
class gensbi.models.autoencoders.autoencoder_1d.AutoEncoder1D(params)[source]#

Bases: flax.nnx.Module

1D Autoencoder model with Gaussian latent space.

Parameters:

params (AutoEncoderParams) – Configuration parameters for the autoencoder.

__call__(x, key)[source]#

Forward pass: encode and then decode the input.

Parameters:

x (Array) – Input tensor.

Returns:

Reconstructed output.

Return type:

Array

decode(z)[source]#

Decode latent representation back to data space.

Parameters:

z (Array) – Latent tensor.

Returns:

Reconstructed output.

Return type:

Array

encode(x, key)[source]#

Encode input data into the latent space.

Parameters:
  • x (Array) – Input tensor.

  • key (Array) – PRNG key for sampling the latent variable.

Returns:

Latent representation.

Return type:

Array

Decoder1D[source]#
Encoder1D[source]#
reg[source]#
scale_factor[source]#
shift_factor[source]#
class gensbi.models.autoencoders.autoencoder_1d.Decoder1D(ch, out_ch, ch_mult, num_res_blocks, in_channels, resolution, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Decoder for autoencoder architectures.

Parameters:
  • ch (int) – Base number of channels.

  • out_ch (int) – Number of output channels.

  • ch_mult (list[int]) – Channel multipliers for each resolution.

  • num_res_blocks (int) – Number of residual blocks per resolution.

  • in_channels (int) – Number of input channels.

  • resolution (int) – Output sequence length.

  • z_channels (int) – Number of latent channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(z)[source]#

Forward pass for the decoder.

Parameters:

z (Array) – Latent tensor of shape (batch, latent_length, latent_channels).

Returns:

Reconstructed output tensor.

Return type:

Array

ch[source]#
conv_in[source]#
conv_out[source]#
ffactor[source]#
in_channels[source]#
mid[source]#
norm_out[source]#
num_res_blocks[source]#
num_resolutions[source]#
resolution[source]#
up[source]#
z_shape[source]#
class gensbi.models.autoencoders.autoencoder_1d.Downsample1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Downsampling block using strided convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Downsample the input tensor by a factor of 2.

Parameters:

x (Array) – Input tensor of shape (batch, length, channels).

Returns:

Downsampled tensor.

Return type:

Array

conv[source]#
class gensbi.models.autoencoders.autoencoder_1d.Encoder1D(resolution, in_channels, ch, ch_mult, num_res_blocks, z_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Encoder for autoencoder architectures.

Parameters:
  • resolution (int) – Input sequence length.

  • in_channels (int) – Number of input channels.

  • ch (int) – Base number of channels.

  • ch_mult (list[int]) – Channel multipliers for each resolution.

  • num_res_blocks (int) – Number of residual blocks per resolution.

  • z_channels (int) – Number of latent channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the encoder.

Parameters:

x (Array) – Input tensor of shape (batch, length, channels).

Returns:

Encoded latent representation.

Return type:

Array

ch[source]#
conv_in[source]#
conv_out[source]#
down[source]#
in_ch_mult[source]#
in_channels[source]#
mid[source]#
norm_out[source]#
num_res_blocks[source]#
num_resolutions[source]#
resolution[source]#
class gensbi.models.autoencoders.autoencoder_1d.ResnetBlock1D(in_channels, out_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Residual block with optional channel up/downsampling.

Parameters:
  • in_channels (int) – Number of input channels.

  • out_channels (int) – Number of output channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Forward pass for the residual block.

Parameters:

x (Array) – Input tensor.

Returns:

Output tensor after residual connection.

Return type:

Array

conv1[source]#
conv2[source]#
in_channels[source]#
norm1[source]#
norm2[source]#
out_channels[source]#
class gensbi.models.autoencoders.autoencoder_1d.Upsample1D(in_channels, rngs, param_dtype=jnp.bfloat16)[source]#

Bases: flax.nnx.Module

1D Upsampling block using nearest-neighbor interpolation and convolution.

Parameters:
  • in_channels (int) – Number of input channels.

  • rngs (nnx.Rngs) – Random number generators for parameter initialization.

  • param_dtype (DTypeLike) – Data type for parameters (default: jnp.bfloat16).

__call__(x)[source]#

Upsample the input tensor by a factor of 2.

Parameters:

x (Array) – Input tensor of shape (batch, length, channels).

Returns:

Upsampled tensor.

Return type:

Array

conv[source]#