Source code for gensbi.models.wrappers.joint

from jax import Array
from typing import Optional


import jax.numpy as jnp
from jax import Array
from jax.typing import DTypeLike

from gensbi.utils.model_wrapping import ModelWrapper, _expand_dims, _expand_time


[docs] class JointWrapper(ModelWrapper): """ Wrapper for joint models to handle both conditioned and unconditioned inference. Args: model: The joint model instance to wrap. conditioned (bool, optional): Whether to use conditioning by default. Defaults to True. """ def __init__(self, model): """ Initialize the JointWrapper. Args: model: The joint model instance to wrap. conditioned (bool, optional): Whether to use conditioning by default. Defaults to True. """ super().__init__(model)
[docs] def conditioned( self, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, t: Array, **kwargs, ) -> Array: """ Perform conditioned inference. Args: obs (Array): Observations. obs_ids (Array): Observation identifiers. cond (Array): Conditioning values. cond_ids (Array): Conditioning identifiers. t (Array): Time steps. **kwargs: Additional keyword arguments passed to the model. Returns: Array: Conditioned output (only for unconditioned variables). """ obs_dim = obs.shape[1] cond_dim = cond.shape[1] cond = jnp.broadcast_to(cond, (obs.shape[0], *cond.shape[1:])) condition_mask_dim = obs_dim + cond_dim condition_mask = jnp.zeros((condition_mask_dim,), dtype=jnp.bool_) condition_mask = condition_mask.at[obs_dim:].set(True) x = jnp.concatenate([obs, cond], axis=1) node_ids = jnp.concatenate([obs_ids, cond_ids], axis=1) res = self.model( obs=x, t=t, node_ids=node_ids, condition_mask=condition_mask, **kwargs, ) res = res[:, :obs_dim] return res
[docs] def unconditioned( self, obs: Array, obs_ids: Array, t: Array, **kwargs, ) -> Array: """ Perform unconditioned inference. Args: obs (Array): Observations. obs_ids (Array): Observation identifiers. t (Array): Time steps. **kwargs: Additional keyword arguments passed to the model. Returns: Array: Unconditioned output. """ condition_mask = jnp.zeros(obs.shape, dtype=jnp.bool_) node_ids = obs_ids res = self.model( obs=obs, t=t, node_ids=node_ids, condition_mask=condition_mask, **kwargs, ) return res
[docs] def __call__( self, t: Array, obs: Array, obs_ids: Array, cond: Array, cond_ids: Array, conditioned: bool = True, **kwargs, ) -> Array: """ Call the wrapped model for either conditioned or unconditioned inference. Args: t (Array): Time steps. obs (Array): Observations. obs_ids (Array): Observation identifiers. cond (Array): Conditioning values. cond_ids (Array): Conditioning identifiers. conditioned (bool, optional): Whether to use conditioning. If None, uses the default set at initialization. **kwargs: Additional keyword arguments passed to the model. Returns: Array: Model output. """ t = _expand_time(t) obs = _expand_dims(obs) cond = _expand_dims(cond) if conditioned: return self.conditioned(obs, obs_ids, cond, cond_ids, t, **kwargs) else: return self.unconditioned(obs, obs_ids, t, **kwargs)