gensbi.models#
Submodules#
Classes#
ConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model. |
|
ConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model. |
|
Wrapper for conditional models to handle input expansion and calling convention. |
|
Transformer model for flow matching on sequences. |
|
Flux1Joint model for joint density estimation. |
|
Parameters for the Flux1Joint model. |
|
Parameters for the Flux1 model. |
|
JointCFMLoss is a class that computes the continuous flow matching loss for the Joint model. |
|
JointDiffLoss is a class that computes the diffusion score matching loss for the Joint model. |
|
Wrapper for joint models to handle both conditioned and unconditioned inference. |
|
Simformer model for joint density estimation. |
|
Parameters for the Simformer model. |
|
UnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model. |
|
UnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model. |
|
Wrapper for unconditional models to handle input expansion and calling convention. |
Package Contents#
- class gensbi.models.ConditionalCFMLoss(path, reduction='mean', cfg_scale=None)[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossConditionalCFMLoss is a class that computes the continuous flow matching loss for the Conditional model.
- Parameters:
path – Probability path (x-prediction training).
reduction (str, optional) – Specify the reduction to apply to the output
'none'|'mean'|'sum'.'none': no reduction is applied to the output,'mean': the output is reduced by mean over sequence elements,'sum': the output is reduced by sum over sequence elements. Defaults to ‘mean’.
- __call__(vf, batch, cond, obs_ids, cond_ids)[source]#
Evaluates the continuous flow matching loss.
- Parameters:
vf (callable) – The vector field model to evaluate.
batch (tuple) – A tuple containing the input data (x_0, x_1, t).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
The computed loss.
- Return type:
jnp.ndarray
- cfg_scale = None#
- class gensbi.models.ConditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleConditionalDiffLoss is a class that computes the diffusion score matching loss for the Conditional model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, cond, obs_ids, cond_ids)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
cond (jnp.ndarray) – The conditioning data.
obs_ids (jnp.ndarray) – The observation IDs.
cond_ids (jnp.ndarray) – The conditioning IDs.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.ConditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for conditional models to handle input expansion and calling convention.
- Parameters:
model – The conditional model instance to wrap.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Call the wrapped model with expanded inputs.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool | Array, optional) – Whether to use conditioning. Defaults to True.
guidance (Array | None, optional) – Optional guidance input.
- Returns:
Model output.
- Return type:
Array
- class gensbi.models.Flux1(params)[source]#
Bases:
flax.nnx.ModuleTransformer model for flow matching on sequences.
- Parameters:
params (Flux1Params)
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
obs_ids (jax.Array)
cond (jax.Array)
cond_ids (jax.Array)
conditioned (bool | jax.Array)
guidance (jax.Array | None)
- Return type:
jax.Array
- cond_in#
- condition_embedding#
- condition_null#
- double_blocks#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- pe_embedder#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.Flux1Joint(params)[source]#
Bases:
flax.nnx.ModuleFlux1Joint model for joint density estimation.
- Parameters:
params (Flux1JointParams) – Parameters for the Flux1Joint model.
- __call__(t, obs, node_ids, condition_mask, guidance=None, edge_mask=None)[source]#
- Parameters:
t (jax.Array)
obs (jax.Array)
node_ids (jax.Array)
condition_mask (jax.Array)
guidance (jax.Array | None)
edge_mask (Optional[jax.Array])
- Return type:
jax.Array
- condition_embedding#
- final_layer#
- in_channels#
- num_heads#
- obs_in#
- out_channels#
- params#
- pe_embedder#
- qkv_features#
- single_blocks#
- time_in#
- vector_in#
- class gensbi.models.Flux1JointParams[source]#
Parameters for the Flux1Joint model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
guidance_embed (bool) – Whether to use guidance embedding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- condition_dim: list[int]#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- joint_dim: int#
- mlp_ratio: float#
- num_heads: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 10000#
- vec_in_dim: int | None#
- class gensbi.models.Flux1Params[source]#
Parameters for the Flux1 model.
- Parameters:
in_channels (int) – Number of input channels.
vec_in_dim (Union[int, None]) – Dimension of the vector input, if applicable.
context_in_dim (int) – Dimension of the context input.
mlp_ratio (float) – Ratio for the MLP layers.
num_heads (int) – Number of attention heads.
depth (int) – Number of double stream blocks.
depth_single_blocks (int) – Number of single stream blocks.
axes_dim (list[int]) – Dimensions of the axes for positional encoding.
qkv_bias (bool) – Whether to use bias in QKV layers.
rngs (nnx.Rngs) – Random number generators for initialization.
obs_dim (int) – Observation dimension.
cond_dim (int) – Condition dimension.
theta (int) – Scaling factor for positional encoding.
param_dtype (DTypeLike) – Data type for model parameters.
- axes_dim: list[int]#
- cond_dim: int#
- context_in_dim: int#
- depth: int#
- depth_single_blocks: int#
- guidance_embed: bool = False#
- in_channels: int#
- mlp_ratio: float#
- num_heads: int#
- obs_dim: int#
- param_dtype: jax.typing.DTypeLike#
- qkv_bias: bool#
- rngs: flax.nnx.Rngs#
- theta: int = 10000#
- vec_in_dim: int | None#
- class gensbi.models.JointCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossJointCFMLoss is a class that computes the continuous flow matching loss for the Joint model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.JointDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleJointDiffLoss is a class that computes the diffusion score matching loss for the Joint model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, condition_mask=None, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.JointWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for joint models to handle both conditioned and unconditioned inference.
- Parameters:
model – The joint model instance to wrap.
conditioned (bool, optional) – Whether to use conditioning by default. Defaults to True.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, **kwargs)[source]#
Call the wrapped model for either conditioned or unconditioned inference.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool, optional) – Whether to use conditioning. If None, uses the default set at initialization.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Model output.
- Return type:
Array
- conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs)[source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Conditioned output (only for unconditioned variables).
- Return type:
Array
- unconditioned(obs, obs_ids, t, **kwargs)[source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.Simformer(params)[source]#
Bases:
flax.nnx.ModuleSimformer model for joint density estimation.
- Parameters:
params (SimformerParams) – Parameters for the Simformer model.
- __call__(t, obs, node_ids, condition_mask, edge_mask=None)[source]#
Forward pass of the Simformer model.
- Parameters:
t (Array) – Time steps.
obs (Array) – Input data.
args (Optional[dict]) – Additional arguments.
node_ids (Array) – Node identifiers.
condition_mask (Array) – Mask for conditioning.
edge_mask (Optional[Array]) – Mask for edges.
- Returns:
Model output.
- Return type:
Array
- condition_embedding#
- dim_condition#
- dim_id#
- dim_value#
- embedding_net_id#
- embedding_net_value#
- embedding_time#
- in_channels#
- output_fn#
- params#
- total_tokens#
- transformer#
- class gensbi.models.SimformerParams[source]#
Parameters for the Simformer model.
- Parameters:
rngs (nnx.Rngs) – Random number generators for initialization.
in_channels (int) – Number of input channels.
dim_value (int) – Dimension of the value embeddings.
dim_id (int) – Dimension of the ID embeddings.
dim_condition (int) – Dimension of the condition embeddings.
dim_joint (int) – Total dimension of the joint embeddings.
fourier_features (int) – Number of Fourier features for time embedding.
num_heads (int) – Number of attention heads.
num_layers (int) – Number of transformer layers.
widening_factor (int) – Widening factor for the transformer.
qkv_features (int) – Number of features for QKV layers.
num_hidden_layers (int) – Number of hidden layers in the transformer.
- dim_condition: int#
- dim_id: int#
- dim_joint: int#
- dim_value: int#
- fourier_features: int = 128#
- in_channels: int#
- num_heads: int#
- num_layers: int#
- qkv_features: int | None = None#
- rngs: flax.nnx.Rngs#
- widening_factor: int = 3#
- class gensbi.models.UnconditionalCFMLoss(path, reduction='mean')[source]#
Bases:
gensbi.flow_matching.loss.ContinuousFMLossUnconditionalCFMLoss is a class that computes the continuous flow matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
reduction (str) – Reduction method (‘none’, ‘mean’, ‘sum’).
- __call__(vf, batch, *args, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
vf (Callable) – Vector field model.
batch (Tuple[Array, Array, Array]) – Input data (x_0, x_1, t).
args (Optional[dict]) – Additional arguments.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- class gensbi.models.UnconditionalDiffLoss(path)[source]#
Bases:
flax.nnx.ModuleUnconditionalDiffLoss is a class that computes the diffusion score matching loss for the Unconditional model.
- Parameters:
path – Probability path for training.
- __call__(key, model, batch, **kwargs)[source]#
Evaluate the continuous flow matching loss.
- Parameters:
key (jax.random.PRNGKey) – Random key for stochastic operations.
model (Callable) – F model.
batch (Tuple[Array, Array, Array]) – Input data (x_1, sigma).
args (Optional[dict]) – Additional arguments.
condition_mask (Optional[Array]) – Mask for conditioning.
**kwargs – Additional keyword arguments.
- Returns:
Computed loss.
- Return type:
Array
- loss_fn#
- path#
- class gensbi.models.UnconditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for unconditional models to handle input expansion and calling convention.
- Parameters:
model – The unconditional model instance to wrap.