scikit-tda / persim
 1 3 ```import numpy as np ``` 2 3 ```from scipy.spatial.distance import cityblock ``` 3 4 3 ```__all__ = ["sliced_wasserstein"] ``` 5 6 3 ```def sliced_wasserstein(PD1, PD2, M=50): ``` 7 ``` """ Implementation of Sliced Wasserstein distance as described in ``` 8 ``` Sliced Wasserstein Kernel for Persistence Diagrams by Mathieu Carriere, Marco Cuturi, Steve Oudot (https://arxiv.org/abs/1706.03358) ``` 9 10 11 ``` Parameters ``` 12 ``` ----------- ``` 13 ``` ``` 14 ``` PD1: np.array size (m,2) ``` 15 ``` Persistence diagram ``` 16 ``` PD2: np.array size (n,2) ``` 17 ``` Persistence diagram ``` 18 ``` M: int, default is 50 ``` 19 ``` Iterations to run approximation. ``` 20 21 ``` Returns ``` 22 ``` -------- ``` 23 ``` sw: float ``` 24 ``` Sliced Wasserstein distance between PD1 and PD2 ``` 25 ``` """ ``` 26 27 3 ``` diag_theta = np.array( ``` 28 ``` [np.cos(0.25 * np.pi), np.sin(0.25 * np.pi)], dtype=np.float32 ``` 29 ``` ) ``` 30 31 3 ``` l_theta1 = [np.dot(diag_theta, x) for x in PD1] ``` 32 3 ``` l_theta2 = [np.dot(diag_theta, x) for x in PD2] ``` 33 34 3 ``` if (len(l_theta1) != PD1.shape[0]) or (len(l_theta2) != PD2.shape[0]): ``` 35 0 ``` raise ValueError("The projected points and origin do not match") ``` 36 37 3 ``` PD_delta1 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta1] ``` 38 3 ``` PD_delta2 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta2] ``` 39 40 ``` # i have the input now to compute the sw ``` 41 3 ``` sw = 0 ``` 42 3 ``` theta = 0.5 ``` 43 3 ``` step = 1.0 / M ``` 44 3 ``` for i in range(M): ``` 45 3 ``` l_theta = np.array( ``` 46 ``` [np.cos(theta * np.pi), np.sin(theta * np.pi)], dtype=np.float32 ``` 47 ``` ) ``` 48 49 3 ``` V1 = [np.dot(l_theta, x) for x in PD1] + [np.dot(l_theta, x) for x in PD_delta2] ``` 50 51 3 ``` V2 = [np.dot(l_theta, x) for x in PD2] + [np.dot(l_theta, x) for x in PD_delta1] ``` 52 53 3 ``` sw += step * cityblock(sorted(V1), sorted(V2)) ``` 54 3 ``` theta += step ``` 55 56 3 ``` return sw ```

Read our documentation on viewing source code .