1 3
import numpy as np
2 3
import matplotlib.pyplot as plt
3

4 3
__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"]
5

6

7 3
def plot_diagrams(
8
    diagrams,
9
    plot_only=None,
10
    title=None,
11
    xy_range=None,
12
    labels=None,
13
    colormap="default",
14
    size=20,
15
    ax_color=np.array([0.0, 0.0, 0.0]),
16
    diagonal=True,
17
    lifetime=False,
18
    legend=True,
19
    show=False,
20
    ax=None
21
):
22
    """A helper function to plot persistence diagrams. 
23

24
    Parameters
25
    ----------
26
    diagrams: ndarray (n_pairs, 2) or list of diagrams
27
        A diagram or list of diagrams. If diagram is a list of diagrams, 
28
        then plot all on the same plot using different colors.
29
    plot_only: list of numeric
30
        If specified, an array of only the diagrams that should be plotted.
31
    title: string, default is None
32
        If title is defined, add it as title of the plot.
33
    xy_range: list of numeric [xmin, xmax, ymin, ymax]
34
        User provided range of axes. This is useful for comparing 
35
        multiple persistence diagrams.
36
    labels: string or list of strings
37
        Legend labels for each diagram. 
38
        If none are specified, we use H_0, H_1, H_2,... by default.
39
    colormap: string, default is 'default'
40
        Any of matplotlib color palettes. 
41
        Some options are 'default', 'seaborn', 'sequential'. 
42
        See all available styles with
43

44
        .. code:: python
45

46
            import matplotlib as mpl
47
            print(mpl.styles.available)
48

49
    size: numeric, default is 20
50
        Pixel size of each point plotted.
51
    ax_color: any valid matplotlib color type. 
52
        See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
53
    diagonal: bool, default is True
54
        Plot the diagonal x=y line.
55
    lifetime: bool, default is False. If True, diagonal is turned to False.
56
        Plot life time of each point instead of birth and death. 
57
        Essentially, visualize (x, y-x).
58
    legend: bool, default is True
59
        If true, show the legend.
60
    show: bool, default is False
61
        Call plt.show() after plotting. If you are using self.plot() as part 
62
        of a subplot, set show=False and call plt.show() only once at the end.
63
    """
64

65 3
    ax = ax or plt.gca()
66 3
    plt.style.use(colormap)
67

68 3
    xlabel, ylabel = "Birth", "Death"
69

70 3
    if labels is None:
71
        # Provide default labels for diagrams if using self.dgm_
72 3
        labels = [
73
            "$H_0$",
74
            "$H_1$",
75
            "$H_2$",
76
            "$H_3$",
77
            "$H_4$",
78
            "$H_5$",
79
            "$H_6$",
80
            "$H_7$",
81
            "$H_8$",
82
        ]
83

84 3
    if not isinstance(diagrams, list):
85
        # Must have diagrams as a list for processing downstream
86 3
        diagrams = [diagrams]
87

88 3
    if plot_only:
89 3
        diagrams = [diagrams[i] for i in plot_only]
90 3
        labels = [labels[i] for i in plot_only]
91

92 3
    if not isinstance(labels, list):
93 0
        labels = [labels] * len(diagrams)
94

95
    # Construct copy with proper type of each diagram
96
    # so we can freely edit them.
97 3
    diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams]
98

99
    # find min and max of all visible diagrams
100 3
    concat_dgms = np.concatenate(diagrams).flatten()
101 3
    has_inf = np.any(np.isinf(concat_dgms))
102 3
    finite_dgms = concat_dgms[np.isfinite(concat_dgms)]
103

104
    # clever bounding boxes of the diagram
105 3
    if not xy_range:
106
        # define bounds of diagram
107 3
        ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms)
108 3
        x_r = ax_max - ax_min
109

110
        # Give plot a nice buffer on all sides.
111
        # ax_range=0 when only one point,
112 3
        buffer = 1 if xy_range == 0 else x_r / 5
113

114 3
        x_down = ax_min - buffer / 2
115 3
        x_up = ax_max + buffer
116

117 3
        y_down, y_up = x_down, x_up
118
    else:
119 0
        x_down, x_up, y_down, y_up = xy_range
120

121 3
    yr = y_up - y_down
122

123 3
    if lifetime:
124

125
        # Don't plot landscape and diagonal at the same time.
126 3
        diagonal = False
127

128
        # reset y axis so it doesn't go much below zero
129 3
        y_down = -yr * 0.05
130 3
        y_up = y_down + yr
131

132
        # set custom ylabel
133 3
        ylabel = "Lifetime"
134

135
        # set diagrams to be (x, y-x)
136 3
        for dgm in diagrams:
137 3
            dgm[:, 1] -= dgm[:, 0]
138

139
        # plot horizon line
140 3
        ax.plot([x_down, x_up], [0, 0], c=ax_color)
141

142
    # Plot diagonal
143 3
    if diagonal:
144 3
        ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color)
145

146
    # Plot inf line
147 3
    if has_inf:
