1
|
|
"""Bokeh Densityplot."""
|
2
|
2
|
from collections import defaultdict
|
3
|
2
|
from itertools import cycle
|
4
|
|
|
5
|
2
|
import matplotlib.pyplot as plt
|
6
|
2
|
import numpy as np
|
7
|
2
|
from bokeh.models.annotations import Legend, Title
|
8
|
|
|
9
|
2
|
from ....stats import hdi
|
10
|
2
|
from ....stats.density_utils import get_bins, histogram, kde
|
11
|
2
|
from ...plot_utils import _scale_fig_size, calculate_point_estimate, make_label, vectorized_to_hex
|
12
|
2
|
from .. import show_layout
|
13
|
2
|
from . import backend_kwarg_defaults, create_axes_grid
|
14
|
|
|
15
|
|
|
16
|
2
|
def plot_density(
|
17
|
|
ax,
|
18
|
|
all_labels,
|
19
|
|
to_plot,
|
20
|
|
colors,
|
21
|
|
bw,
|
22
|
|
circular,
|
23
|
|
figsize,
|
24
|
|
length_plotters,
|
25
|
|
rows,
|
26
|
|
cols,
|
27
|
|
textsize,
|
28
|
|
hdi_prob,
|
29
|
|
point_estimate,
|
30
|
|
hdi_markers,
|
31
|
|
outline,
|
32
|
|
shade,
|
33
|
|
n_data,
|
34
|
|
data_labels,
|
35
|
|
backend_kwargs,
|
36
|
|
show,
|
37
|
|
):
|
38
|
|
"""Bokeh density plot."""
|
39
|
2
|
if backend_kwargs is None:
|
40
|
2
|
backend_kwargs = {}
|
41
|
|
|
42
|
2
|
backend_kwargs = {
|
43
|
|
**backend_kwarg_defaults(),
|
44
|
|
**backend_kwargs,
|
45
|
|
}
|
46
|
|
|
47
|
2
|
if colors == "cycle":
|
48
|
2
|
colors = [
|
49
|
|
prop
|
50
|
|
for _, prop in zip(
|
51
|
|
range(n_data), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
|
52
|
|
)
|
53
|
|
]
|
54
|
0
|
elif isinstance(colors, str):
|
55
|
0
|
colors = [colors for _ in range(n_data)]
|
56
|
2
|
colors = vectorized_to_hex(colors)
|
57
|
|
|
58
|
2
|
(figsize, _, _, _, line_width, markersize) = _scale_fig_size(figsize, textsize, rows, cols)
|
59
|
|
|
60
|
2
|
if ax is None:
|
61
|
2
|
ax = create_axes_grid(
|
62
|
|
length_plotters,
|
63
|
|
rows,
|
64
|
|
cols,
|
65
|
|
figsize=figsize,
|
66
|
|
squeeze=True,
|
67
|
|
backend_kwargs=backend_kwargs,
|
68
|
|
)
|
69
|
|
else:
|
70
|
0
|
ax = np.atleast_2d(ax)
|
71
|
|
|
72
|
2
|
axis_map = {
|
73
|
|
label: ax_
|
74
|
|
for label, ax_ in zip(all_labels, (item for item in ax.flatten() if item is not None))
|
75
|
|
}
|
76
|
2
|
if data_labels is None:
|
77
|
0
|
data_labels = {}
|
78
|
|
|
79
|
2
|
legend_items = defaultdict(list)
|
80
|
2
|
for m_idx, plotters in enumerate(to_plot):
|
81
|
2
|
for var_name, selection, values in plotters:
|
82
|
2
|
label = make_label(var_name, selection)
|
83
|
|
|
84
|
2
|
if data_labels:
|
85
|
2
|
data_label = data_labels[m_idx]
|
86
|
|
else:
|
87
|
0
|
data_label = None
|
88
|
|
|
89
|
2
|
plotted = _d_helper(
|
90
|
|
values.flatten(),
|
91
|
|
label,
|
92
|
|
colors[m_idx],
|
93
|
|
bw,
|
94
|
|
circular,
|
95
|
|
line_width,
|
96
|
|
markersize,
|
97
|
|
hdi_prob,
|
98
|
|
point_estimate,
|
99
|
|
hdi_markers,
|
100
|
|
outline,
|
101
|
|
shade,
|
102
|
|
axis_map[label],
|
103
|
|
)
|
104
|
2
|
if data_label is not None:
|
105
|
2
|
legend_items[axis_map[label]].append((data_label, plotted))
|
106
|
|
|
107
|
2
|
for ax1, legend in legend_items.items():
|
108
|
2
|
legend = Legend(
|
109
|
|
items=legend,
|
110
|
|
location="center_right",
|
111
|
|
orientation="horizontal",
|
112
|
|
)
|
113
|
2
|
ax1.add_layout(legend, "above")
|
114
|
2
|
ax1.legend.click_policy = "hide"
|
115
|
|
|
116
|
2
|
show_layout(ax, show)
|
117
|
|
|
118
|
2
|
return ax
|
119
|
|
|
120
|
|
|
121
|
2
|
def _d_helper(
|
122
|
|
vec,
|
123
|
|
vname,
|
124
|
|
color,
|
125
|
|
bw,
|
126
|
|
circular,
|
127
|
|
line_width,
|
128
|
|
markersize,
|
129
|
|
hdi_prob,
|
130
|
|
point_estimate,
|
131
|
|
hdi_markers,
|
132
|
|
outline,
|
133
|
|
shade,
|
134
|
|
ax,
|
135
|
|
):
|
136
|
|
|
137
|
2
|
extra = dict()
|
138
|
2
|
plotted = []
|
139
|
|
|
140
|
2
|
if vec.dtype.kind == "f":
|
141
|
2
|
if hdi_prob != 1:
|
142
|
2
|
hdi_ = hdi(vec, hdi_prob, multimodal=False)
|
143
|
2
|
new_vec = vec[(vec >= hdi_[0]) & (vec <= hdi_[1])]
|
144
|
|
else:
|
145
|
2
|
new_vec = vec
|
146
|
|
|
147
|
2
|
x, density = kde(new_vec, circular=circular, bw=bw)
|
148
|
2
|
density *= hdi_prob
|
149
|
2
|
xmin, xmax = x[0], x[-1]
|
150
|
2
|
ymin, ymax = density[0], density[-1]
|
151
|
|
|
152
|
2
|
if outline:
|
153
|
2
|
plotted.append(ax.line(x, density, line_color=color, line_width=line_width, **extra))
|
154
|
2
|
plotted.append(
|
155
|
|
ax.line(
|
156
|
|
[xmin, xmin],
|
157
|
|
[-ymin / 100, ymin],
|
158
|
|
line_color=color,
|
159
|
|
line_dash="solid",
|
160
|
|
line_width=line_width,
|
161
|
|
muted_color=color,
|
162
|
|
muted_alpha=0.2,
|
163
|
|
)
|
164
|
|
)
|
165
|
2
|
plotted.append(
|
166
|
|
ax.line(
|
167
|
|
[xmax, xmax],
|
168
|
|
[-ymax / 100, ymax],
|
169
|
|
line_color=color,
|
170
|
|
line_dash="solid",
|
171
|
|
line_width=line_width,
|
172
|
|
muted_color=color,
|
173
|
|
muted_alpha=0.2,
|
174
|
|
)
|
175
|
|
)
|
176
|
|
|
177
|
2
|
if shade:
|
178
|
2
|
plotted.append(
|
179
|
|
ax.patch(
|
180
|
|
np.r_[x[::-1], x, x[-1:]],
|
181
|
|
np.r_[np.zeros_like(x), density, [0]],
|
182
|
|
fill_color=color,
|
183
|
|
fill_alpha=shade,
|
184
|
|
muted_color=color,
|
185
|
|
muted_alpha=0.2,
|
186
|
|
**extra
|
187
|
|
)
|
188
|
|
)
|
189
|
|
|
190
|
|
else:
|
191
|
2
|
xmin, xmax = hdi(vec, hdi_prob, multimodal=False)
|
192
|
2
|
bins = get_bins(vec)
|
193
|
|
|
194
|
2
|
_, hist, edges = histogram(vec, bins=bins)
|
195
|
|
|
196
|
2
|
if outline:
|
197
|
2
|
plotted.append(
|
198
|
|
ax.quad(
|
199
|
|
top=hist,
|
200
|
|
bottom=0,
|
201
|
|
left=edges[:-1],
|
202
|
|
right=edges[1:],
|
203
|
|
line_color=color,
|
204
|
|
fill_color=None,
|
205
|
|
muted_color=color,
|
206
|
|
muted_alpha=0.2,
|
207
|
|
**extra
|
208
|
|
)
|
209
|
|
)
|
210
|
|
else:
|
211
|
0
|
plotted.append(
|
212
|
|
ax.quad(
|
213
|
|
top=hist,
|
214
|
|
bottom=0,
|
215
|
|
left=edges[:-1],
|
216
|
|
right=edges[1:],
|
217
|
|
line_color=color,
|
218
|
|
fill_color=color,
|
219
|
|
fill_alpha=shade,
|
220
|
|
muted_color=color,
|
221
|
|
muted_alpha=0.2,
|
222
|
|
**extra
|
223
|
|
)
|
224
|
|
)
|
225
|
|
|
226
|
2
|
if hdi_markers:
|
227
|
2
|
plotted.append(ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize))
|
228
|
2
|
plotted.append(ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize))
|
229
|
|
|
230
|
2
|
if point_estimate is not None:
|
231
|
2
|
est = calculate_point_estimate(point_estimate, vec, bw, circular)
|
232
|
2
|
plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize))
|
233
|
|
|
234
|
2
|
_title = Title()
|
235
|
2
|
_title.text = vname
|
236
|
2
|
ax.title = _title
|
237
|
2
|
ax.title.text_font_size = "13pt"
|
238
|
|
|
239
|
2
|
return plotted
|