Source code for gensbi.flow_matching.utils.utils

from typing import Optional, Callable

import jax
import jax.numpy as jnp
from jax import Array

import matplotlib.pyplot as plt
import numpy as np

from einops import einsum


[docs] def unsqueeze_to_match(source: Array, target: Array, how: str = "suffix") -> Array: """ Unsqueeze the source array to match the dimensionality of the target array. Args: source (Array): The source array to be unsqueezed. target (Array): The target array to match the dimensionality of. how (str, optional): Whether to unsqueeze the source array at the beginning ("prefix") or end ("suffix"). Defaults to "suffix". Returns: Array: The unsqueezed source array. """ assert ( how == "prefix" or how == "suffix" ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." dim_diff = len(target.shape) - len(source.shape) for _ in range(dim_diff): if how == "prefix": source = jnp.expand_dims(source, axis=0) elif how == "suffix": source = jnp.expand_dims(source, axis=-1) return source
[docs] def expand_tensor_like(input_array: Array, expand_to: Array) -> Array: """`input_array` is a 1d vector of length equal to the batch size of `expand_to`, expand `input_array` to have the same shape as `expand_to` along all remaining dimensions. Args: input_array (Array): (batch_size,). expand_to (Array): (batch_size, ...). Returns: Array: (batch_size, ...). """ assert len(input_array.shape) == 1, "Input array must be a 1d vector." assert ( input_array.shape[0] == expand_to.shape[0] ), f"The first (batch_size) dimension must match. Got shape {input_array.shape} and {expand_to.shape}." dim_diff = len(expand_to.shape) - len(input_array.shape) t_expanded = jnp.reshape(input_array, (-1,) + (1,) * dim_diff) return jnp.broadcast_to(t_expanded, expand_to.shape)