gensbi.models package#

Subpackages#

Module contents#

class gensbi.models.Flux(*args: Any, **kwargs: Any)[source]#

Bases: Module

Transformer model for flow matching on sequences.

class gensbi.models.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

class gensbi.models.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#

Bases: object

axes_dim: list[int]#
cond_dim: int | None = None#
context_in_dim: int#
depth: int#
depth_single_blocks: int#
guidance_embed: bool = False#
in_channels: int#
mlp_ratio: float#
num_heads: int#
obs_dim: int | None = None#
param_dtype#

alias of bfloat16

qkv_bias: bool#
qkv_multiplier: int = 1#
rngs: Rngs#
theta: int = 10000#
use_rope: bool = True#
vec_in_dim: int#
class gensbi.models.Simformer(*args: Any, **kwargs: Any)[source]#

Bases: Module

class gensbi.models.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#

Bases: ContinuousFMLoss

class gensbi.models.SimformerConditioner(*args: Any, **kwargs: Any)[source]#

Bases: Module

conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Conditioned output.

Return type:

Array

unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • edge_mask (Optional[Array]) – Mask for edges.

Returns:

Unconditioned output.

Return type:

Array

class gensbi.models.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#

Bases: object

dim_condition: int#
dim_id: int#
dim_joint: int#
dim_value: int#
dropout_rate: float = 0.0#
fourier_features: int = 128#
num_heads: int = 4#
num_hidden_layers: int = 1#
num_layers: int = 6#
qkv_features: int = 8#
rngs: Rngs#
widening_factor: int = 3#