Source code for gensbi.models.wrappers.conditional
from jax import Array
from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time
[docs]
class ConditionalWrapper(ModelWrapper):
"""
Wrapper for conditional models to handle input expansion and calling convention.
Args:
model: The conditional model instance to wrap.
"""
def __init__(self, model):
"""
Initialize the ConditionalWrapper.
Args:
model: The conditional model instance to wrap.
"""
super().__init__(model)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
cond: Array,
cond_ids: Array,
conditioned: bool | Array = True,
guidance: Array | None = None,
) -> Array:
"""
Call the wrapped model with expanded inputs.
Args:
t (Array): Time steps.
obs (Array): Observations.
obs_ids (Array): Observation identifiers.
cond (Array): Conditioning values.
cond_ids (Array): Conditioning identifiers.
conditioned (bool | Array, optional): Whether to use conditioning. Defaults to True.
guidance (Array | None, optional): Optional guidance input.
Returns:
Array: Model output.
"""
obs = _expand_dims(obs)
t = _expand_time(t)
cond = _expand_dims(cond)
obs_ids = _expand_dims(obs_ids)
cond_ids = _expand_dims(cond_ids)
return self.model(
obs=obs,
t=t,
cond=cond,
obs_ids=obs_ids,
cond_ids=cond_ids,
conditioned=conditioned,
guidance=guidance,
)