Source code for gensbi.flow_matching.solver.utils
#FIXME: first pass
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.
import jax
import jax.numpy as jnp
from jax import Array
[docs]
def get_nearest_times(time_grid: Array, t_discretization: Array) -> Array:
"""Find the nearest times in t_discretization for each time in time_grid.
Args:
time_grid (Array): Query times to find nearest neighbors for, shape (N,)
t_discretization (Array): Reference time points to match against, shape (M,)
Returns:
Array: Nearest times from t_discretization for each point in time_grid, shape (N,)
"""
# Expand dimensions for broadcasting
time_grid_expanded = jnp.expand_dims(time_grid, axis=1) # (N, 1)
t_disc_expanded = jnp.expand_dims(t_discretization, axis=0) # (1, M)
# Compute pairwise distances
distances = jnp.abs(time_grid_expanded - t_disc_expanded)
# Find indices of minimum distances
nearest_indices = jnp.argmin(distances, axis=1)
# Get the corresponding times
return t_discretization[nearest_indices]