1 3
import numpy as np
2 3
from scipy.spatial.distance import cityblock
3

4 3
__all__ = ["sliced_wasserstein"]
5

6 3
def sliced_wasserstein(PD1, PD2, M=50):
7
    """ Implementation of Sliced Wasserstein distance as described in 
8
        Sliced Wasserstein Kernel for Persistence Diagrams by Mathieu Carriere, Marco Cuturi, Steve Oudot (https://arxiv.org/abs/1706.03358)
9

10

11
        Parameters
12
        -----------
13
        
14
        PD1: np.array size (m,2)
15
            Persistence diagram
16
        PD2: np.array size (n,2)
17
            Persistence diagram
18
        M: int, default is 50
19
            Iterations to run approximation.
20

21
        Returns
22
        --------
23
        sw: float
24
            Sliced Wasserstein distance between PD1 and PD2
25
    """
26

27 3
    diag_theta = np.array(
28
        [np.cos(0.25 * np.pi), np.sin(0.25 * np.pi)], dtype=np.float32
29
    )
30

31 3
    l_theta1 = [np.dot(diag_theta, x) for x in PD1]
32 3
    l_theta2 = [np.dot(diag_theta, x) for x in PD2]
33

34 3
    if (len(l_theta1) != PD1.shape[0]) or (len(l_theta2) != PD2.shape[0]):
35 0
        raise ValueError("The projected points and origin do not match")
36

37 3
    PD_delta1 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta1]
38 3
    PD_delta2 = [[np.sqrt(x ** 2 / 2.0)] * 2 for x in l_theta2]
39

40
    # i have the input now to compute the sw
41 3
    sw = 0
42 3
    theta = 0.5
43 3
    step = 1.0 / M
44 3
    for i in range(M):
45 3
        l_theta = np.array(
46
            [np.cos(theta * np.pi), np.sin(theta * np.pi)], dtype=np.float32
47
        )
48

49 3
        V1 = [np.dot(l_theta, x) for x in PD1] + [np.dot(l_theta, x) for x in PD_delta2]
50

51 3
        V2 = [np.dot(l_theta, x) for x in PD2] + [np.dot(l_theta, x) for x in PD_delta1]
52

53 3
        sw += step * cityblock(sorted(V1), sorted(V2))
54 3
        theta += step
55

56 3
    return sw

Read our documentation on viewing source code .

Loading