Source code for gensbi.models.wrappers.conditional



from jax import Array


from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time




[docs] class ConditionalWrapper(ModelWrapper): """ Wrapper for conditional models to handle input expansion and calling convention. Args: model: The conditional model instance to wrap. """ def __init__(self, model): """ Initialize the ConditionalWrapper. Args: model: The conditional model instance to wrap. """ super().__init__(model)
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, conditioned: bool | Array = True, guidance: Array | None = None, ) -> Array: """ Call the wrapped model with expanded inputs. Args: t (Array): Time steps. obs (Array): Observations. obs_ids (Array): Observation identifiers. cond (Array): Conditioning values. cond_ids (Array): Conditioning identifiers. conditioned (bool | Array, optional): Whether to use conditioning. Defaults to True. guidance (Array | None, optional): Optional guidance input. Returns: Array: Model output. """ obs = _expand_dims(obs) t = _expand_time(t) cond = _expand_dims(cond) obs_ids = _expand_dims(obs_ids) cond_ids = _expand_dims(cond_ids) return self.model( obs=obs, t=t, cond=cond, obs_ids=obs_ids, cond_ids=cond_ids, conditioned=conditioned, guidance=guidance, )