1
|
|
"""Matplotlib khatplot."""
|
2
|
2
|
import warnings
|
3
|
|
|
4
|
2
|
import matplotlib as mpl
|
5
|
2
|
import matplotlib.cm as cm
|
6
|
2
|
import matplotlib.pyplot as plt
|
7
|
2
|
import numpy as np
|
8
|
2
|
from matplotlib.colors import to_rgba_array
|
9
|
|
|
10
|
2
|
from ....stats.density_utils import histogram
|
11
|
2
|
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
|
12
|
2
|
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
|
13
|
|
|
14
|
|
|
15
|
2
|
def plot_khat(
|
16
|
|
hover_label,
|
17
|
|
hover_format,
|
18
|
|
ax,
|
19
|
|
figsize,
|
20
|
|
xdata,
|
21
|
|
khats,
|
22
|
|
kwargs,
|
23
|
|
annotate,
|
24
|
|
coord_labels,
|
25
|
|
show_bins,
|
26
|
|
hlines_kwargs,
|
27
|
|
xlabels,
|
28
|
|
legend,
|
29
|
|
color,
|
30
|
|
dims,
|
31
|
|
textsize,
|
32
|
|
markersize,
|
33
|
|
n_data_points,
|
34
|
|
bin_format,
|
35
|
|
backend_kwargs,
|
36
|
|
show,
|
37
|
|
):
|
38
|
|
"""Matplotlib khat plot."""
|
39
|
2
|
if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk:
|
40
|
2
|
hover_label = False
|
41
|
2
|
warnings.warn(
|
42
|
|
"hover labels are only available with interactive backends. To switch to an "
|
43
|
|
"interactive backend from ipython or jupyter, use `%matplotlib` there should be "
|
44
|
|
"no need to restart the kernel. For other cases, see "
|
45
|
|
"https://matplotlib.org/3.1.0/tutorials/introductory/usage.html#backends",
|
46
|
|
UserWarning,
|
47
|
|
)
|
48
|
|
|
49
|
2
|
if backend_kwargs is None:
|
50
|
2
|
backend_kwargs = {}
|
51
|
|
|
52
|
2
|
backend_kwargs = {
|
53
|
|
**backend_kwarg_defaults(constrained_layout=(not xlabels)),
|
54
|
|
**backend_kwargs,
|
55
|
|
}
|
56
|
|
|
57
|
2
|
(figsize, ax_labelsize, _, xt_labelsize, linewidth, scaled_markersize) = _scale_fig_size(
|
58
|
|
figsize, textsize
|
59
|
|
)
|
60
|
2
|
backend_kwargs.setdefault("figsize", figsize)
|
61
|
2
|
backend_kwargs["squeeze"] = True
|
62
|
|
|
63
|
2
|
hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
|
64
|
2
|
hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
|
65
|
2
|
hlines_kwargs.setdefault("alpha", 0.7)
|
66
|
2
|
hlines_kwargs.setdefault("zorder", -1)
|
67
|
2
|
hlines_kwargs.setdefault("color", "C1")
|
68
|
2
|
hlines_kwargs["color"] = vectorized_to_hex(hlines_kwargs["color"])
|
69
|
|
|
70
|
2
|
if markersize is None:
|
71
|
2
|
markersize = scaled_markersize ** 2 # s in scatter plot mus be markersize square
|
72
|
|
# for dots to have the same size
|
73
|
|
|
74
|
2
|
kwargs = matplotlib_kwarg_dealiaser(kwargs, "scatter")
|
75
|
2
|
kwargs.setdefault("s", markersize)
|
76
|
2
|
kwargs.setdefault("marker", "+")
|
77
|
2
|
color_mapping = None
|
78
|
2
|
cmap = None
|
79
|
2
|
if isinstance(color, str):
|
80
|
2
|
if color in dims:
|
81
|
2
|
colors, color_mapping = color_from_dim(khats, color)
|
82
|
2
|
cmap_name = kwargs.get("cmap", plt.rcParams["image.cmap"])
|
83
|
2
|
cmap = getattr(cm, cmap_name)
|
84
|
2
|
rgba_c = cmap(colors)
|
85
|
|
else:
|
86
|
2
|
legend = False
|
87
|
2
|
rgba_c = to_rgba_array(np.full(n_data_points, color))
|
88
|
|
else:
|
89
|
2
|
legend = False
|
90
|
2
|
try:
|
91
|
2
|
rgba_c = to_rgba_array(color)
|
92
|
2
|
except ValueError:
|
93
|
2
|
cmap_name = kwargs.get("cmap", plt.rcParams["image.cmap"])
|
94
|
2
|
cmap = getattr(cm, cmap_name)
|
95
|
2
|
rgba_c = cmap(color)
|
96
|
|
|
97
|
2
|
khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
|
98
|
2
|
alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
|
99
|
2
|
rgba_c[:, 3] = alphas
|
100
|
2
|
rgba_c = vectorized_to_hex(rgba_c)
|
101
|
|
|
102
|
2
|
if ax is None:
|
103
|
2
|
fig, ax = create_axes_grid(
|
104
|
|
1,
|
105
|
|
backend_kwargs=backend_kwargs,
|
106
|
|
)
|
107
|
|
else:
|
108
|
0
|
fig = ax.get_figure()
|
109
|
|
|
110
|
2
|
sc_plot = ax.scatter(xdata, khats, c=rgba_c, **kwargs)
|
111
|
|
|
112
|
2
|
if annotate:
|
113
|
2
|
idxs = xdata[khats > 1]
|
114
|
2
|
for idx in idxs:
|
115
|
2
|
ax.text(
|
116
|
|
idx,
|
117
|
|
khats[idx],
|
118
|
|
coord_labels[idx],
|
119
|
|
horizontalalignment="center",
|
120
|
|
verticalalignment="bottom",
|
121
|
|
fontsize=0.8 * xt_labelsize,
|
122
|
|
)
|
123
|
|
|
124
|
2
|
xmin, xmax = ax.get_xlim()
|
125
|
2
|
if show_bins:
|
126
|
2
|
xmax += n_data_points / 12
|
127
|
2
|
ylims1 = ax.get_ylim()
|
128
|
2
|
ax.hlines([0, 0.5, 0.7, 1], xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs)
|
129
|
2
|
ylims2 = ax.get_ylim()
|
130
|
2
|
ymin = min(ylims1[0], ylims2[0])
|
131
|
2
|
ymax = min(ylims1[1], ylims2[1])
|
132
|
2
|
if show_bins:
|
133
|
2
|
bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
|
134
|
2
|
bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
|
135
|
2
|
hist, _, _ = histogram(khats, bin_edges)
|
136
|
2
|
for idx, count in enumerate(hist):
|
137
|
2
|
ax.text(
|
138
|
|
(n_data_points - 1 + xmax) / 2,
|
139
|
|
np.mean(bin_edges[idx : idx + 2]),
|
140
|
|
bin_format.format(count, count / n_data_points * 100),
|
141
|
|
horizontalalignment="center",
|
142
|
|
verticalalignment="center",
|
143
|
|
)
|
144
|
2
|
ax.set_ylim(ymin, ymax)
|
145
|
2
|
ax.set_xlim(xmin, xmax)
|
146
|
|
|
147
|
2
|
ax.set_xlabel("Data Point", fontsize=ax_labelsize)
|
148
|
2
|
ax.set_ylabel(r"Shape parameter k", fontsize=ax_labelsize)
|
149
|
2
|
ax.tick_params(labelsize=xt_labelsize)
|
150
|
2
|
if xlabels:
|
151
|
2
|
set_xticklabels(ax, coord_labels)
|
152
|
2
|
fig.autofmt_xdate()
|
153
|
2
|
fig.tight_layout()
|
154
|
2
|
if legend:
|
155
|
2
|
ncols = len(color_mapping) // 6 + 1
|
156
|
2
|
for label, float_color in color_mapping.items():
|
157
|
2
|
ax.scatter([], [], c=[cmap(float_color)], label=label, **kwargs)
|
158
|
2
|
ax.legend(ncol=ncols, title=color)
|
159
|
|
|
160
|
2
|
if hover_label and mpl.get_backend() in mpl.rcsetup.interactive_bk:
|
161
|
0
|
_make_hover_annotation(fig, ax, sc_plot, coord_labels, rgba_c, hover_format)
|
162
|
|
|
163
|
2
|
if backend_show(show):
|
164
|
0
|
plt.show()
|
165
|
|
|
166
|
2
|
return ax
|
167
|
|
|
168
|
|
|
169
|
2
|
def _make_hover_annotation(fig, ax, sc_plot, coord_labels, rgba_c, hover_format):
|
170
|
|
"""Show data point label when hovering over it with mouse."""
|
171
|
0
|
annot = ax.annotate(
|
172
|
|
"",
|
173
|
|
xy=(0, 0),
|
174
|
|
xytext=(0, 0),
|
175
|
|
textcoords="offset points",
|
176
|
|
bbox=dict(boxstyle="round", fc="w", alpha=0.4),
|
177
|
|
arrowprops=dict(arrowstyle="->"),
|
178
|
|
)
|
179
|
0
|
annot.set_visible(False)
|
180
|
0
|
xmid = np.mean(ax.get_xlim())
|
181
|
0
|
ymid = np.mean(ax.get_ylim())
|
182
|
0
|
offset = 10
|
183
|
|
|
184
|
0
|
def update_annot(ind):
|
185
|
|
|
186
|
0
|
idx = ind["ind"][0]
|
187
|
0
|
pos = sc_plot.get_offsets()[idx]
|
188
|
0
|
annot_text = hover_format.format(idx, coord_labels[idx])
|
189
|
0
|
annot.xy = pos
|
190
|
0
|
annot.set_position(
|
191
|
|
(-offset if pos[0] > xmid else offset, -offset if pos[1] > ymid else offset)
|
192
|
|
)
|
193
|
0
|
annot.set_text(annot_text)
|
194
|
0
|
annot.get_bbox_patch().set_facecolor(rgba_c[idx])
|
195
|
0
|
annot.set_ha("right" if pos[0] > xmid else "left")
|
196
|
0
|
annot.set_va("top" if pos[1] > ymid else "bottom")
|
197
|
|
|
198
|
0
|
def hover(event):
|
199
|
0
|
vis = annot.get_visible()
|
200
|
0
|
if event.inaxes == ax:
|
201
|
0
|
cont, ind = sc_plot.contains(event)
|
202
|
0
|
if cont:
|
203
|
0
|
update_annot(ind)
|
204
|
0
|
annot.set_visible(True)
|
205
|
0
|
fig.canvas.draw_idle()
|
206
|
|
else:
|
207
|
0
|
if vis:
|
208
|
0
|
annot.set_visible(False)
|
209
|
0
|
fig.canvas.draw_idle()
|
210
|
|
|
211
|
0
|
fig.canvas.mpl_connect("motion_notify_event", hover)
|