kylanoconnor's picture
Initial PLONK deployment for Hugging Face Spaces
fac3244
raw
history blame
1.36 kB
"""Copyright (c) Meta Platforms, Inc. and affiliates."""
import math
import torch
from geoopt.manifolds import Sphere as geoopt_Sphere
class Sphere(geoopt_Sphere):
def transp(self, x, y, v):
denom = 1 + self.inner(x, x, y, keepdim=True)
res = v - self.inner(x, y, v, keepdim=True) / denom * (x + y)
cond = denom.gt(1e-3)
return torch.where(cond, res, -v)
def uniform_logprob(self, x):
dim = x.shape[-1]
return torch.full_like(
x[..., 0],
math.lgamma(dim / 2) - (math.log(2) + (dim / 2) * math.log(math.pi)),
)
def random_base(self, *args, **kwargs):
return self.random_uniform(*args, **kwargs)
def base_logprob(self, *args, **kwargs):
return self.uniform_logprob(*args, **kwargs)
def geodesic(manifold, start_point, end_point):
shooting_tangent_vec = manifold.logmap(start_point, end_point)
def path(t):
"""Generate parameterized function for geodesic curve.
Parameters
----------
t : array-like, shape=[n_points,]
Times at which to compute points of the geodesics.
"""
tangent_vecs = torch.einsum("i,...k->...ik", t, shooting_tangent_vec)
points_at_time_t = manifold.expmap(start_point.unsqueeze(-2), tangent_vecs)
return points_at_time_t
return path