migrate to ghactions for ci/cd
1 | 3 |
import numpy as np |
2 | 3 |
from scipy.spatial.distance import cityblock |
3 |
|
|
4 | 3 |
__all__ = ["sliced_wasserstein"] |
5 |
|
|
6 | 3 |
def sliced_wasserstein(PD1, PD2, M=50): |
7 |
""" Implementation of Sliced Wasserstein distance as described in
|
|
8 |
Sliced Wasserstein Kernel for Persistence Diagrams by Mathieu Carriere, Marco Cuturi, Steve Oudot (https://arxiv.org/abs/1706.03358)
|
|
9 |
|
|
10 |
|
|
11 |
Parameters
|
|
12 |
-----------
|
|
13 |
|
|
14 |
PD1: np.array size (m,2)
|
|
15 |
Persistence diagram
|
|
16 |
PD2: np.array size (n,2)
|
|
17 |
Persistence diagram
|
|
18 |
M: int, default is 50
|
|
19 |
Iterations to run approximation.
|
|
20 |
|
|
21 |
Returns
|
|
22 |
--------
|
|
23 |
sw: float
|
|
24 |
Sliced Wasserstein distance between PD1 and PD2
|
|
25 |
"""
|
|
26 |
|
|
27 | 3 |
diag_theta = np.array( |
28 |
[np.cos(0.25 * np.pi), np.sin(0.25 * np.pi)], dtype=np.float32 |
|
29 |
)
|
|
30 |
|
|
31 |
l_theta1 = [np.dot(diag_theta, x) for x in PD1] |
|
32 |
l_theta2 = [np.dot(diag_theta, x) for x in PD2] |
|
33 |
|
|
34 |
if (len(l_theta1) != PD1.shape[0]) or (len(l_theta2) != PD2.shape[0]): |
|
35 |
raise ValueError("The projected points and origin do not match") |
|
36 |
|
|
37 |
PD_delta1 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta1] |
|
38 |
PD_delta2 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta2] |
|
39 |
|
|
40 |
# i have the input now to compute the sw
|
|
41 | 3 |
sw = 0 |
42 | 3 |
theta = 0.5 |
43 | 3 |
step = 1.0 / M |
44 |
for i in range(M): |
|
45 | 3 |
l_theta = np.array( |
46 |
[np.cos(theta * np.pi), np.sin(theta * np.pi)], dtype=np.float32 |
|
47 |
)
|
|
48 |
|
|
49 |
V1 = [np.dot(l_theta, x) for x in PD1] + [np.dot(l_theta, x) for x in PD_delta2] |
|
50 |
|
|
51 |
V2 = [np.dot(l_theta, x) for x in PD2] + [np.dot(l_theta, x) for x in PD_delta1] |
|
52 |
|
|
53 | 3 |
sw += step * cityblock(sorted(V1), sorted(V2)) |
54 | 3 |
theta += step |
55 |
|
|
56 | 3 |
return sw |
Read our documentation on viewing source code .