gensbi.models.wrappers#
Submodules#
Classes#
Wrapper for conditional models to handle input expansion and calling convention. |
|
Wrapper for joint models to handle both conditioned and unconditioned inference. |
|
Wrapper for unconditional models to handle input expansion and calling convention. |
Package Contents#
- class gensbi.models.wrappers.ConditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for conditional models to handle input expansion and calling convention.
- Parameters:
model – The conditional model instance to wrap.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#
Call the wrapped model with expanded inputs.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool | Array, optional) – Whether to use conditioning. Defaults to True.
guidance (Array | None, optional) – Optional guidance input.
- Returns:
Model output.
- Return type:
Array
- class gensbi.models.wrappers.JointWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for joint models to handle both conditioned and unconditioned inference.
- Parameters:
model – The joint model instance to wrap.
conditioned (bool, optional) – Whether to use conditioning by default. Defaults to True.
- __call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, **kwargs)[source]#
Call the wrapped model for either conditioned or unconditioned inference.
- Parameters:
t (Array) – Time steps.
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
conditioned (bool, optional) – Whether to use conditioning. If None, uses the default set at initialization.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Model output.
- Return type:
Array
- conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs)[source]#
Perform conditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
cond (Array) – Conditioning values.
cond_ids (Array) – Conditioning identifiers.
t (Array) – Time steps.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Conditioned output (only for unconditioned variables).
- Return type:
Array
- unconditioned(obs, obs_ids, t, **kwargs)[source]#
Perform unconditioned inference.
- Parameters:
obs (Array) – Observations.
obs_ids (Array) – Observation identifiers.
t (Array) – Time steps.
**kwargs – Additional keyword arguments passed to the model.
- Returns:
Unconditioned output.
- Return type:
Array
- class gensbi.models.wrappers.UnconditionalWrapper(model)[source]#
Bases:
gensbi.utils.model_wrapping.ModelWrapperWrapper for unconditional models to handle input expansion and calling convention.
- Parameters:
model – The unconditional model instance to wrap.