gensbi.models package#
Subpackages#
- gensbi.models.flux1 package
- Submodules
- gensbi.models.flux1.layers module
- gensbi.models.flux1.loss module
- gensbi.models.flux1.math module
- gensbi.models.flux1.model module
Flux
FluxParams
FluxParams.axes_dim
FluxParams.cond_dim
FluxParams.context_in_dim
FluxParams.depth
FluxParams.depth_single_blocks
FluxParams.guidance_embed
FluxParams.in_channels
FluxParams.mlp_ratio
FluxParams.num_heads
FluxParams.obs_dim
FluxParams.param_dtype
FluxParams.qkv_bias
FluxParams.qkv_multiplier
FluxParams.rngs
FluxParams.theta
FluxParams.use_rope
FluxParams.vec_in_dim
Identity
- Module contents
Flux
FluxCFMLoss
FluxParams
FluxParams.axes_dim
FluxParams.cond_dim
FluxParams.context_in_dim
FluxParams.depth
FluxParams.depth_single_blocks
FluxParams.guidance_embed
FluxParams.in_channels
FluxParams.mlp_ratio
FluxParams.num_heads
FluxParams.obs_dim
FluxParams.param_dtype
FluxParams.qkv_bias
FluxParams.qkv_multiplier
FluxParams.rngs
FluxParams.theta
FluxParams.use_rope
FluxParams.vec_in_dim
- gensbi.models.simformer package
- Submodules
- gensbi.models.simformer.embedding module
- gensbi.models.simformer.loss module
- gensbi.models.simformer.simformer module
Simformer
SimformerConditioner
SimformerParams
SimformerParams.dim_condition
SimformerParams.dim_id
SimformerParams.dim_joint
SimformerParams.dim_value
SimformerParams.dropout_rate
SimformerParams.fourier_features
SimformerParams.num_heads
SimformerParams.num_hidden_layers
SimformerParams.num_layers
SimformerParams.qkv_features
SimformerParams.rngs
SimformerParams.widening_factor
- gensbi.models.simformer.transformer module
- Module contents
Simformer
SimformerCFMLoss
SimformerConditioner
SimformerParams
SimformerParams.dim_condition
SimformerParams.dim_id
SimformerParams.dim_joint
SimformerParams.dim_value
SimformerParams.dropout_rate
SimformerParams.fourier_features
SimformerParams.num_heads
SimformerParams.num_hidden_layers
SimformerParams.num_layers
SimformerParams.qkv_features
SimformerParams.rngs
SimformerParams.widening_factor
Module contents#
- class gensbi.models.Flux(*args: Any, **kwargs: Any)[source]#
Bases:
Module
Transformer model for flow matching on sequences.
- class gensbi.models.FluxCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
- class gensbi.models.FluxParams(in_channels: int, vec_in_dim: int, context_in_dim: int, mlp_ratio: float, num_heads: int, depth: int, depth_single_blocks: int, axes_dim: list[int], qkv_bias: bool, rngs: flax.nnx.rnglib.Rngs, obs_dim: int | None = None, cond_dim: int | None = None, use_rope: bool = True, theta: int = 10000, guidance_embed: bool = False, qkv_multiplier: int = 1, param_dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType] = <class 'jax.numpy.bfloat16'>)[source]#
Bases:
object
- axes_dim: list[int]#
- cond_dim: int | None = None#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- obs_dim: int | None = None#
- param_dtype#
alias of
bfloat16
- qkv_bias: bool#
- qkv_multiplier: int = 1#
- rngs: Rngs#
- theta: int = 10000#
- use_rope: bool = True#
- vec_in_dim: int#
- class gensbi.models.SimformerCFMLoss(*args: Any, **kwargs: Any)[source]#
Bases:
ContinuousFMLoss
- class gensbi.models.SimformerConditioner(*args: Any, **kwargs: Any)[source]#
Bases:
Module
- conditioned(obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Conditioned output.
- Return type:
Array
- unconditioned(obs: Array, obs_ids: Array, t: Array, edge_mask: Array | None = None) Array [source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.SimformerParams(rngs: flax.nnx.rnglib.Rngs, dim_value: int, dim_id: int, dim_condition: int, dim_joint: int, fourier_features: int = 128, num_heads: int = 4, num_layers: int = 6, widening_factor: int = 3, qkv_features: int = 8, num_hidden_layers: int = 1, dropout_rate: float = 0.0)[source]#
Bases:
object
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- dropout_rate: float = 0.0#
- fourier_features: int = 128#
- num_heads: int = 4#
- num_layers: int = 6#
- qkv_features: int = 8#
- rngs: Rngs#
- widening_factor: int = 3#