migrate to ghactions for ci/cd
1 | 3 |
import numpy as np |
2 | 3 |
import matplotlib.pyplot as plt |
3 |
|
|
4 | 3 |
__all__ = ["plot_diagrams", "bottleneck_matching", "wasserstein_matching"] |
5 |
|
|
6 |
|
|
7 | 3 |
def plot_diagrams( |
8 |
diagrams, |
|
9 |
plot_only=None, |
|
10 |
title=None, |
|
11 |
xy_range=None, |
|
12 |
labels=None, |
|
13 |
colormap="default", |
|
14 |
size=20, |
|
15 |
ax_color=np.array([0.0, 0.0, 0.0]), |
|
16 |
diagonal=True, |
|
17 |
lifetime=False, |
|
18 |
legend=True, |
|
19 |
show=False, |
|
20 |
ax=None |
|
21 |
):
|
|
22 |
"""A helper function to plot persistence diagrams.
|
|
23 |
|
|
24 |
Parameters
|
|
25 |
----------
|
|
26 |
diagrams: ndarray (n_pairs, 2) or list of diagrams
|
|
27 |
A diagram or list of diagrams. If diagram is a list of diagrams,
|
|
28 |
then plot all on the same plot using different colors.
|
|
29 |
plot_only: list of numeric
|
|
30 |
If specified, an array of only the diagrams that should be plotted.
|
|
31 |
title: string, default is None
|
|
32 |
If title is defined, add it as title of the plot.
|
|
33 |
xy_range: list of numeric [xmin, xmax, ymin, ymax]
|
|
34 |
User provided range of axes. This is useful for comparing
|
|
35 |
multiple persistence diagrams.
|
|
36 |
labels: string or list of strings
|
|
37 |
Legend labels for each diagram.
|
|
38 |
If none are specified, we use H_0, H_1, H_2,... by default.
|
|
39 |
colormap: string, default is 'default'
|
|
40 |
Any of matplotlib color palettes.
|
|
41 |
Some options are 'default', 'seaborn', 'sequential'.
|
|
42 |
See all available styles with
|
|
43 |
|
|
44 |
.. code:: python
|
|
45 |
|
|
46 |
import matplotlib as mpl
|
|
47 |
print(mpl.styles.available)
|
|
48 |
|
|
49 |
size: numeric, default is 20
|
|
50 |
Pixel size of each point plotted.
|
|
51 |
ax_color: any valid matplotlib color type.
|
|
52 |
See [https://matplotlib.org/api/colors_api.html](https://matplotlib.org/api/colors_api.html) for complete API.
|
|
53 |
diagonal: bool, default is True
|
|
54 |
Plot the diagonal x=y line.
|
|
55 |
lifetime: bool, default is False. If True, diagonal is turned to False.
|
|
56 |
Plot life time of each point instead of birth and death.
|
|
57 |
Essentially, visualize (x, y-x).
|
|
58 |
legend: bool, default is True
|
|
59 |
If true, show the legend.
|
|
60 |
show: bool, default is False
|
|
61 |
Call plt.show() after plotting. If you are using self.plot() as part
|
|
62 |
of a subplot, set show=False and call plt.show() only once at the end.
|
|
63 |
"""
|
|
64 |
|
|
65 | 3 |
ax = ax or plt.gca() |
66 | 3 |
plt.style.use(colormap) |
67 |
|
|
68 | 3 |
xlabel, ylabel = "Birth", "Death" |
69 |
|
|
70 |
if labels is None: |
|
71 |
# Provide default labels for diagrams if using self.dgm_
|
|
72 | 3 |
labels = [ |
73 |
"$H_0$", |
|
74 |
"$H_1$", |
|
75 |
"$H_2$", |
|
76 |
"$H_3$", |
|
77 |
"$H_4$", |
|
78 |
"$H_5$", |
|
79 |
"$H_6$", |
|
80 |
"$H_7$", |
|
81 |
"$H_8$", |
|
82 |
]
|
|
83 |
|
|
84 |
if not isinstance(diagrams, list): |
|
85 |
# Must have diagrams as a list for processing downstream
|
|
86 | 3 |
diagrams = [diagrams] |
87 |
|
|
88 |
if plot_only: |
|
89 |
diagrams = [diagrams[i] for i in plot_only] |
|
90 |
labels = [labels[i] for i in plot_only] |
|
91 |
|
|
92 |
if not isinstance(labels, list): |
|
93 |
labels = [labels] * len(diagrams) |
|
94 |
|
|
95 |
# Construct copy with proper type of each diagram
|
|
96 |
# so we can freely edit them.
|
|
97 |
diagrams = [dgm.astype(np.float32, copy=True) for dgm in diagrams] |
|
98 |
|
|
99 |
# find min and max of all visible diagrams
|
|
100 | 3 |
concat_dgms = np.concatenate(diagrams).flatten() |
101 | 3 |
has_inf = np.any(np.isinf(concat_dgms)) |
102 | 3 |
finite_dgms = concat_dgms[np.isfinite(concat_dgms)] |
103 |
|
|
104 |
# clever bounding boxes of the diagram
|
|
105 |
if not xy_range: |
|
106 |
# define bounds of diagram
|
|
107 | 3 |
ax_min, ax_max = np.min(finite_dgms), np.max(finite_dgms) |
108 | 3 |
x_r = ax_max - ax_min |
109 |
|
|
110 |
# Give plot a nice buffer on all sides.
|
|
111 |
# ax_range=0 when only one point,
|
|
112 | 3 |
buffer = 1 if xy_range == 0 else x_r / 5 |
113 |
|
|
114 | 3 |
x_down = ax_min - buffer / 2 |
115 | 3 |
x_up = ax_max + buffer |
116 |
|
|
117 | 3 |
y_down, y_up = x_down, x_up |
118 |
else: |
|
119 |
x_down, x_up, y_down, y_up = xy_range |
|
120 |
|
|
121 | 3 |
yr = y_up - y_down |
122 |
|
|
123 |
if lifetime: |
|
124 |
|
|
125 |
# Don't plot landscape and diagonal at the same time.
|
|
126 | 3 |
diagonal = False |
127 |
|
|
128 |
# reset y axis so it doesn't go much below zero
|
|
129 | 3 |
y_down = -yr * 0.05 |
130 | 3 |
y_up = y_down + yr |
131 |
|
|
132 |
# set custom ylabel
|
|
133 | 3 |
ylabel = "Lifetime" |
134 |
|
|
135 |
# set diagrams to be (x, y-x)
|
|
136 |
for dgm in diagrams: |
|
137 | 3 |
dgm[:, 1] -= dgm[:, 0] |
138 |
|
|
139 |
# plot horizon line
|
|
140 | 3 |
ax.plot([x_down, x_up], [0, 0], c=ax_color) |
141 |
|
|
142 |
# Plot diagonal
|
|
143 |
if diagonal: |
|
144 | 3 |
ax.plot([x_down, x_up], [x_down, x_up], "--", c=ax_color) |
145 |
|
|
146 |
# Plot inf line
|
|
147 |
if has_inf: |
|
148 |
# put inf line slightly below top
|
|
149 | 3 |
b_inf = y_down + yr * 0.95 |
150 | 3 |
ax.plot([x_down, x_up], [b_inf, b_inf], "--", c="k", label=r"$\infty$") |
151 |
|
|
152 |
# convert each inf in each diagram with b_inf
|
|
153 |
for dgm in diagrams: |
|
154 | 3 |
dgm[np.isinf(dgm)] = b_inf |
155 |
|
|
156 |
# Plot each diagram
|
|
157 |
for dgm, label in zip(diagrams, labels): |
|
158 |
|
|
159 |
# plot persistence pairs
|
|
160 | 3 |
ax.scatter(dgm[:, 0], dgm[:, 1], size, label=label, edgecolor="none") |
161 |
|
|
162 | 3 |
ax.set_xlabel(xlabel) |
163 | 3 |
ax.set_ylabel(ylabel) |
164 |
|
|
165 | 3 |
ax.set_xlim([x_down, x_up]) |
166 | 3 |
ax.set_ylim([y_down, y_up]) |
167 | 3 |
ax.set_aspect('equal', 'box') |
168 |
|
|
169 |
if title is not None: |
|
170 | 3 |
ax.set_title(title) |
171 |
|
|
172 |
if legend is True: |
|
173 | 3 |
ax.legend(loc="lower right") |
174 |
|
|
175 |
if show is True: |
|
176 |
plt.show() |
|
177 |
|
|
178 | 3 |
def plot_a_bar(p, q, c='b', linestyle='-'): |
179 |
plt.plot([p[0], q[0]], [p[1], q[1]], c=c, linestyle=linestyle, linewidth=1) |
|
180 |
|
|
181 | 3 |
def bottleneck_matching(I1, I2, matchidx, D, labels=["dgm1", "dgm2"], ax=None): |
182 |
""" Visualize bottleneck matching between two diagrams
|
|
183 |
|
|
184 |
Parameters
|
|
185 |
===========
|
|
186 |
|
|
187 |
I1: array
|
|
188 |
A diagram
|
|
189 |
I2: array
|
|
190 |
A diagram
|
|
191 |
matchidx: tuples of matched indices
|
|
192 |
if input `matching=True`, then return matching
|
|
193 |
D: array
|
|
194 |
cross-similarity matrix
|
|
195 |
labels: list of strings
|
|
196 |
names of diagrams for legend. Default = ["dgm1", "dgm2"],
|
|
197 |
ax: matplotlib Axis object
|
|
198 |
For plotting on a particular axis.
|
|
199 |
|
|
200 |
|
|
201 |
Examples
|
|
202 |
==========
|
|
203 |
|
|
204 |
bn_matching, (matchidx, D) = persim.bottleneck(A_h1, B_h1, matching=True)
|
|
205 |
persim.bottleneck_matching(A_h1, B_h1, matchidx, D)
|
|
206 |
|
|
207 |
"""
|
|
208 |
|
|
209 | 3 |
plot_diagrams([I1, I2], labels=labels, ax=ax) |
210 | 3 |
cp = np.cos(np.pi / 4) |
211 | 3 |
sp = np.sin(np.pi / 4) |
212 | 3 |
R = np.array([[cp, -sp], [sp, cp]]) |
213 |
if I1.size == 0: |
|
214 |
I1 = np.array([[0, 0]]) |
|
215 |
if I2.size == 0: |
|
216 |
I2 = np.array([[0, 0]]) |
|
217 | 3 |
I1Rot = I1.dot(R) |
218 | 3 |
I2Rot = I2.dot(R) |
219 |
dists = [D[i, j] for (i, j) in matchidx] |
|
220 | 3 |
(i, j) = matchidx[np.argmax(dists)] |
221 |
if i >= I1.shape[0] and j >= I2.shape[0]: |
|
222 |
return
|
|
223 |
if i >= I1.shape[0]: |
|
224 |
diagElem = np.array([I2Rot[j, 0], 0]) |
|
225 |
diagElem = diagElem.dot(R.T) |
|
226 |
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g") |
|
227 |
elif j >= I2.shape[0]: |
|
228 |
diagElem = np.array([I1Rot[i, 0], 0]) |
|
229 |
diagElem = diagElem.dot(R.T) |
|
230 |
plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g") |
|
231 |
else: |
|
232 | 3 |
plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g") |
233 |
|
|
234 |
|
|
235 | 3 |
def wasserstein_matching(I1, I2, matchidx, palette=None, labels=["dgm1", "dgm2"], colors=None, ax=None): |
236 |
""" Visualize bottleneck matching between two diagrams
|
|
237 |
|
|
238 |
Parameters
|
|
239 |
===========
|
|
240 |
|
|
241 |
I1: array
|
|
242 |
A diagram
|
|
243 |
I2: array
|
|
244 |
A diagram
|
|
245 |
matchidx: tuples of matched indices
|
|
246 |
if input `matching=True`, then return matching
|
|
247 |
labels: list of strings
|
|
248 |
names of diagrams for legend. Default = ["dgm1", "dgm2"],
|
|
249 |
ax: matplotlib Axis object
|
|
250 |
For plotting on a particular axis.
|
|
251 |
|
|
252 |
Examples
|
|
253 |
==========
|
|
254 |
|
|
255 |
bn_matching, (matchidx, D) = persim.wasserstien(A_h1, B_h1, matching=True)
|
|
256 |
persim.wasserstein_matching(A_h1, B_h1, matchidx, D)
|
|
257 |
|
|
258 |
"""
|
|
259 |
|
|
260 |
cp = np.cos(np.pi / 4) |
|
261 |
sp = np.sin(np.pi / 4) |
|
262 |
R = np.array([[cp, -sp], [sp, cp]]) |
|
263 |
if I1.size == 0: |
|
264 |
I1 = np.array([[0, 0]]) |
|
265 |
if I2.size == 0: |
|
266 |
I2 = np.array([[0, 0]]) |
|
267 |
I1Rot = I1.dot(R) |
|
268 |
I2Rot = I2.dot(R) |
|
269 |
for index in matchidx: |
|
270 |
(i, j) = index |
|
271 |
if i >= I1.shape[0] and j >= I2.shape[0]: |
|
272 |
continue
|
|
273 |
if i >= I1.shape[0]: |
|
274 |
diagElem = np.array([I2Rot[j, 0], 0]) |
|
275 |
diagElem = diagElem.dot(R.T) |
|
276 |
plt.plot([I2[j, 0], diagElem[0]], [I2[j, 1], diagElem[1]], "g") |
|
277 |
elif j >= I2.shape[0]: |
|
278 |
diagElem = np.array([I1Rot[i, 0], 0]) |
|
279 |
diagElem = diagElem.dot(R.T) |
|
280 |
plt.plot([I1[i, 0], diagElem[0]], [I1[i, 1], diagElem[1]], "g") |
|
281 |
else: |
|
282 |
plt.plot([I1[i, 0], I2[j, 0]], [I1[i, 1], I2[j, 1]], "g") |
|
283 |
|
|
284 |
plot_diagrams([I1, I2], labels=labels, ax=ax) |
Read our documentation on viewing source code .