Source code for gensbi.models.wrappers.unconditional

from jax import Array
from typing import Optional


import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike

from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time



[docs] class UnconditionalWrapper(ModelWrapper): """ Wrapper for unconditional models to handle input expansion and calling convention. Args: model: The unconditional model instance to wrap. """ def __init__(self, model): """ Initialize the UnconditionalWrapper. Args: model: The unconditional model instance to wrap. """ super().__init__(model)
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, **kwargs, ) -> Array: """ Call the wrapped model with expanded inputs. Args: t (Array): Time steps. obs (Array): Observations. obs_ids (Array): Observation identifiers. **kwargs: Additional keyword arguments passed to the model. Returns: Array: Model output. """ t = _expand_time(t) obs = _expand_dims(obs) obs_ids = _expand_dims(obs_ids) return self.model( obs=obs, t=t, node_ids=obs_ids, condition_mask=jnp.zeros(obs.shape, dtype=jnp.bool_), **kwargs, )