Source code for gensbi.models.wrappers.unconditional
from jax import Array
from typing import Optional
import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike
from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time
[docs]
class UnconditionalWrapper(ModelWrapper):
"""
Wrapper for unconditional models to handle input expansion and calling convention.
Args:
model: The unconditional model instance to wrap.
"""
def __init__(self, model):
"""
Initialize the UnconditionalWrapper.
Args:
model: The unconditional model instance to wrap.
"""
super().__init__(model)
[docs]
def __call__(
self,
t: Array,
obs: Array,
obs_ids: Array,
**kwargs,
) -> Array:
"""
Call the wrapped model with expanded inputs.
Args:
t (Array): Time steps.
obs (Array): Observations.
obs_ids (Array): Observation identifiers.
**kwargs: Additional keyword arguments passed to the model.
Returns:
Array: Model output.
"""
t = _expand_time(t)
obs = _expand_dims(obs)
obs_ids = _expand_dims(obs_ids)
return self.model(
obs=obs,
t=t,
node_ids=obs_ids,
condition_mask=jnp.zeros(obs.shape, dtype=jnp.bool_),
**kwargs,
)