scikit-tda / persim
 1 ```""" ``` 2 3 ``` Implementation of the Wasserstein distance using ``` 4 ``` the Hungarian algorithm ``` 5 6 ``` Author: Chris Tralie ``` 7 8 ```""" ``` 9 3 ```import numpy as np ``` 10 3 ```from sklearn import metrics ``` 11 3 ```from scipy import optimize ``` 12 3 ```import warnings ``` 13 14 3 ```__all__ = ["wasserstein"] ``` 15 16 17 3 ```def wasserstein(dgm1, dgm2, matching=False): ``` 18 ``` """ ``` 19 ``` Perform the Wasserstein distance matching between persistence diagrams. ``` 20 ``` Assumes first two columns of dgm1 and dgm2 are the coordinates of the persistence ``` 21 ``` points, but allows for other coordinate columns (which are ignored in ``` 22 ``` diagonal matching). ``` 23 24 ``` See the `distances` notebook for an example of how to use this. ``` 25 26 ``` Parameters ``` 27 ``` ------------ ``` 28 29 ``` dgm1: Mx(>=2) ``` 30 ``` array of birth/death pairs for PD 1 ``` 31 ``` dgm2: Nx(>=2) ``` 32 ``` array of birth/death paris for PD 2 ``` 33 ``` matching: bool, default False ``` 34 ``` if True, return matching information and cross-similarity matrix ``` 35 36 ``` Returns ``` 37 ``` --------- ``` 38 39 ``` d: float ``` 40 ``` Wasserstein distance between dgm1 and dgm2 ``` 41 ``` (matching, D): Only returns if `matching=True` ``` 42 ``` (tuples of matched indices, (N+M)x(N+M) cross-similarity matrix) ``` 43 44 ``` """ ``` 45 46 3 ``` S = np.array(dgm1) ``` 47 3 ``` M = min(S.shape[0], S.size) ``` 48 3 ``` if S.size > 0: ``` 49 3 ``` S = S[np.isfinite(S[:, 1]), :] ``` 50 3 ``` if S.shape[0] < M: ``` 51 3 ``` warnings.warn( ``` 52 ``` "dgm1 has points with non-finite death times;"+ ``` 53 ``` "ignoring those points" ``` 54 ``` ) ``` 55 3 ``` M = S.shape[0] ``` 56 3 ``` T = np.array(dgm2) ``` 57 3 ``` N = min(T.shape[0], T.size) ``` 58 3 ``` if T.size > 0: ``` 59 3 ``` T = T[np.isfinite(T[:, 1]), :] ``` 60 3 ``` if T.shape[0] < N: ``` 61 3 ``` warnings.warn( ``` 62 ``` "dgm2 has points with non-finite death times;"+ ``` 63 ``` "ignoring those points" ``` 64 ``` ) ``` 65 3 ``` N = T.shape[0] ``` 66 67 3 ``` if M == 0: ``` 68 3 ``` S = np.array([[0, 0]]) ``` 69 3 ``` M = 1 ``` 70 3 ``` if N == 0: ``` 71 3 ``` T = np.array([[0, 0]]) ``` 72 3 ``` N = 1 ``` 73 ``` # Step 1: Compute CSM between S and dgm2, including points on diagonal ``` 74 3 ``` DUL = metrics.pairwise.pairwise_distances(S, T) ``` 75 76 ``` # Put diagonal elements into the matrix ``` 77 ``` # Rotate the diagrams to make it easy to find the straight line ``` 78 ``` # distance to the diagonal ``` 79 3 ``` cp = np.cos(np.pi/4) ``` 80 3 ``` sp = np.sin(np.pi/4) ``` 81 3 ``` R = np.array([[cp, -sp], [sp, cp]]) ``` 82 3 ``` S = S[:, 0:2].dot(R) ``` 83 3 ``` T = T[:, 0:2].dot(R) ``` 84 3 ``` D = np.zeros((M+N, M+N)) ``` 85 3 ``` D[0:M, 0:N] = DUL ``` 86 3 ``` UR = np.max(D)*np.ones((M, M)) ``` 87 3 ``` np.fill_diagonal(UR, S[:, 1]) ``` 88 3 ``` D[0:M, N:N+M] = UR ``` 89 3 ``` UL = np.max(D)*np.ones((N, N)) ``` 90 3 ``` np.fill_diagonal(UL, T[:, 1]) ``` 91 3 ``` D[M:N+M, 0:N] = UL ``` 92 93 ``` # Step 2: Run the hungarian algorithm ``` 94 3 ``` matchi, matchj = optimize.linear_sum_assignment(D) ``` 95 3 ``` matchdist = np.sum(D[matchi, matchj]) ``` 96 97 3 ``` if matching: ``` 98 3 ``` matchidx = [(i, j) for i, j in zip(matchi, matchj)] ``` 99 0 ``` return matchdist, (matchidx, D) ``` 100 101 3 ``` return matchdist ```

Read our documentation on viewing source code .