1
|
|
"""Matplotlib Plot posterior densities."""
|
2
|
2
|
from numbers import Number
|
3
|
|
|
4
|
2
|
import matplotlib.pyplot as plt
|
5
|
2
|
import numpy as np
|
6
|
|
|
7
|
2
|
from ....stats import hdi
|
8
|
2
|
from ....stats.density_utils import get_bins
|
9
|
2
|
from ...kdeplot import plot_kde
|
10
|
2
|
from ...plot_utils import (
|
11
|
|
_scale_fig_size,
|
12
|
|
calculate_point_estimate,
|
13
|
|
format_sig_figs,
|
14
|
|
make_label,
|
15
|
|
round_num,
|
16
|
|
)
|
17
|
2
|
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
|
18
|
|
|
19
|
|
|
20
|
2
|
def plot_posterior(
|
21
|
|
ax,
|
22
|
|
length_plotters,
|
23
|
|
rows,
|
24
|
|
cols,
|
25
|
|
figsize,
|
26
|
|
plotters,
|
27
|
|
bw,
|
28
|
|
circular,
|
29
|
|
bins,
|
30
|
|
kind,
|
31
|
|
point_estimate,
|
32
|
|
round_to,
|
33
|
|
hdi_prob,
|
34
|
|
multimodal,
|
35
|
|
textsize,
|
36
|
|
ref_val,
|
37
|
|
rope,
|
38
|
|
kwargs,
|
39
|
|
backend_kwargs,
|
40
|
|
show,
|
41
|
|
):
|
42
|
|
"""Matplotlib posterior plot."""
|
43
|
2
|
if backend_kwargs is None:
|
44
|
2
|
backend_kwargs = {}
|
45
|
|
|
46
|
2
|
backend_kwargs = {
|
47
|
|
**backend_kwarg_defaults(),
|
48
|
|
**backend_kwargs,
|
49
|
|
}
|
50
|
|
|
51
|
2
|
(figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _) = _scale_fig_size(
|
52
|
|
figsize, textsize, rows, cols
|
53
|
|
)
|
54
|
2
|
backend_kwargs.setdefault("figsize", figsize)
|
55
|
2
|
backend_kwargs.setdefault("squeeze", True)
|
56
|
|
|
57
|
2
|
if kind == "hist":
|
58
|
2
|
kwargs = matplotlib_kwarg_dealiaser(kwargs, "hist")
|
59
|
|
else:
|
60
|
2
|
kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot")
|
61
|
2
|
kwargs.setdefault("linewidth", _linewidth)
|
62
|
|
|
63
|
2
|
if ax is None:
|
64
|
2
|
_, ax = create_axes_grid(
|
65
|
|
length_plotters,
|
66
|
|
rows,
|
67
|
|
cols,
|
68
|
|
backend_kwargs=backend_kwargs,
|
69
|
|
)
|
70
|
2
|
idx = 0
|
71
|
2
|
for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)):
|
72
|
2
|
_plot_posterior_op(
|
73
|
|
idx,
|
74
|
|
x.flatten(),
|
75
|
|
var_name,
|
76
|
|
selection,
|
77
|
|
ax=ax_,
|
78
|
|
bw=bw,
|
79
|
|
circular=circular,
|
80
|
|
bins=bins,
|
81
|
|
kind=kind,
|
82
|
|
point_estimate=point_estimate,
|
83
|
|
round_to=round_to,
|
84
|
|
hdi_prob=hdi_prob,
|
85
|
|
multimodal=multimodal,
|
86
|
|
ref_val=ref_val,
|
87
|
|
rope=rope,
|
88
|
|
ax_labelsize=ax_labelsize,
|
89
|
|
xt_labelsize=xt_labelsize,
|
90
|
|
**kwargs,
|
91
|
|
)
|
92
|
2
|
idx += 1
|
93
|
2
|
ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True)
|
94
|
|
|
95
|
2
|
if backend_show(show):
|
96
|
0
|
plt.show()
|
97
|
|
|
98
|
2
|
return ax
|
99
|
|
|
100
|
|
|
101
|
2
|
def _plot_posterior_op(
|
102
|
|
idx,
|
103
|
|
values,
|
104
|
|
var_name,
|
105
|
|
selection,
|
106
|
|
ax,
|
107
|
|
bw,
|
108
|
|
circular,
|
109
|
|
linewidth,
|
110
|
|
bins,
|
111
|
|
kind,
|
112
|
|
point_estimate,
|
113
|
|
hdi_prob,
|
114
|
|
multimodal,
|
115
|
|
ref_val,
|
116
|
|
rope,
|
117
|
|
ax_labelsize,
|
118
|
|
xt_labelsize,
|
119
|
|
round_to=None,
|
120
|
|
**kwargs,
|
121
|
|
): # noqa: D202
|
122
|
|
"""Artist to draw posterior."""
|
123
|
|
|
124
|
2
|
def format_as_percent(x, round_to=0):
|
125
|
2
|
return "{0:.{1:d}f}%".format(100 * x, round_to)
|
126
|
|
|
127
|
2
|
def display_ref_val():
|
128
|
2
|
if ref_val is None:
|
129
|
2
|
return
|
130
|
2
|
elif isinstance(ref_val, dict):
|
131
|
2
|
val = None
|
132
|
2
|
for sel in ref_val.get(var_name, []):
|
133
|
2
|
if all(
|
134
|
|
k in selection and selection[k] == v for k, v in sel.items() if k != "ref_val"
|
135
|
|
):
|
136
|
2
|
val = sel["ref_val"]
|
137
|
2
|
break
|
138
|
2
|
if val is None:
|
139
|
2
|
return
|
140
|
2
|
elif isinstance(ref_val, list):
|
141
|
0
|
val = ref_val[idx]
|
142
|
2
|
elif isinstance(ref_val, Number):
|
143
|
2
|
val = ref_val
|
144
|
|
else:
|
145
|
2
|
raise ValueError(
|
146
|
|
"Argument `ref_val` must be None, a constant, a list or a "
|
147
|
|
'dictionary like {"var_name": [{"ref_val": ref_val}]}'
|
148
|
|
)
|
149
|
2
|
less_than_ref_probability = (values < val).mean()
|
150
|
2
|
greater_than_ref_probability = (values >= val).mean()
|
151
|
2
|
ref_in_posterior = "{} <{:g}< {}".format(
|
152
|
|
format_as_percent(less_than_ref_probability, 1),
|
153
|
|
val,
|
154
|
|
format_as_percent(greater_than_ref_probability, 1),
|
155
|
|
)
|
156
|
2
|
ax.axvline(val, ymin=0.05, ymax=0.75, color="C1", lw=linewidth, alpha=0.65)
|
157
|
2
|
ax.text(
|
158
|
|
values.mean(),
|
159
|
|
plot_height * 0.6,
|
160
|
|
ref_in_posterior,
|
161
|
|
size=ax_labelsize,
|
162
|
|
color="C1",
|
163
|
|
weight="semibold",
|
164
|
|
horizontalalignment="center",
|
165
|
|
)
|
166
|
|
|
167
|
2
|
def display_rope():
|
168
|
2
|
if rope is None:
|
169
|
2
|
return
|
170
|
2
|
elif isinstance(rope, dict):
|
171
|
2
|
vals = None
|
172
|
2
|
for sel in rope.get(var_name, []):
|
173
|
|
# pylint: disable=line-too-long
|
174
|
2
|
if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
|
175
|
2
|
vals = sel["rope"]
|
176
|
2
|
break
|
177
|
2
|
if vals is None:
|
178
|
2
|
return
|
179
|
2
|
elif len(rope) == 2:
|
180
|
2
|
vals = rope
|
181
|
|
else:
|
182
|
2
|
raise ValueError(
|
183
|
|
"Argument `rope` must be None, a dictionary like"
|
184
|
|
'{"var_name": {"rope": (lo, hi)}}, or an'
|
185
|
|
"iterable of length 2"
|
186
|
|
)
|
187
|
|
|
188
|
2
|
ax.plot(
|
189
|
|
vals,
|
190
|
|
(plot_height * 0.02, plot_height * 0.02),
|
191
|
|
lw=linewidth * 5,
|
192
|
|
color="C2",
|
193
|
|
solid_capstyle="butt",
|
194
|
|
zorder=0,
|
195
|
|
alpha=0.7,
|
196
|
|
)
|
197
|
2
|
text_props = {"size": ax_labelsize, "color": "C2"}
|
198
|
2
|
ax.text(
|
199
|
|
vals[0],
|
200
|
|
plot_height * 0.2,
|
201
|
|
f"{vals[0]} ",
|
202
|
|
weight="semibold",
|
203
|
|
horizontalalignment="right",
|
204
|
|
**text_props,
|
205
|
|
)
|
206
|
2
|
ax.text(
|
207
|
|
vals[1],
|
208
|
|
plot_height * 0.2,
|
209
|
|
f" {vals[1]}",
|
210
|
|
weight="semibold",
|
211
|
|
horizontalalignment="left",
|
212
|
|
**text_props,
|
213
|
|
)
|
214
|
|
|
215
|
2
|
def display_point_estimate():
|
216
|
2
|
if not point_estimate:
|
217
|
2
|
return
|
218
|
2
|
point_value = calculate_point_estimate(point_estimate, values, bw, circular)
|
219
|
2
|
sig_figs = format_sig_figs(point_value, round_to)
|
220
|
2
|
point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
|
221
|
|
point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
|
222
|
|
)
|
223
|
2
|
ax.text(
|
224
|
|
point_value,
|
225
|
|
plot_height * 0.8,
|
226
|
|
point_text,
|
227
|
|
size=ax_labelsize,
|
228
|
|
horizontalalignment="center",
|
229
|
|
)
|
230
|
|
|
231
|
2
|
def display_hdi():
|
232
|
|
# np.ndarray with 2 entries, min and max
|
233
|
|
# pylint: disable=line-too-long
|
234
|
2
|
hdi_probs = hdi(values, hdi_prob=hdi_prob, multimodal=multimodal) # type: np.ndarray
|
235
|
|
|
236
|
2
|
for hdi_i in np.atleast_2d(hdi_probs):
|
237
|
2
|
ax.plot(
|
238
|
|
hdi_i,
|
239
|
|
(plot_height * 0.02, plot_height * 0.02),
|
240
|
|
lw=linewidth * 2,
|
241
|
|
color="k",
|
242
|
|
solid_capstyle="butt",
|
243
|
|
)
|
244
|
2
|
ax.text(
|
245
|
|
hdi_i[0],
|
246
|
|
plot_height * 0.07,
|
247
|
|
round_num(hdi_i[0], round_to) + " ",
|
248
|
|
size=ax_labelsize,
|
249
|
|
horizontalalignment="right",
|
250
|
|
)
|
251
|
2
|
ax.text(
|
252
|
|
hdi_i[1],
|
253
|
|
plot_height * 0.07,
|
254
|
|
" " + round_num(hdi_i[1], round_to),
|
255
|
|
size=ax_labelsize,
|
256
|
|
horizontalalignment="left",
|
257
|
|
)
|
258
|
2
|
ax.text(
|
259
|
|
(hdi_i[0] + hdi_i[1]) / 2,
|
260
|
|
plot_height * 0.3,
|
261
|
|
format_as_percent(hdi_prob) + " HDI",
|
262
|
|
size=ax_labelsize,
|
263
|
|
horizontalalignment="center",
|
264
|
|
)
|
265
|
|
|
266
|
2
|
def format_axes():
|
267
|
2
|
ax.yaxis.set_ticks([])
|
268
|
2
|
ax.spines["top"].set_visible(False)
|
269
|
2
|
ax.spines["right"].set_visible(False)
|
270
|
2
|
ax.spines["left"].set_visible(False)
|
271
|
2
|
ax.spines["bottom"].set_visible(True)
|
272
|
2
|
ax.xaxis.set_ticks_position("bottom")
|
273
|
2
|
ax.tick_params(
|
274
|
|
axis="x", direction="out", width=1, length=3, color="0.5", labelsize=xt_labelsize
|
275
|
|
)
|
276
|
2
|
ax.spines["bottom"].set_color("0.5")
|
277
|
|
|
278
|
2
|
if kind == "kde" and values.dtype.kind == "f":
|
279
|
2
|
kwargs.setdefault("linewidth", linewidth)
|
280
|
2
|
plot_kde(
|
281
|
|
values,
|
282
|
|
bw=bw,
|
283
|
|
circular=circular,
|
284
|
|
fill_kwargs={"alpha": kwargs.pop("fill_alpha", 0)},
|
285
|
|
plot_kwargs=kwargs,
|
286
|
|
ax=ax,
|
287
|
|
rug=False,
|
288
|
|
show=False,
|
289
|
|
)
|
290
|
|
else:
|
291
|
2
|
if bins is None:
|
292
|
2
|
if values.dtype.kind == "i":
|
293
|
2
|
xmin = values.min()
|
294
|
2
|
xmax = values.max()
|
295
|
2
|
bins = get_bins(values)
|
296
|
2
|
ax.set_xlim(xmin - 0.5, xmax + 0.5)
|
297
|
|
else:
|
298
|
2
|
bins = "auto"
|
299
|
2
|
kwargs.setdefault("align", "left")
|
300
|
2
|
kwargs.setdefault("color", "C0")
|
301
|
2
|
ax.hist(values, bins=bins, alpha=0.35, **kwargs)
|
302
|
|
|
303
|
2
|
plot_height = ax.get_ylim()[1]
|
304
|
|
|
305
|
2
|
format_axes()
|
306
|
2
|
if hdi_prob != "hide":
|
307
|
2
|
display_hdi()
|
308
|
2
|
display_point_estimate()
|
309
|
2
|
display_ref_val()
|
310
|
2
|
display_rope()
|