scikit-tda / persim
1
"""
2

3
    Implementation of the Wasserstein distance using
4
    the Hungarian algorithm
5

6
    Author: Chris Tralie
7

8
"""
9 3
import numpy as np
10 3
from sklearn import metrics
11 3
from scipy import optimize
12 3
import warnings
13

14 3
__all__ = ["wasserstein"]
15

16

17 3
def wasserstein(dgm1, dgm2, matching=False):
18
    """
19
    Perform the Wasserstein distance matching between persistence diagrams.
20
    Assumes first two columns of dgm1 and dgm2 are the coordinates of the persistence
21
    points, but allows for other coordinate columns (which are ignored in
22
    diagonal matching).
23

24
    See the `distances` notebook for an example of how to use this.
25

26
    Parameters
27
    ------------
28

29
    dgm1: Mx(>=2) 
30
        array of birth/death pairs for PD 1
31
    dgm2: Nx(>=2) 
32
        array of birth/death paris for PD 2
33
    matching: bool, default False
34
        if True, return matching information and cross-similarity matrix
35

36
    Returns 
37
    ---------
38

39
    d: float
40
        Wasserstein distance between dgm1 and dgm2
41
    (matching, D): Only returns if `matching=True`
42
        (tuples of matched indices, (N+M)x(N+M) cross-similarity matrix)
43

44
    """
45

46 3
    S = np.array(dgm1)
47 3
    M = min(S.shape[0], S.size)
48 3
    if S.size > 0:
49 3
        S = S[np.isfinite(S[:, 1]), :]
50 3
        if S.shape[0] < M:
51 3
            warnings.warn(
52
                "dgm1 has points with non-finite death times;"+
53
                "ignoring those points"
54
            )
55 3
            M = S.shape[0]
56 3
    T = np.array(dgm2)
57 3
    N = min(T.shape[0], T.size)
58 3
    if T.size > 0:
59 3
        T = T[np.isfinite(T[:, 1]), :]
60 3
        if T.shape[0] < N:
61 3
            warnings.warn(
62
                "dgm2 has points with non-finite death times;"+
63
                "ignoring those points"
64
            )
65 3
            N = T.shape[0]
66

67 3
    if M == 0:
68 3
        S = np.array([[0, 0]])
69 3
        M = 1
70 3
    if N == 0:
71 3
        T = np.array([[0, 0]])
72 3
        N = 1
73
    # Step 1: Compute CSM between S and dgm2, including points on diagonal
74 3
    DUL = metrics.pairwise.pairwise_distances(S, T)
75

76
    # Put diagonal elements into the matrix
77
    # Rotate the diagrams to make it easy to find the straight line
78
    # distance to the diagonal
79 3
    cp = np.cos(np.pi/4)
80 3
    sp = np.sin(np.pi/4)
81 3
    R = np.array([[cp, -sp], [sp, cp]])
82 3
    S = S[:, 0:2].dot(R)
83 3
    T = T[:, 0:2].dot(R)
84 3
    D = np.zeros((M+N, M+N))
85 3
    D[0:M, 0:N] = DUL
86 3
    UR = np.max(D)*np.ones((M, M))
87 3
    np.fill_diagonal(UR, S[:, 1])
88 3
    D[0:M, N:N+M] = UR
89 3
    UL = np.max(D)*np.ones((N, N))
90 3
    np.fill_diagonal(UL, T[:, 1])
91 3
    D[M:N+M, 0:N] = UL
92

93
    # Step 2: Run the hungarian algorithm
94 3
    matchi, matchj = optimize.linear_sum_assignment(D)
95 3
    matchdist = np.sum(D[matchi, matchj])
96

97 3
    if matching:
98 3
        matchidx = [(i, j) for i, j in zip(matchi, matchj)]
99 0
        return matchdist, (matchidx, D)
100

101 3
    return matchdist

Read our documentation on viewing source code .

Loading