Source code for gensbi.flow_matching.utils.manifolds_not_implemented.manifold

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the CC-by-NC license found in the
# LICENSE file in the root directory of this source tree.

import abc

import torch.nn as nn
from torch import Tensor


[docs] class Manifold(nn.Module, metaclass=abc.ABCMeta): """A manifold class that contains projection operations and logarithm and exponential maps."""
[docs] @abc.abstractmethod def expmap(self, x: Tensor, u: Tensor) -> Tensor: r"""Computes exponential map :math:`\exp_x(u)`. Args: x (Tensor): point on the manifold u (Tensor): tangent vector at point :math:`x` Raises: NotImplementedError: if not implemented Returns: Tensor: transported point """ raise NotImplementedError
[docs] @abc.abstractmethod def logmap(self, x: Tensor, y: Tensor) -> Tensor: r"""Computes logarithmic map :math:`\log_x(y)`. Args: x (Tensor): point on the manifold y (Tensor): point on the manifold Raises: NotImplementedError: if not implemented Returns: Tensor: tangent vector at point :math:`x` """ raise NotImplementedError
[docs] @abc.abstractmethod def projx(self, x: Tensor) -> Tensor: """Project point :math:`x` on the manifold. Args: x (Tensor): point to be projected Raises: NotImplementedError: if not implemented Returns: Tensor: projected point on the manifold """ raise NotImplementedError
[docs] @abc.abstractmethod def proju(self, x: Tensor, u: Tensor) -> Tensor: """Project vector :math:`u` on a tangent space for :math:`x`. Args: x (Tensor): point on the manifold u (Tensor): vector to be projected Raises: NotImplementedError: if not implemented Returns: Tensor: projected tangent vector """ raise NotImplementedError
[docs] class Euclidean(Manifold): """The Euclidean manifold."""
[docs] def expmap(self, x: Tensor, u: Tensor) -> Tensor: return x + u
[docs] def logmap(self, x: Tensor, y: Tensor) -> Tensor: return y - x
[docs] def projx(self, x: Tensor) -> Tensor: return x
[docs] def proju(self, x: Tensor, u: Tensor) -> Tensor: return u