1
|
|
"""Matplotlib loopitplot."""
|
2
|
2
|
import matplotlib.pyplot as plt
|
3
|
2
|
import numpy as np
|
4
|
2
|
from matplotlib.colors import hsv_to_rgb, rgb_to_hsv, to_hex, to_rgb
|
5
|
2
|
from xarray import DataArray
|
6
|
|
|
7
|
2
|
from ....stats.density_utils import kde
|
8
|
2
|
from ...plot_utils import _scale_fig_size
|
9
|
2
|
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
|
10
|
|
|
11
|
|
|
12
|
2
|
def plot_loo_pit(
|
13
|
|
ax,
|
14
|
|
figsize,
|
15
|
|
ecdf,
|
16
|
|
loo_pit,
|
17
|
|
loo_pit_ecdf,
|
18
|
|
unif_ecdf,
|
19
|
|
p975,
|
20
|
|
p025,
|
21
|
|
fill_kwargs,
|
22
|
|
ecdf_fill,
|
23
|
|
use_hdi,
|
24
|
|
x_vals,
|
25
|
|
hdi_kwargs,
|
26
|
|
hdi_odds,
|
27
|
|
n_unif,
|
28
|
|
unif,
|
29
|
|
plot_unif_kwargs,
|
30
|
|
loo_pit_kde,
|
31
|
|
legend,
|
32
|
|
y_hat,
|
33
|
|
y,
|
34
|
|
color,
|
35
|
|
textsize,
|
36
|
|
credible_interval,
|
37
|
|
plot_kwargs,
|
38
|
|
backend_kwargs,
|
39
|
|
show,
|
40
|
|
):
|
41
|
|
"""Matplotlib loo pit plot."""
|
42
|
2
|
if backend_kwargs is None:
|
43
|
2
|
backend_kwargs = {}
|
44
|
|
|
45
|
2
|
backend_kwargs = {
|
46
|
|
**backend_kwarg_defaults(),
|
47
|
|
**backend_kwargs,
|
48
|
|
}
|
49
|
|
|
50
|
2
|
(figsize, _, _, xt_labelsize, linewidth, _) = _scale_fig_size(figsize, textsize, 1, 1)
|
51
|
2
|
backend_kwargs.setdefault("figsize", figsize)
|
52
|
2
|
backend_kwargs["squeeze"] = True
|
53
|
|
|
54
|
2
|
if ax is None:
|
55
|
2
|
_, ax = create_axes_grid(1, backend_kwargs=backend_kwargs)
|
56
|
|
|
57
|
2
|
plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "plot")
|
58
|
2
|
plot_kwargs["color"] = to_hex(color)
|
59
|
2
|
plot_kwargs.setdefault("linewidth", linewidth * 1.4)
|
60
|
2
|
if isinstance(y, str):
|
61
|
2
|
label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y)
|
62
|
2
|
elif isinstance(y, DataArray) and y.name is not None:
|
63
|
2
|
label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y.name)
|
64
|
2
|
elif isinstance(y_hat, str):
|
65
|
2
|
label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat)
|
66
|
2
|
elif isinstance(y_hat, DataArray) and y_hat.name is not None:
|
67
|
2
|
label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat.name)
|
68
|
|
else:
|
69
|
2
|
label = "LOO-PIT ECDF" if ecdf else "LOO-PIT"
|
70
|
|
|
71
|
2
|
plot_kwargs.setdefault("label", label)
|
72
|
2
|
plot_kwargs.setdefault("zorder", 5)
|
73
|
|
|
74
|
2
|
plot_unif_kwargs = matplotlib_kwarg_dealiaser(plot_unif_kwargs, "plot")
|
75
|
2
|
light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
|
76
|
2
|
light_color[1] /= 2 # pylint: disable=unsupported-assignment-operation
|
77
|
2
|
light_color[2] += (1 - light_color[2]) / 2 # pylint: disable=unsupported-assignment-operation
|
78
|
2
|
plot_unif_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
|
79
|
2
|
plot_unif_kwargs.setdefault("alpha", 0.5)
|
80
|
2
|
plot_unif_kwargs.setdefault("linewidth", 0.6 * linewidth)
|
81
|
|
|
82
|
2
|
if ecdf:
|
83
|
2
|
n_data_points = loo_pit.size
|
84
|
2
|
plot_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
|
85
|
2
|
plot_unif_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
|
86
|
|
|
87
|
2
|
if ecdf_fill:
|
88
|
2
|
if fill_kwargs is None:
|
89
|
2
|
fill_kwargs = {}
|
90
|
2
|
fill_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
|
91
|
2
|
fill_kwargs.setdefault("alpha", 0.5)
|
92
|
2
|
fill_kwargs.setdefault(
|
93
|
|
"step", "mid" if plot_kwargs["drawstyle"] == "steps-mid" else None
|
94
|
|
)
|
95
|
2
|
fill_kwargs.setdefault("label", "{:.3g}% credible interval".format(credible_interval))
|
96
|
2
|
elif use_hdi:
|
97
|
2
|
if hdi_kwargs is None:
|
98
|
2
|
hdi_kwargs = {}
|
99
|
2
|
hdi_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
|
100
|
2
|
hdi_kwargs.setdefault("alpha", 0.35)
|
101
|
2
|
hdi_kwargs.setdefault("label", "Uniform HDI")
|
102
|
|
|
103
|
2
|
if ecdf:
|
104
|
2
|
ax.plot(
|
105
|
|
np.hstack((0, loo_pit, 1)), np.hstack((0, loo_pit - loo_pit_ecdf, 0)), **plot_kwargs
|
106
|
|
)
|
107
|
|
|
108
|
2
|
if ecdf_fill:
|
109
|
2
|
ax.fill_between(unif_ecdf, p975 - unif_ecdf, p025 - unif_ecdf, **fill_kwargs)
|
110
|
|
else:
|
111
|
2
|
ax.plot(unif_ecdf, p975 - unif_ecdf, unif_ecdf, p025 - unif_ecdf, **plot_unif_kwargs)
|
112
|
|
else:
|
113
|
2
|
x_ss = np.empty((n_unif, len(loo_pit_kde)))
|
114
|
2
|
u_dens = np.empty((n_unif, len(loo_pit_kde)))
|
115
|
2
|
if use_hdi:
|
116
|
2
|
ax.axhspan(*hdi_odds, **hdi_kwargs)
|
117
|
|
else:
|
118
|
2
|
for idx in range(n_unif):
|
119
|
2
|
x_s, unif_density = kde(unif[idx, :])
|
120
|
2
|
x_ss[idx] = x_s
|
121
|
2
|
u_dens[idx] = unif_density
|
122
|
2
|
ax.plot(x_ss.T, u_dens.T, **plot_unif_kwargs)
|
123
|
2
|
ax.plot(x_vals, loo_pit_kde, **plot_kwargs)
|
124
|
2
|
ax.set_xlim(0, 1)
|
125
|
2
|
ax.set_ylim(0, None)
|
126
|
2
|
ax.tick_params(labelsize=xt_labelsize)
|
127
|
2
|
if legend:
|
128
|
2
|
if not (use_hdi or (ecdf and ecdf_fill)):
|
129
|
2
|
label = "{:.3g}% credible interval".format(credible_interval) if ecdf else "Uniform"
|
130
|
2
|
ax.plot([], label=label, **plot_unif_kwargs)
|
131
|
2
|
ax.legend()
|
132
|
|
|
133
|
2
|
if backend_show(show):
|
134
|
0
|
plt.show()
|
135
|
|
|
136
|
2
|
return ax
|