gensbi.models.wrappers#

Submodules#

Classes#

ConditionalWrapper

Wrapper for conditional models to handle input expansion and calling convention.

JointWrapper

Wrapper for joint models to handle both conditioned and unconditioned inference.

UnconditionalWrapper

Wrapper for unconditional models to handle input expansion and calling convention.

Package Contents#

class gensbi.models.wrappers.ConditionalWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for conditional models to handle input expansion and calling convention.

Parameters:

model – The conditional model instance to wrap.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, guidance=None)[source]#

Call the wrapped model with expanded inputs.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • conditioned (bool | Array, optional) – Whether to use conditioning. Defaults to True.

  • guidance (Array | None, optional) – Optional guidance input.

Returns:

Model output.

Return type:

Array

class gensbi.models.wrappers.JointWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for joint models to handle both conditioned and unconditioned inference.

Parameters:
  • model – The joint model instance to wrap.

  • conditioned (bool, optional) – Whether to use conditioning by default. Defaults to True.

__call__(t, obs, obs_ids, cond, cond_ids, conditioned=True, **kwargs)[source]#

Call the wrapped model for either conditioned or unconditioned inference.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • conditioned (bool, optional) – Whether to use conditioning. If None, uses the default set at initialization.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Model output.

Return type:

Array

conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs)[source]#

Perform conditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • cond (Array) – Conditioning values.

  • cond_ids (Array) – Conditioning identifiers.

  • t (Array) – Time steps.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Conditioned output (only for unconditioned variables).

Return type:

Array

unconditioned(obs, obs_ids, t, **kwargs)[source]#

Perform unconditioned inference.

Parameters:
  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • t (Array) – Time steps.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Unconditioned output.

Return type:

Array

class gensbi.models.wrappers.UnconditionalWrapper(model)[source]#

Bases: gensbi.utils.model_wrapping.ModelWrapper

Wrapper for unconditional models to handle input expansion and calling convention.

Parameters:

model – The unconditional model instance to wrap.

__call__(t, obs, obs_ids, **kwargs)[source]#

Call the wrapped model with expanded inputs.

Parameters:
  • t (Array) – Time steps.

  • obs (Array) – Observations.

  • obs_ids (Array) – Observation identifiers.

  • **kwargs – Additional keyword arguments passed to the model.

Returns:

Model output.

Return type:

Array