Source code for gensbi.flow_matching.utils.utils

#FIXME: first pass

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional, Callable

import jax
import jax.numpy as jnp
from jax import Array

import matplotlib.pyplot as plt
import numpy as np

from einops import einsum


[docs] def unsqueeze_to_match(source: Array, target: Array, how: str = "suffix") -> Array: """ Unsqueeze the source array to match the dimensionality of the target array. Args: source (Array): The source array to be unsqueezed. target (Array): The target array to match the dimensionality of. how (str, optional): Whether to unsqueeze the source array at the beginning ("prefix") or end ("suffix"). Defaults to "suffix". Returns: Array: The unsqueezed source array. """ assert ( how == "prefix" or how == "suffix" ), f"{how} is not supported, only 'prefix' and 'suffix' are supported." dim_diff = len(target.shape) - len(source.shape) for _ in range(dim_diff): if how == "prefix": source = jnp.expand_dims(source, axis=0) elif how == "suffix": source = jnp.expand_dims(source, axis=-1) return source
[docs] def expand_tensor_like(input_array: Array, expand_to: Array) -> Array: """`input_array` is a 1d vector of length equal to the batch size of `expand_to`, expand `input_array` to have the same shape as `expand_to` along all remaining dimensions. Args: input_array (Array): (batch_size,). expand_to (Array): (batch_size, ...). Returns: Array: (batch_size, ...). """ assert len(input_array.shape) == 1, "Input array must be a 1d vector." assert ( input_array.shape[0] == expand_to.shape[0] ), f"The first (batch_size) dimension must match. Got shape {input_array.shape} and {expand_to.shape}." dim_diff = len(expand_to.shape) - len(input_array.shape) t_expanded = jnp.reshape(input_array, (-1,) + (1,) * dim_diff) return jnp.broadcast_to(t_expanded, expand_to.shape)