1
|
|
"""Bokeh Bayesian p-value Posterior predictive plot."""
|
2
|
2
|
import numpy as np
|
3
|
2
|
from bokeh.models import BoxAnnotation
|
4
|
2
|
from bokeh.models.annotations import Title
|
5
|
2
|
from scipy import stats
|
6
|
|
|
7
|
2
|
from ....stats.density_utils import kde
|
8
|
2
|
from ...kdeplot import plot_kde
|
9
|
2
|
from ...plot_utils import (
|
10
|
|
_scale_fig_size,
|
11
|
|
is_valid_quantile,
|
12
|
|
sample_reference_distribution,
|
13
|
|
vectorized_to_hex,
|
14
|
|
)
|
15
|
2
|
from .. import show_layout
|
16
|
2
|
from . import backend_kwarg_defaults, create_axes_grid
|
17
|
|
|
18
|
|
|
19
|
2
|
def plot_bpv(
|
20
|
|
ax,
|
21
|
|
length_plotters,
|
22
|
|
rows,
|
23
|
|
cols,
|
24
|
|
obs_plotters,
|
25
|
|
pp_plotters,
|
26
|
|
total_pp_samples,
|
27
|
|
kind,
|
28
|
|
t_stat,
|
29
|
|
bpv,
|
30
|
|
plot_mean,
|
31
|
|
reference,
|
32
|
|
n_ref,
|
33
|
|
hdi_prob,
|
34
|
|
color,
|
35
|
|
figsize,
|
36
|
|
textsize,
|
37
|
|
plot_ref_kwargs,
|
38
|
|
backend_kwargs,
|
39
|
|
show,
|
40
|
|
):
|
41
|
|
"""Bokeh bpv plot."""
|
42
|
2
|
if backend_kwargs is None:
|
43
|
2
|
backend_kwargs = {}
|
44
|
|
|
45
|
2
|
backend_kwargs = {
|
46
|
|
**backend_kwarg_defaults(),
|
47
|
|
**backend_kwargs,
|
48
|
|
}
|
49
|
|
|
50
|
2
|
color = vectorized_to_hex(color)
|
51
|
|
|
52
|
2
|
if plot_ref_kwargs is None:
|
53
|
2
|
plot_ref_kwargs = {}
|
54
|
2
|
if kind == "p_value" and reference == "analytical":
|
55
|
0
|
plot_ref_kwargs.setdefault("line_color", "black")
|
56
|
0
|
plot_ref_kwargs.setdefault("line_dash", "dashed")
|
57
|
|
else:
|
58
|
2
|
plot_ref_kwargs.setdefault("alpha", 0.1)
|
59
|
2
|
plot_ref_kwargs.setdefault("line_color", color)
|
60
|
|
|
61
|
2
|
(figsize, ax_labelsize, _, _, linewidth, markersize) = _scale_fig_size(
|
62
|
|
figsize, textsize, rows, cols
|
63
|
|
)
|
64
|
|
|
65
|
2
|
if ax is None:
|
66
|
2
|
axes = create_axes_grid(
|
67
|
|
length_plotters,
|
68
|
|
rows,
|
69
|
|
cols,
|
70
|
|
figsize=figsize,
|
71
|
|
backend_kwargs=backend_kwargs,
|
72
|
|
)
|
73
|
|
else:
|
74
|
0
|
axes = np.atleast_2d(ax)
|
75
|
|
|
76
|
0
|
if len([item for item in axes.ravel() if not None]) != length_plotters:
|
77
|
0
|
raise ValueError(
|
78
|
|
"Found {} variables to plot but {} axes instances. They must be equal.".format(
|
79
|
|
length_plotters, len(axes)
|
80
|
|
)
|
81
|
|
)
|
82
|
|
|
83
|
2
|
for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)):
|
84
|
2
|
var_name, _, obs_vals = obs_plotters[i]
|
85
|
2
|
pp_var_name, _, pp_vals = pp_plotters[i]
|
86
|
|
|
87
|
2
|
obs_vals = obs_vals.flatten()
|
88
|
2
|
pp_vals = pp_vals.reshape(total_pp_samples, -1)
|
89
|
|
|
90
|
2
|
if kind == "p_value":
|
91
|
2
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
|
92
|
2
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
93
|
2
|
ax_i.line(x_s, tstat_pit_dens, line_width=linewidth, line_color=color)
|
94
|
|
# ax_i.set_yticks([])
|
95
|
2
|
if reference is not None:
|
96
|
2
|
dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
|
97
|
2
|
if reference == "analytical":
|
98
|
0
|
lwb = dist.ppf((1 - 0.9999) / 2)
|
99
|
0
|
upb = 1 - lwb
|
100
|
0
|
x = np.linspace(lwb, upb, 500)
|
101
|
0
|
dens_ref = dist.pdf(x)
|
102
|
0
|
ax_i.line(x, dens_ref, **plot_ref_kwargs)
|
103
|
2
|
elif reference == "samples":
|
104
|
2
|
x_ss, u_dens = sample_reference_distribution(
|
105
|
|
dist,
|
106
|
|
(
|
107
|
|
n_ref,
|
108
|
|
tstat_pit_dens.size,
|
109
|
|
),
|
110
|
|
)
|
111
|
2
|
ax_i.multi_line(
|
112
|
|
list(x_ss.T), list(u_dens.T), line_width=linewidth, **plot_ref_kwargs
|
113
|
|
)
|
114
|
|
|
115
|
2
|
elif kind == "u_value":
|
116
|
2
|
tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
|
117
|
2
|
x_s, tstat_pit_dens = kde(tstat_pit)
|
118
|
2
|
ax_i.line(x_s, tstat_pit_dens, line_color=color)
|
119
|
2
|
if reference is not None:
|
120
|
2
|
if reference == "analytical":
|
121
|
2
|
n_obs = obs_vals.size
|
122
|
2
|
hdi = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
|
123
|
2
|
hdi_odds = (hdi / (1 - hdi), (1 - hdi) / hdi)
|
124
|
2
|
ax_i.add_layout(
|
125
|
|
BoxAnnotation(
|
126
|
|
bottom=hdi_odds[1],
|
127
|
|
top=hdi_odds[0],
|
128
|
|
fill_alpha=plot_ref_kwargs.pop("alpha"),
|
129
|
|
fill_color=plot_ref_kwargs.pop("line_color"),
|
130
|
|
**plot_ref_kwargs,
|
131
|
|
)
|
132
|
|
)
|
133
|
2
|
ax_i.line([0, 1], [1, 1], line_color="white")
|
134
|
2
|
elif reference == "samples":
|
135
|
2
|
dist = stats.uniform(0, 1)
|
136
|
2
|
x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
|
137
|
2
|
for x_ss_i, u_dens_i in zip(x_ss.T, u_dens.T):
|
138
|
2
|
ax_i.line(x_ss_i, u_dens_i, line_width=linewidth, **plot_ref_kwargs)
|
139
|
2
|
ax_i.line(0, 0)
|
140
|
|
else:
|
141
|
2
|
if t_stat in ["mean", "median", "std"]:
|
142
|
2
|
if t_stat == "mean":
|
143
|
0
|
tfunc = np.mean
|
144
|
2
|
elif t_stat == "median":
|
145
|
0
|
tfunc = np.median
|
146
|
2
|
elif t_stat == "std":
|
147
|
2
|
tfunc = np.std
|
148
|
2
|
obs_vals = tfunc(obs_vals)
|
149
|
2
|
pp_vals = tfunc(pp_vals, axis=1)
|
150
|
2
|
elif hasattr(t_stat, "__call__"):
|
151
|
0
|
obs_vals = t_stat(obs_vals.flatten())
|
152
|
0
|
pp_vals = t_stat(pp_vals)
|
153
|
2
|
elif is_valid_quantile(t_stat):
|
154
|
2
|
t_stat = float(t_stat)
|
155
|
2
|
obs_vals = np.quantile(obs_vals, q=t_stat)
|
156
|
2
|
pp_vals = np.quantile(pp_vals, q=t_stat, axis=1)
|
157
|
|
else:
|
158
|
0
|
raise ValueError(f"T statistics {t_stat} not implemented")
|
159
|
|
|
160
|
2
|
plot_kde(pp_vals, ax=ax_i, plot_kwargs={"color": color}, backend="bokeh", show=False)
|
161
|
|
# ax_i.set_yticks([])
|
162
|
2
|
if bpv:
|
163
|
2
|
p_value = np.mean(pp_vals <= obs_vals)
|
164
|
2
|
ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
|
165
|
|
|
166
|
2
|
if plot_mean:
|
167
|
2
|
ax_i.circle(
|
168
|
|
obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
|
169
|
|
)
|
170
|
|
|
171
|
2
|
if var_name != pp_var_name:
|
172
|
0
|
xlabel = "{} / {}".format(var_name, pp_var_name)
|
173
|
|
else:
|
174
|
2
|
xlabel = var_name
|
175
|
2
|
_title = Title()
|
176
|
2
|
_title.text = xlabel
|
177
|
2
|
ax_i.title = _title
|
178
|
2
|
size = str(int(ax_labelsize))
|
179
|
2
|
ax_i.title.text_font_size = f"{size}pt"
|
180
|
|
|
181
|
2
|
show_layout(axes, show)
|
182
|
|
|
183
|
2
|
return axes
|