148
        # put inf line slightly below top
149 3
        b_inf = y_down + yr * 0.95
150 3
        ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$")
151

152
        # convert each inf in each diagram with b_inf
153 3
        for dgm in diagrams:
154 3
            dgm[np.isinf(dgm)] = b_inf
155

156
    # Plot each diagram
157 3
    for dgm, label in zip(diagrams, labels):
158

159
        # plot persistence pairs
160 3
        ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none")
161

162 3
        ax.set_xlabel(xlabel)
163 3
        ax.set_ylabel(ylabel)
164

165 3
    ax.set_xlim([x_down, x_up])
166 3
    ax.set_ylim([y_down, y_up])
167 3
    ax.set_aspect('equal', 'box')
168

169 3
    if title is not None:
170 3
        ax.set_title(title)
171

172 3
    if legend is True:
173 3
        ax.legend(loc="lower right")
174

175 3
    if show is True:
176 0
        plt.show()
177

178 3
def plot_a_bar(p, q, c='b', linestyle='-'):
179 0
    plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1)
180

181 3
def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None):
182
    """ Visualize bottleneck matching between two diagrams
183

184
    Parameters
185
    ===========
186

187
    I1: array
188
        A diagram
189
    I2: array
190
        A diagram
191
    matchidx: tuples of matched indices
192
        if input `matching=True`, then return matching
193
    D: array
194
        cross-similarity matrix
195
    labels: list of strings
196
        names of diagrams for legend. Default = ["dgm1", "dgm2"], 
197
    ax: matplotlib Axis object
198
        For plotting on a particular axis.
199

200

201
    Examples
202
    ==========
203

204
    bn_matching, (matchidx, D) = persim.bottleneck(A_h1, B_h1, matching=True)
205
    persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
206

207
    """
208

209 3
    plot_diagrams([I1, I2], labels=labels, ax=ax)
210 3
    cp = np.cos(np.pi / 4)
211 3
    sp = np.sin(np.pi / 4)
212 3
    R = np.array([[cp, -sp], [sp, cp]])
213 3
    if I1.size == 0:
214 0
        I1 = np.array([[0, 0]])
215 3
    if I2.size == 0:
216 0
        I2 = np.array([[0, 0]])
217 3
    I1Rot = I1.dot(R)
218 3
    I2Rot = I2.dot(R)
219 3
    dists = [D[i, j] for (i, j) in matchidx]
220 3
    (i, j) = matchidx[np.argmax(dists)]
221 3
    if i >= I1.shape[0] and j >= I2.shape[0]:
222 0
        return
223 3
    if i >= I1.shape[0]:
224 0
        diagElem = np.array([I2Rot[j, 0], 0])
225 0
        diagElem = diagElem.dot(R.T)
226 0
        plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
227 3
    elif j >= I2.shape[0]:
228 0
        diagElem = np.array([I1Rot[i, 0], 0])
229 0
        diagElem = diagElem.dot(R.T)
230 0
        plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
231
    else:
232 3
        plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
233

234

235 3
def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"], colors=None, ax=None):
236
    """ Visualize bottleneck matching between two diagrams
237

238
    Parameters
239
    ===========
240

241
    I1: array
242
        A diagram
243
    I2: array
244
        A diagram
245
    matchidx: tuples of matched indices
246
        if input `matching=True`, then return matching
247
    labels: list of strings
248
        names of diagrams for legend. Default = ["dgm1", "dgm2"], 
249
    ax: matplotlib Axis object
250
        For plotting on a particular axis.
251

252
    Examples
253
    ==========
254

255
    bn_matching, (matchidx, D) = persim.wasserstien(A_h1, B_h1, matching=True)
256
    persim.wasserstein_matching(A_h1, B_h1, matchidx, D)
257

258
    """
259

260 0
    cp = np.cos(np.pi / 4)
261 0
    sp = np.sin(np.pi / 4)
262 0
    R = np.array([[cp, -sp], [sp, cp]])
263 3
    if I1.size == 0:
264 0
        I1 = np.array([[0, 0]])
265 3
    if I2.size == 0:
266 0
        I2 = np.array([[0, 0]])
267 0
    I1Rot = I1.dot(R)
268 0
    I2Rot = I2.dot(R)
269 3
    for index in matchidx:
270 0
        (i, j) = index
271 3
        if i >= I1.shape[0] and j >= I2.shape[0]:
272 0
            continue
273 3
        if i >= I1.shape[0]:
274 0
            diagElem = np.array([I2Rot[j, 0], 0])
275 0
            diagElem = diagElem.dot(R.T)
276 0
            plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g")
277 3
        elif j >= I2.shape[0]:
278 0
            diagElem = np.array([I1Rot[i, 0], 0])
279 0
            diagElem = diagElem.dot(R.T)
280 0
            plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g")
281
        else:
282 0
            plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g")
283

284 0
    plot_diagrams([I1, I2], labels=labels, ax=ax)

Read our documentation on viewing source code .

Loading