1
|
|
# pylint: disable=too-many-lines
|
2
|
2
|
"""Statistical functions in ArviZ."""
|
3
|
2
|
import logging
|
4
|
2
|
import warnings
|
5
|
2
|
from collections import OrderedDict
|
6
|
2
|
from copy import deepcopy
|
7
|
2
|
from typing import List, Optional, Union
|
8
|
|
|
9
|
2
|
import numpy as np
|
10
|
2
|
import pandas as pd
|
11
|
2
|
import scipy.stats as st
|
12
|
2
|
import xarray as xr
|
13
|
2
|
from scipy.optimize import minimize
|
14
|
|
|
15
|
2
|
from ..data import CoordSpec, DimSpec, InferenceData, convert_to_dataset, convert_to_inference_data
|
16
|
2
|
from ..rcparams import rcParams
|
17
|
2
|
from ..utils import Numba, _numba_var, _var_names, credible_interval_warning, get_coords
|
18
|
2
|
from .density_utils import get_bins as _get_bins
|
19
|
2
|
from .density_utils import histogram as _histogram
|
20
|
2
|
from .density_utils import kde as _kde
|
21
|
2
|
from .diagnostics import _mc_error, _multichain_statistics, ess
|
22
|
2
|
from .stats_utils import ELPDData, _circular_standard_deviation
|
23
|
2
|
from .stats_utils import get_log_likelihood as _get_log_likelihood
|
24
|
2
|
from .stats_utils import logsumexp as _logsumexp
|
25
|
2
|
from .stats_utils import make_ufunc as _make_ufunc
|
26
|
2
|
from .stats_utils import stats_variance_2d as svar
|
27
|
2
|
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
|
28
|
|
|
29
|
2
|
_log = logging.getLogger(__name__)
|
30
|
|
|
31
|
2
|
__all__ = [
|
32
|
|
"apply_test_function",
|
33
|
|
"compare",
|
34
|
|
"hdi",
|
35
|
|
"hpd",
|
36
|
|
"loo",
|
37
|
|
"loo_pit",
|
38
|
|
"psislw",
|
39
|
|
"r2_score",
|
40
|
|
"summary",
|
41
|
|
"waic",
|
42
|
|
]
|
43
|
|
|
44
|
|
|
45
|
2
|
def compare(
|
46
|
|
dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None
|
47
|
|
):
|
48
|
|
r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation.
|
49
|
|
|
50
|
|
LOO is leave-one-out (PSIS-LOO `loo`) cross-validation and
|
51
|
|
WAIC is the widely applicable information criterion.
|
52
|
|
Read more theory here - in a paper by some of the leading authorities
|
53
|
|
on model selection dx.doi.org/10.1111/1467-9868.00353
|
54
|
|
|
55
|
|
Parameters
|
56
|
|
----------
|
57
|
|
dataset_dict: dict[str] -> InferenceData
|
58
|
|
A dictionary of model names and InferenceData objects
|
59
|
|
ic: str
|
60
|
|
Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to
|
61
|
|
``rcParams["stats.information_criterion"]``.
|
62
|
|
method: str
|
63
|
|
Method used to estimate the weights for each model. Available options are:
|
64
|
|
|
65
|
|
- 'stacking' : stacking of predictive distributions.
|
66
|
|
- 'BB-pseudo-BMA' : (default) pseudo-Bayesian Model averaging using Akaike-type
|
67
|
|
weighting. The weights are stabilized using the Bayesian bootstrap.
|
68
|
|
- 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
|
69
|
|
weighting, without Bootstrap stabilization (not recommended).
|
70
|
|
|
71
|
|
For more information read https://arxiv.org/abs/1704.02030
|
72
|
|
b_samples: int
|
73
|
|
Number of samples taken by the Bayesian bootstrap estimation.
|
74
|
|
Only useful when method = 'BB-pseudo-BMA'.
|
75
|
|
alpha: float
|
76
|
|
The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
|
77
|
|
useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
|
78
|
|
on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
|
79
|
|
seed: int or np.random.RandomState instance
|
80
|
|
If int or RandomState, use it for seeding Bayesian bootstrap. Only
|
81
|
|
useful when method = 'BB-pseudo-BMA'. Default None the global
|
82
|
|
np.random state is used.
|
83
|
|
scale: str
|
84
|
|
Output scale for IC. Available options are:
|
85
|
|
|
86
|
|
- `log` : (default) log-score (after Vehtari et al. (2017))
|
87
|
|
- `negative_log` : -1 * (log-score)
|
88
|
|
- `deviance` : -2 * (log-score)
|
89
|
|
|
90
|
|
A higher log-score (or a lower deviance) indicates a model with better predictive
|
91
|
|
accuracy.
|
92
|
|
|
93
|
|
Returns
|
94
|
|
-------
|
95
|
|
A DataFrame, ordered from best to worst model (measured by information criteria).
|
96
|
|
The index reflects the key with which the models are passed to this function. The columns are:
|
97
|
|
rank: The rank-order of the models. 0 is the best.
|
98
|
|
IC: Information Criteria (PSIS-LOO `loo` or WAIC `waic`).
|
99
|
|
Higher IC indicates higher out-of-sample predictive fit ("better" model). Default LOO.
|
100
|
|
If `scale` is `deviance` or `negative_log` smaller IC indicates
|
101
|
|
higher out-of-sample predictive fit ("better" model).
|
102
|
|
pIC: Estimated effective number of parameters.
|
103
|
|
dIC: Relative difference between each IC (PSIS-LOO `loo` or WAIC `waic`)
|
104
|
|
and the lowest IC (PSIS-LOO `loo` or WAIC `waic`).
|
105
|
|
The top-ranked model is always 0.
|
106
|
|
weight: Relative weight for each model.
|
107
|
|
This can be loosely interpreted as the probability of each model (among the compared model)
|
108
|
|
given the data. By default the uncertainty in the weights estimation is considered using
|
109
|
|
Bayesian bootstrap.
|
110
|
|
SE: Standard error of the IC estimate.
|
111
|
|
If method = BB-pseudo-BMA these values are estimated using Bayesian bootstrap.
|
112
|
|
dSE: Standard error of the difference in IC between each model and the top-ranked model.
|
113
|
|
It's always 0 for the top-ranked model.
|
114
|
|
warning: A value of 1 indicates that the computation of the IC may not be reliable.
|
115
|
|
This could be indication of WAIC/LOO starting to fail see
|
116
|
|
http://arxiv.org/abs/1507.04544 for details.
|
117
|
|
scale: Scale used for the IC.
|
118
|
|
|
119
|
|
Examples
|
120
|
|
--------
|
121
|
|
Compare the centered and non centered models of the eight school problem:
|
122
|
|
|
123
|
|
.. ipython::
|
124
|
|
|
125
|
|
In [1]: import arviz as az
|
126
|
|
...: data1 = az.load_arviz_data("non_centered_eight")
|
127
|
|
...: data2 = az.load_arviz_data("centered_eight")
|
128
|
|
...: compare_dict = {"non centered": data1, "centered": data2}
|
129
|
|
...: az.compare(compare_dict)
|
130
|
|
|
131
|
|
Compare the models using LOO-CV, returning the IC in log scale and calculating the
|
132
|
|
weights using the stacking method.
|
133
|
|
|
134
|
|
.. ipython::
|
135
|
|
|
136
|
|
In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
|
137
|
|
|
138
|
|
See Also
|
139
|
|
--------
|
140
|
|
loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation.
|
141
|
|
waic : Compute the widely applicable information criterion.
|
142
|
|
|
143
|
|
"""
|
144
|
2
|
names = list(dataset_dict.keys())
|
145
|
2
|
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
|
146
|
2
|
if scale == "log":
|
147
|
2
|
scale_value = 1
|
148
|
2
|
ascending = False
|
149
|
2
|
warnings.warn(
|
150
|
|
"\nThe scale is now log by default. Use 'scale' argument or "
|
151
|
|
"'stats.ic_scale' rcParam if you rely on a specific value.\nA higher "
|
152
|
|
"log-score (or a lower deviance) indicates a model with better predictive "
|
153
|
|
"accuracy."
|
154
|
|
)
|
155
|
|
else:
|
156
|
2
|
if scale == "negative_log":
|
157
|
2
|
scale_value = -1
|
158
|
|
else:
|
159
|
2
|
scale_value = -2
|
160
|
2
|
ascending = True
|
161
|
|
|
162
|
2
|
ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
|
163
|
2
|
if ic == "loo":
|
164
|
2
|
ic_func = loo
|
165
|
2
|
df_comp = pd.DataFrame(
|
166
|
|
index=names,
|
167
|
|
columns=[
|
168
|
|
"rank",
|
169
|
|
"loo",
|
170
|
|
"p_loo",
|
171
|
|
"d_loo",
|
172
|
|
"weight",
|
173
|
|
"se",
|
174
|
|
"dse",
|
175
|
|
"warning",
|
176
|
|
"loo_scale",
|
177
|
|
],
|
178
|
|
)
|
179
|
2
|
scale_col = "loo_scale"
|
180
|
2
|
elif ic == "waic":
|
181
|
2
|
ic_func = waic
|
182
|
2
|
df_comp = pd.DataFrame(
|
183
|
|
index=names,
|
184
|
|
columns=[
|
185
|
|
"rank",
|
186
|
|
"waic",
|
187
|
|
"p_waic",
|
188
|
|
"d_waic",
|
189
|
|
"weight",
|
190
|
|
"se",
|
191
|
|
"dse",
|
192
|
|
"warning",
|
193
|
|
"waic_scale",
|
194
|
|
],
|
195
|
|
)
|
196
|
2
|
scale_col = "waic_scale"
|
197
|
|
else:
|
198
|
2
|
raise NotImplementedError("The information criterion {} is not supported.".format(ic))
|
199
|
|
|
200
|
2
|
if method.lower() not in ["stacking", "bb-pseudo-bma", "pseudo-bma"]:
|
201
|
2
|
raise ValueError("The method {}, to compute weights, is not supported.".format(method))
|
202
|
|
|
203
|
2
|
ic_se = "{}_se".format(ic)
|
204
|
2
|
p_ic = "p_{}".format(ic)
|
205
|
2
|
ic_i = "{}_i".format(ic)
|
206
|
|
|
207
|
2
|
ics = pd.DataFrame()
|
208
|
2
|
names = []
|
209
|
2
|
for name, dataset in dataset_dict.items():
|
210
|
2
|
names.append(name)
|
211
|
2
|
ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)])
|
212
|
2
|
ics.index = names
|
213
|
2
|
ics.sort_values(by=ic, inplace=True, ascending=ascending)
|
214
|
2
|
ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
|
215
|
|
|
216
|
2
|
if method.lower() == "stacking":
|
217
|
2
|
rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
|
218
|
2
|
exp_ic_i = np.exp(ic_i_val / scale_value)
|
219
|
2
|
last_col = cols - 1
|
220
|
|
|
221
|
2
|
def w_fuller(weights):
|
222
|
2
|
return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))
|
223
|
|
|
224
|
2
|
def log_score(weights):
|
225
|
2
|
w_full = w_fuller(weights)
|
226
|
2
|
score = 0.0
|
227
|
2
|
for i in range(rows):
|
228
|
2
|
score += np.log(np.dot(exp_ic_i[i], w_full))
|
229
|
2
|
return -score
|
230
|
|
|
231
|
2
|
def gradient(weights):
|
232
|
2
|
w_full = w_fuller(weights)
|
233
|
2
|
grad = np.zeros(last_col)
|
234
|
2
|
for k in range(last_col - 1):
|
235
|
0
|
for i in range(rows):
|
236
|
0
|
grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, last_col]) / np.dot(
|
237
|
|
exp_ic_i[i], w_full
|
238
|
|
)
|
239
|
2
|
return -grad
|
240
|
|
|
241
|
2
|
theta = np.full(last_col, 1.0 / cols)
|
242
|
2
|
bounds = [(0.0, 1.0) for _ in range(last_col)]
|
243
|
2
|
constraints = [
|
244
|
|
{"type": "ineq", "fun": lambda x: 1.0 - np.sum(x)},
|
245
|
|
{"type": "ineq", "fun": np.sum},
|
246
|
|
]
|
247
|
|
|
248
|
2
|
weights = minimize(
|
249
|
|
fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
|
250
|
|
)
|
251
|
|
|
252
|
2
|
weights = w_fuller(weights["x"])
|
253
|
2
|
ses = ics[ic_se]
|
254
|
|
|
255
|
2
|
elif method.lower() == "bb-pseudo-bma":
|
256
|
2
|
rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
|
257
|
2
|
ic_i_val = ic_i_val * rows
|
258
|
|
|
259
|
2
|
b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)
|
260
|
2
|
weights = np.zeros((b_samples, cols))
|
261
|
2
|
z_bs = np.zeros_like(weights)
|
262
|
2
|
for i in range(b_samples):
|
263
|
2
|
z_b = np.dot(b_weighting[i], ic_i_val)
|
264
|
2
|
u_weights = np.exp((z_b - np.min(z_b)) / scale_value)
|
265
|
2
|
z_bs[i] = z_b # pylint: disable=unsupported-assignment-operation
|
266
|
2
|
weights[i] = u_weights / np.sum(u_weights)
|
267
|
|
|
268
|
2
|
weights = weights.mean(axis=0)
|
269
|
2
|
ses = pd.Series(z_bs.std(axis=0), index=names) # pylint: disable=no-member
|
270
|
|
|
271
|
2
|
elif method.lower() == "pseudo-bma":
|
272
|
2
|
min_ic = ics.iloc[0][ic]
|
273
|
2
|
z_rv = np.exp((ics[ic] - min_ic) / scale_value)
|
274
|
2
|
weights = z_rv / np.sum(z_rv)
|
275
|
2
|
ses = ics[ic_se]
|
276
|
|
|
277
|
2
|
if np.any(weights):
|
278
|
2
|
min_ic_i_val = ics[ic_i].iloc[0]
|
279
|
2
|
for idx, val in enumerate(ics.index):
|
280
|
2
|
res = ics.loc[val]
|
281
|
2
|
if scale_value < 0:
|
282
|
2
|
diff = res[ic_i] - min_ic_i_val
|
283
|
|
else:
|
284
|
2
|
diff = min_ic_i_val - res[ic_i]
|
285
|
2
|
d_ic = np.sum(diff)
|
286
|
2
|
d_std_err = np.sqrt(len(diff) * np.var(diff))
|
287
|
2
|
std_err = ses.loc[val]
|
288
|
2
|
weight = weights[idx]
|
289
|
2
|
df_comp.at[val] = (
|
290
|
|
idx,
|
291
|
|
res[ic],
|
292
|
|
res[p_ic],
|
293
|
|
d_ic,
|
294
|
|
weight,
|
295
|
|
std_err,
|
296
|
|
d_std_err,
|
297
|
|
res["warning"],
|
298
|
|
res[scale_col],
|
299
|
|
)
|
300
|
|
|
301
|
2
|
return df_comp.sort_values(by=ic, ascending=ascending)
|
302
|
|
|
303
|
|
|
304
|
2
|
def _ic_matrix(ics, ic_i):
|
305
|
|
"""Store the previously computed pointwise predictive accuracy values (ics) in a 2D matrix."""
|
306
|
2
|
cols, _ = ics.shape
|
307
|
2
|
rows = len(ics[ic_i].iloc[0])
|
308
|
2
|
ic_i_val = np.zeros((rows, cols))
|
309
|
|
|
310
|
2
|
for idx, val in enumerate(ics.index):
|
311
|
2
|
ic = ics.loc[val][ic_i]
|
312
|
|
|
313
|
2
|
if len(ic) != rows:
|
314
|
2
|
raise ValueError("The number of observations should be the same across all models")
|
315
|
|
|
316
|
2
|
ic_i_val[:, idx] = ic
|
317
|
|
|
318
|
2
|
return rows, cols, ic_i_val
|
319
|
|
|
320
|
|
|
321
|
2
|
def hpd(
|
322
|
|
# pylint: disable=unused-argument
|
323
|
|
ary,
|
324
|
|
hdi_prob=None,
|
325
|
|
circular=False,
|
326
|
|
multimodal=False,
|
327
|
|
skipna=False,
|
328
|
|
group="posterior",
|
329
|
|
var_names=None,
|
330
|
|
filter_vars=None,
|
331
|
|
coords=None,
|
332
|
|
max_modes=10,
|
333
|
|
**kwargs,
|
334
|
|
):
|
335
|
|
"""Pending deprecation. Please refer to :func:`~arviz.hdi`."""
|
336
|
|
# pylint: enable=unused-argument
|
337
|
0
|
warnings.warn(
|
338
|
|
("hpd will be deprecated " "Please replace hdi"),
|
339
|
|
)
|
340
|
0
|
return hdi(
|
341
|
|
ary,
|
342
|
|
hdi_prob,
|
343
|
|
circular,
|
344
|
|
multimodal,
|
345
|
|
skipna,
|
346
|
|
group,
|
347
|
|
var_names,
|
348
|
|
filter_vars,
|
349
|
|
coords,
|
350
|
|
max_modes,
|
351
|
|
**kwargs,
|
352
|
|
)
|
353
|
|
|
354
|
|
|
355
|
2
|
def hdi(
|
356
|
|
ary,
|
357
|
|
hdi_prob=None,
|
358
|
|
circular=False,
|
359
|
|
multimodal=False,
|
360
|
|
skipna=False,
|
361
|
|
group="posterior",
|
362
|
|
var_names=None,
|
363
|
|
filter_vars=None,
|
364
|
|
coords=None,
|
365
|
|
max_modes=10,
|
366
|
|
**kwargs,
|
367
|
|
):
|
368
|
|
"""
|
369
|
|
Calculate highest density interval (HDI) of array for given probability.
|
370
|
|
|
371
|
|
The HDI is the minimum width Bayesian credible interval (BCI).
|
372
|
|
|
373
|
|
Parameters
|
374
|
|
----------
|
375
|
|
ary: obj
|
376
|
|
object containing posterior samples.
|
377
|
|
Any object that can be converted to an az.InferenceData object.
|
378
|
|
Refer to documentation of az.convert_to_dataset for details.
|
379
|
|
hdi_prob: float, optional
|
380
|
|
HDI prob for which interval will be computed. Defaults to ``stats.hdi_prob`` rcParam.
|
381
|
|
circular: bool, optional
|
382
|
|
Whether to compute the hdi taking into account `x` is a circular variable
|
383
|
|
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
384
|
|
Only works if multimodal is False.
|
385
|
|
multimodal: bool, optional
|
386
|
|
If true it may compute more than one hdi interval if the distribution is multimodal and the
|
387
|
|
modes are well separated.
|
388
|
|
skipna: bool, optional
|
389
|
|
If true ignores nan values when computing the hdi interval. Defaults to false.
|
390
|
|
group: str, optional
|
391
|
|
Specifies which InferenceData group should be used to calculate hdi.
|
392
|
|
Defaults to 'posterior'
|
393
|
|
var_names: list, optional
|
394
|
|
Names of variables to include in the hdi report. Prefix the variables by `~`
|
395
|
|
when you want to exclude them from the report: `["~beta"]` instead of `["beta"]`
|
396
|
|
(see `az.summary` for more details).
|
397
|
|
filter_vars: {None, "like", "regex"}, optional, default=None
|
398
|
|
If `None` (default), interpret var_names as the real variables names. If "like",
|
399
|
|
interpret var_names as substrings of the real variables names. If "regex",
|
400
|
|
interpret var_names as regular expressions on the real variables names. A la
|
401
|
|
`pandas.filter`.
|
402
|
|
coords: mapping, optional
|
403
|
|
Specifies the subset over to calculate hdi.
|
404
|
|
max_modes: int, optional
|
405
|
|
Specifies the maximum number of modes for multimodal case.
|
406
|
|
kwargs: dict, optional
|
407
|
|
Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`.
|
408
|
|
|
409
|
|
Returns
|
410
|
|
-------
|
411
|
|
np.ndarray or xarray.Dataset, depending upon input
|
412
|
|
lower(s) and upper(s) values of the interval(s).
|
413
|
|
|
414
|
|
See Also
|
415
|
|
--------
|
416
|
|
plot_hdi : Plot HDI intervals for regression data.
|
417
|
|
xarray.Dataset.quantile : Calculate quantiles of array for given probabilities.
|
418
|
|
|
419
|
|
Examples
|
420
|
|
--------
|
421
|
|
Calculate the HDI of a Normal random variable:
|
422
|
|
|
423
|
|
.. ipython::
|
424
|
|
|
425
|
|
In [1]: import arviz as az
|
426
|
|
...: import numpy as np
|
427
|
|
...: data = np.random.normal(size=2000)
|
428
|
|
...: az.hdi(data, hdi_prob=.68)
|
429
|
|
|
430
|
|
Calculate the HDI of a dataset:
|
431
|
|
|
432
|
|
.. ipython::
|
433
|
|
|
434
|
|
In [1]: import arviz as az
|
435
|
|
...: data = az.load_arviz_data('centered_eight')
|
436
|
|
...: az.hdi(data)
|
437
|
|
|
438
|
|
We can also calculate the HDI of some of the variables of dataset:
|
439
|
|
|
440
|
|
.. ipython::
|
441
|
|
|
442
|
|
In [1]: az.hdi(data, var_names=["mu", "theta"])
|
443
|
|
|
444
|
|
If we want to calculate the HDI over specified dimension of dataset,
|
445
|
|
we can pass `input_core_dims` by kwargs:
|
446
|
|
|
447
|
|
.. ipython::
|
448
|
|
|
449
|
|
In [1]: az.hdi(data, input_core_dims = [["chain"]])
|
450
|
|
|
451
|
|
We can also calculate the hdi over a particular selection over all groups:
|
452
|
|
|
453
|
|
.. ipython::
|
454
|
|
|
455
|
|
In [1]: az.hdi(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])
|
456
|
|
|
457
|
|
"""
|
458
|
2
|
if hdi_prob is None:
|
459
|
2
|
hdi_prob = rcParams["stats.hdi_prob"]
|
460
|
|
else:
|
461
|
2
|
if not 1 >= hdi_prob > 0:
|
462
|
2
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
463
|
|
|
464
|
2
|
func_kwargs = {
|
465
|
|
"hdi_prob": hdi_prob,
|
466
|
|
"skipna": skipna,
|
467
|
|
"out_shape": (max_modes, 2) if multimodal else (2,),
|
468
|
|
}
|
469
|
2
|
kwargs.setdefault("output_core_dims", [["hdi", "mode"] if multimodal else ["hdi"]])
|
470
|
2
|
if not multimodal:
|
471
|
2
|
func_kwargs["circular"] = circular
|
472
|
|
else:
|
473
|
2
|
func_kwargs["max_modes"] = max_modes
|
474
|
|
|
475
|
2
|
func = _hdi_multimodal if multimodal else _hdi
|
476
|
|
|
477
|
2
|
isarray = isinstance(ary, np.ndarray)
|
478
|
2
|
if isarray and ary.ndim <= 1:
|
479
|
2
|
func_kwargs.pop("out_shape")
|
480
|
2
|
hdi_data = func(ary, **func_kwargs) # pylint: disable=unexpected-keyword-arg
|
481
|
2
|
return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data
|
482
|
|
|
483
|
2
|
if isarray and ary.ndim == 2:
|
484
|
2
|
warnings.warn(
|
485
|
|
"hdi currently interprets 2d data as (draw, shape) but this will change in "
|
486
|
|
"a future release to (chain, draw) for coherence with other functions",
|
487
|
|
FutureWarning,
|
488
|
|
)
|
489
|
2
|
ary = np.expand_dims(ary, 0)
|
490
|
|
|
491
|
2
|
ary = convert_to_dataset(ary, group=group)
|
492
|
2
|
if coords is not None:
|
493
|
2
|
ary = get_coords(ary, coords)
|
494
|
2
|
var_names = _var_names(var_names, ary, filter_vars)
|
495
|
2
|
ary = ary[var_names] if var_names else ary
|
496
|
|
|
497
|
2
|
hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
|
498
|
2
|
hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords(
|
499
|
|
{"hdi": hdi_coord}
|
500
|
|
)
|
501
|
2
|
hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
|
502
|
2
|
return hdi_data.x.values if isarray else hdi_data
|
503
|
|
|
504
|
|
|
505
|
2
|
def _hdi(ary, hdi_prob, circular, skipna):
|
506
|
|
"""Compute hpi over the flattened array."""
|
507
|
2
|
ary = ary.flatten()
|
508
|
2
|
if skipna:
|
509
|
2
|
nans = np.isnan(ary)
|
510
|
2
|
if not nans.all():
|
511
|
2
|
ary = ary[~nans]
|
512
|
2
|
n = len(ary)
|
513
|
|
|
514
|
2
|
if circular:
|
515
|
2
|
mean = st.circmean(ary, high=np.pi, low=-np.pi)
|
516
|
2
|
ary = ary - mean
|
517
|
2
|
ary = np.arctan2(np.sin(ary), np.cos(ary))
|
518
|
|
|
519
|
2
|
ary = np.sort(ary)
|
520
|
2
|
interval_idx_inc = int(np.floor(hdi_prob * n))
|
521
|
2
|
n_intervals = n - interval_idx_inc
|
522
|
2
|
interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
|
523
|
|
|
524
|
2
|
if len(interval_width) == 0:
|
525
|
0
|
raise ValueError("Too few elements for interval calculation. ")
|
526
|
|
|
527
|
2
|
min_idx = np.argmin(interval_width)
|
528
|
2
|
hdi_min = ary[min_idx]
|
529
|
2
|
hdi_max = ary[min_idx + interval_idx_inc]
|
530
|
|
|
531
|
2
|
if circular:
|
532
|
2
|
hdi_min = hdi_min + mean
|
533
|
2
|
hdi_max = hdi_max + mean
|
534
|
2
|
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
|
535
|
2
|
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
|
536
|
|
|
537
|
2
|
hdi_interval = np.array([hdi_min, hdi_max])
|
538
|
|
|
539
|
2
|
return hdi_interval
|
540
|
|
|
541
|
|
|
542
|
2
|
def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
|
543
|
|
"""Compute HDI if the distribution is multimodal."""
|
544
|
2
|
ary = ary.flatten()
|
545
|
2
|
if skipna:
|
546
|
0
|
ary = ary[~np.isnan(ary)]
|
547
|
|
|
548
|
2
|
if ary.dtype.kind == "f":
|
549
|
2
|
bins, density = _kde(ary)
|
550
|
2
|
lower, upper = bins[0], bins[-1]
|
551
|
2
|
range_x = upper - lower
|
552
|
2
|
dx = range_x / len(density)
|
553
|
|
else:
|
554
|
0
|
bins = _get_bins(ary)
|
555
|
0
|
_, density, _ = _histogram(ary, bins=bins)
|
556
|
0
|
dx = np.diff(bins)[0]
|
557
|
|
|
558
|
2
|
density *= dx
|
559
|
|
|
560
|
2
|
idx = np.argsort(-density)
|
561
|
2
|
intervals = bins[idx][density[idx].cumsum() <= hdi_prob]
|
562
|
2
|
intervals.sort()
|
563
|
|
|
564
|
2
|
intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
|
565
|
|
|
566
|
2
|
hdi_intervals = np.full((max_modes, 2), np.nan)
|
567
|
2
|
for i, interval in enumerate(intervals_splitted):
|
568
|
2
|
if i == max_modes:
|
569
|
0
|
warnings.warn(
|
570
|
|
"found more modes than {0}, returning only the first {0} modes".format(max_modes)
|
571
|
|
)
|
572
|
0
|
break
|
573
|
2
|
if interval.size == 0:
|
574
|
0
|
hdi_intervals[i] = np.asarray([bins[0], bins[0]])
|
575
|
|
else:
|
576
|
2
|
hdi_intervals[i] = np.asarray([interval[0], interval[-1]])
|
577
|
|
|
578
|
2
|
return np.array(hdi_intervals)
|
579
|
|
|
580
|
|
|
581
|
2
|
def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
|
582
|
|
"""Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
|
583
|
|
|
584
|
|
Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
|
585
|
|
importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Also calculates LOO's
|
586
|
|
standard error and the effective number of parameters. Read more theory here
|
587
|
|
https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1507.02646
|
588
|
|
|
589
|
|
Parameters
|
590
|
|
----------
|
591
|
|
data: obj
|
592
|
|
Any object that can be converted to an az.InferenceData object. Refer to documentation of
|
593
|
|
az.convert_to_inference_data for details
|
594
|
|
pointwise: bool, optional
|
595
|
|
If True the pointwise predictive accuracy will be returned. Defaults to
|
596
|
|
``stats.ic_pointwise`` rcParam.
|
597
|
|
var_name : str, optional
|
598
|
|
The name of the variable in log_likelihood groups storing the pointwise log
|
599
|
|
likelihood data to use for loo computation.
|
600
|
|
reff: float, optional
|
601
|
|
Relative MCMC efficiency, `ess / n` i.e. number of effective samples divided by the number
|
602
|
|
of actual samples. Computed from trace by default.
|
603
|
|
scale: str
|
604
|
|
Output scale for loo. Available options are:
|
605
|
|
|
606
|
|
- `log` : (default) log-score
|
607
|
|
- `negative_log` : -1 * log-score
|
608
|
|
- `deviance` : -2 * log-score
|
609
|
|
|
610
|
|
A higher log-score (or a lower deviance or negative log_score) indicates a model with
|
611
|
|
better predictive accuracy.
|
612
|
|
|
613
|
|
Returns
|
614
|
|
-------
|
615
|
|
ELPDData object (inherits from panda.Series) with the following row/attributes:
|
616
|
|
loo: approximated expected log pointwise predictive density (elpd)
|
617
|
|
loo_se: standard error of loo
|
618
|
|
p_loo: effective number of parameters
|
619
|
|
shape_warn: bool
|
620
|
|
True if the estimated shape parameter of
|
621
|
|
Pareto distribution is greater than 0.7 for one or more samples
|
622
|
|
loo_i: array of pointwise predictive accuracy, only if pointwise True
|
623
|
|
pareto_k: array of Pareto shape values, only if pointwise True
|
624
|
|
loo_scale: scale of the loo results
|
625
|
|
|
626
|
|
The returned object has a custom print method that overrides pd.Series method.
|
627
|
|
|
628
|
|
Examples
|
629
|
|
--------
|
630
|
|
Calculate LOO of a model:
|
631
|
|
|
632
|
|
.. ipython::
|
633
|
|
|
634
|
|
In [1]: import arviz as az
|
635
|
|
...: data = az.load_arviz_data("centered_eight")
|
636
|
|
...: az.loo(data)
|
637
|
|
|
638
|
|
Calculate LOO of a model and return the pointwise values:
|
639
|
|
|
640
|
|
.. ipython::
|
641
|
|
|
642
|
|
In [2]: data_loo = az.loo(data, pointwise=True)
|
643
|
|
...: data_loo.loo_i
|
644
|
|
"""
|
645
|
2
|
inference_data = convert_to_inference_data(data)
|
646
|
2
|
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
|
647
|
2
|
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
|
648
|
|
|
649
|
2
|
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
|
650
|
2
|
shape = log_likelihood.shape
|
651
|
2
|
n_samples = shape[-1]
|
652
|
2
|
n_data_points = np.product(shape[:-1])
|
653
|
2
|
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
|
654
|
|
|
655
|
2
|
if scale == "deviance":
|
656
|
2
|
scale_value = -2
|
657
|
2
|
elif scale == "log":
|
658
|
2
|
scale_value = 1
|
659
|
2
|
elif scale == "negative_log":
|
660
|
2
|
scale_value = -1
|
661
|
|
else:
|
662
|
2
|
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
|
663
|
|
|
664
|
2
|
if reff is None:
|
665
|
2
|
if not hasattr(inference_data, "posterior"):
|
666
|
2
|
raise TypeError("Must be able to extract a posterior group from data.")
|
667
|
2
|
posterior = inference_data.posterior
|
668
|
2
|
n_chains = len(posterior.chain)
|
669
|
2
|
if n_chains == 1:
|
670
|
2
|
reff = 1.0
|
671
|
|
else:
|
672
|
2
|
ess_p = ess(posterior, method="mean")
|
673
|
|
# this mean is over all data variables
|
674
|
2
|
reff = (
|
675
|
|
np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
|
676
|
|
)
|
677
|
|
|
678
|
2
|
log_weights, pareto_shape = psislw(-log_likelihood, reff)
|
679
|
2
|
log_weights += log_likelihood
|
680
|
|
|
681
|
2
|
warn_mg = False
|
682
|
2
|
if np.any(pareto_shape > 0.7):
|
683
|
2
|
warnings.warn(
|
684
|
|
"Estimated shape parameter of Pareto distribution is greater than 0.7 for "
|
685
|
|
"one or more samples. You should consider using a more robust model, this is because "
|
686
|
|
"importance sampling is less likely to work well if the marginal posterior and "
|
687
|
|
"LOO posterior are very different. This is more likely to happen with a non-robust "
|
688
|
|
"model and highly influential observations."
|
689
|
|
)
|
690
|
2
|
warn_mg = True
|
691
|
|
|
692
|
2
|
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
693
|
2
|
kwargs = {"input_core_dims": [["sample"]]}
|
694
|
2
|
loo_lppd_i = scale_value * _wrap_xarray_ufunc(
|
695
|
|
_logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
|
696
|
|
)
|
697
|
2
|
loo_lppd = loo_lppd_i.values.sum()
|
698
|
2
|
loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
|
699
|
|
|
700
|
2
|
lppd = np.sum(
|
701
|
|
_wrap_xarray_ufunc(
|
702
|
|
_logsumexp,
|
703
|
|
log_likelihood,
|
704
|
|
func_kwargs={"b_inv": n_samples},
|
705
|
|
ufunc_kwargs=ufunc_kwargs,
|
706
|
|
**kwargs,
|
707
|
|
).values
|
708
|
|
)
|
709
|
2
|
p_loo = lppd - loo_lppd / scale_value
|
710
|
|
|
711
|
2
|
if pointwise:
|
712
|
2
|
if np.equal(loo_lppd, loo_lppd_i).all(): # pylint: disable=no-member
|
713
|
0
|
warnings.warn(
|
714
|
|
"The point-wise LOO is the same with the sum LOO, please double check "
|
715
|
|
"the Observed RV in your model to make sure it returns element-wise logp."
|
716
|
|
)
|
717
|
2
|
return ELPDData(
|
718
|
|
data=[
|
719
|
|
loo_lppd,
|
720
|
|
loo_lppd_se,
|
721
|
|
p_loo,
|
722
|
|
n_samples,
|
723
|
|
n_data_points,
|
724
|
|
warn_mg,
|
725
|
|
loo_lppd_i.rename("loo_i"),
|
726
|
|
pareto_shape,
|
727
|
|
scale,
|
728
|
|
],
|
729
|
|
index=[
|
730
|
|
"loo",
|
731
|
|
"loo_se",
|
732
|
|
"p_loo",
|
733
|
|
"n_samples",
|
734
|
|
"n_data_points",
|
735
|
|
"warning",
|
736
|
|
"loo_i",
|
737
|
|
"pareto_k",
|
738
|
|
"loo_scale",
|
739
|
|
],
|
740
|
|
)
|
741
|
|
|
742
|
|
else:
|
743
|
2
|
return ELPDData(
|
744
|
|
data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
|
745
|
|
index=["loo", "loo_se", "p_loo", "n_samples", "n_data_points", "warning", "loo_scale"],
|
746
|
|
)
|
747
|
|
|
748
|
|
|
749
|
2
|
def psislw(log_weights, reff=1.0):
|
750
|
|
"""
|
751
|
|
Pareto smoothed importance sampling (PSIS).
|
752
|
|
|
753
|
|
Parameters
|
754
|
|
----------
|
755
|
|
log_weights: array
|
756
|
|
Array of size (n_observations, n_samples)
|
757
|
|
reff: float
|
758
|
|
relative MCMC efficiency, `ess / n`
|
759
|
|
|
760
|
|
Returns
|
761
|
|
-------
|
762
|
|
lw_out: array
|
763
|
|
Smoothed log weights
|
764
|
|
kss: array
|
765
|
|
Pareto tail indices
|
766
|
|
|
767
|
|
References
|
768
|
|
----------
|
769
|
|
* Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
|
770
|
|
|
771
|
|
Examples
|
772
|
|
--------
|
773
|
|
Get Pareto smoothed importance sampling (PSIS) log weights:
|
774
|
|
|
775
|
|
.. ipython::
|
776
|
|
|
777
|
|
In [1]: import arviz as az
|
778
|
|
...: data = az.load_arviz_data("centered_eight")
|
779
|
|
...: log_likelihood = data.sample_stats.log_likelihood.stack(sample=("chain", "draw"))
|
780
|
|
...: az.psislw(-log_likelihood, reff=0.8)
|
781
|
|
|
782
|
|
"""
|
783
|
2
|
if hasattr(log_weights, "sample"):
|
784
|
2
|
n_samples = len(log_weights.sample)
|
785
|
2
|
shape = [size for size, dim in zip(log_weights.shape, log_weights.dims) if dim != "sample"]
|
786
|
|
else:
|
787
|
0
|
n_samples = log_weights.shape[-1]
|
788
|
0
|
shape = log_weights.shape[:-1]
|
789
|
|
# precalculate constants
|
790
|
2
|
cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1
|
791
|
2
|
cutoffmin = np.log(np.finfo(float).tiny) # pylint: disable=no-member, assignment-from-no-return
|
792
|
2
|
k_min = 1.0 / 3
|
793
|
|
|
794
|
|
# create output array with proper dimensions
|
795
|
2
|
out = tuple([np.empty_like(log_weights), np.empty(shape)])
|
796
|
|
|
797
|
|
# define kwargs
|
798
|
2
|
func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "k_min": k_min, "out": out}
|
799
|
2
|
ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
|
800
|
2
|
kwargs = {"input_core_dims": [["sample"]], "output_core_dims": [["sample"], []]}
|
801
|
2
|
log_weights, pareto_shape = _wrap_xarray_ufunc(
|
802
|
|
_psislw, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs, **kwargs
|
803
|
|
)
|
804
|
2
|
if isinstance(log_weights, xr.DataArray):
|
805
|
2
|
log_weights = log_weights.rename("log_weights").rename(sample="sample")
|
806
|
2
|
if isinstance(pareto_shape, xr.DataArray):
|
807
|
2
|
pareto_shape = pareto_shape.rename("pareto_shape")
|
808
|
2
|
return log_weights, pareto_shape
|
809
|
|
|
810
|
|
|
811
|
2
|
def _psislw(log_weights, cutoff_ind, cutoffmin, k_min=1.0 / 3):
|
812
|
|
"""
|
813
|
|
Pareto smoothed importance sampling (PSIS) for a 1D vector.
|
814
|
|
|
815
|
|
Parameters
|
816
|
|
----------
|
817
|
|
log_weights: array
|
818
|
|
Array of length n_observations
|
819
|
|
cutoff_ind: int
|
820
|
|
cutoffmin: float
|
821
|
|
k_min: float
|
822
|
|
|
823
|
|
Returns
|
824
|
|
-------
|
825
|
|
lw_out: array
|
826
|
|
Smoothed log weights
|
827
|
|
kss: float
|
828
|
|
Pareto tail index
|
829
|
|
"""
|
830
|
2
|
x = np.asarray(log_weights)
|
831
|
|
|
832
|
|
# improve numerical accuracy
|
833
|
2
|
x -= np.max(x)
|
834
|
|
# sort the array
|
835
|
2
|
x_sort_ind = np.argsort(x)
|
836
|
|
# divide log weights into body and right tail
|
837
|
2
|
xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)
|
838
|
|
|
839
|
2
|
expxcutoff = np.exp(xcutoff)
|
840
|
2
|
(tailinds,) = np.where(x > xcutoff) # pylint: disable=unbalanced-tuple-unpacking
|
841
|
2
|
x_tail = x[tailinds]
|
842
|
2
|
tail_len = len(x_tail)
|
843
|
2
|
if tail_len <= 4:
|
844
|
|
# not enough tail samples for gpdfit
|
845
|
2
|
k = np.inf
|
846
|
|
else:
|
847
|
|
# order of tail samples
|
848
|
2
|
x_tail_si = np.argsort(x_tail)
|
849
|
|
# fit generalized Pareto distribution to the right tail samples
|
850
|
2
|
x_tail = np.exp(x_tail) - expxcutoff
|
851
|
2
|
k, sigma = _gpdfit(x_tail[x_tail_si])
|
852
|
|
|
853
|
2
|
if k >= k_min:
|
854
|
|
# no smoothing if short tail or GPD fit failed
|
855
|
|
# compute ordered statistic for the fit
|
856
|
2
|
sti = np.arange(0.5, tail_len) / tail_len
|
857
|
2
|
smoothed_tail = _gpinv(sti, k, sigma)
|
858
|
2
|
smoothed_tail = np.log( # pylint: disable=assignment-from-no-return
|
859
|
|
smoothed_tail + expxcutoff
|
860
|
|
)
|
861
|
|
# place the smoothed tail into the output array
|
862
|
2
|
x[tailinds[x_tail_si]] = smoothed_tail
|
863
|
|
# truncate smoothed values to the largest raw weight 0
|
864
|
2
|
x[x > 0] = 0
|
865
|
|
# renormalize weights
|
866
|
2
|
x -= _logsumexp(x)
|
867
|
|
|
868
|
2
|
return x, k
|
869
|
|
|
870
|
|
|
871
|
2
|
def _gpdfit(ary):
|
872
|
|
"""Estimate the parameters for the Generalized Pareto Distribution (GPD).
|
873
|
|
|
874
|
|
Empirical Bayes estimate for the parameters of the generalized Pareto
|
875
|
|
distribution given the data.
|
876
|
|
|
877
|
|
Parameters
|
878
|
|
----------
|
879
|
|
ary: array
|
880
|
|
sorted 1D data array
|
881
|
|
|
882
|
|
Returns
|
883
|
|
-------
|
884
|
|
k: float
|
885
|
|
estimated shape parameter
|
886
|
|
sigma: float
|
887
|
|
estimated scale parameter
|
888
|
|
"""
|
889
|
2
|
prior_bs = 3
|
890
|
2
|
prior_k = 10
|
891
|
2
|
n = len(ary)
|
892
|
2
|
m_est = 30 + int(n ** 0.5)
|
893
|
|
|
894
|
2
|
b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5))
|
895
|
2
|
b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
|
896
|
2
|
b_ary += 1 / ary[-1]
|
897
|
|
|
898
|
2
|
k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member
|
899
|
2
|
len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
|
900
|
2
|
weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
|
901
|
|
|
902
|
|
# remove negligible weights
|
903
|
2
|
real_idxs = weights >= 10 * np.finfo(float).eps
|
904
|
2
|
if not np.all(real_idxs):
|
905
|
2
|
weights = weights[real_idxs]
|
906
|
2
|
b_ary = b_ary[real_idxs]
|
907
|
|
# normalise weights
|
908
|
2
|
weights /= weights.sum()
|
909
|
|
|
910
|
|
# posterior mean for b
|
911
|
2
|
b_post = np.sum(b_ary * weights)
|
912
|
|
# estimate for k
|
913
|
2
|
k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member
|
914
|
|
# add prior for k_post
|
915
|
2
|
k_post = (n * k_post + prior_k * 0.5) / (n + prior_k)
|
916
|
2
|
sigma = -k_post / b_post
|
917
|
|
|
918
|
2
|
return k_post, sigma
|
919
|
|
|
920
|
|
|
921
|
2
|
def _gpinv(probs, kappa, sigma):
|
922
|
|
"""Inverse Generalized Pareto distribution function."""
|
923
|
|
# pylint: disable=unsupported-assignment-operation, invalid-unary-operand-type
|
924
|
2
|
x = np.full_like(probs, np.nan)
|
925
|
2
|
if sigma <= 0:
|
926
|
2
|
return x
|
927
|
2
|
ok = (probs > 0) & (probs < 1)
|
928
|
2
|
if np.all(ok):
|
929
|
2
|
if np.abs(kappa) < np.finfo(float).eps:
|
930
|
2
|
x = -np.log1p(-probs)
|
931
|
|
else:
|
932
|
2
|
x = np.expm1(-kappa * np.log1p(-probs)) / kappa
|
933
|
2
|
x *= sigma
|
934
|
|
else:
|
935
|
2
|
if np.abs(kappa) < np.finfo(float).eps:
|
936
|
2
|
x[ok] = -np.log1p(-probs[ok])
|
937
|
|
else:
|
938
|
2
|
x[ok] = np.expm1(-kappa * np.log1p(-probs[ok])) / kappa
|
939
|
2
|
x *= sigma
|
940
|
2
|
x[probs == 0] = 0
|
941
|
2
|
if kappa >= 0:
|
942
|
2
|
x[probs == 1] = np.inf
|
943
|
|
else:
|
944
|
2
|
x[probs == 1] = -sigma / kappa
|
945
|
2
|
return x
|
946
|
|
|
947
|
|
|
948
|
2
|
def r2_score(y_true, y_pred):
|
949
|
|
"""R² for Bayesian regression models. Only valid for linear models.
|
950
|
|
|
951
|
|
Parameters
|
952
|
|
----------
|
953
|
|
y_true: array-like of shape = (n_samples) or (n_samples, n_outputs)
|
954
|
|
Ground truth (correct) target values.
|
955
|
|
y_pred: array-like of shape = (n_samples) or (n_samples, n_outputs)
|
956
|
|
Estimated target values.
|
957
|
|
|
958
|
|
Returns
|
959
|
|
-------
|
960
|
|
Pandas Series with the following indices:
|
961
|
|
r2: Bayesian R²
|
962
|
|
r2_std: standard deviation of the Bayesian R².
|
963
|
|
|
964
|
|
Examples
|
965
|
|
--------
|
966
|
|
Calculate R² for Bayesian regression models :
|
967
|
|
|
968
|
|
.. ipython::
|
969
|
|
|
970
|
|
In [1]: import arviz as az
|
971
|
|
...: data = az.load_arviz_data('regression1d')
|
972
|
|
...: y_true = data.observed_data["y"].values
|
973
|
|
...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
|
974
|
|
...: az.r2_score(y_true, y_pred)
|
975
|
|
|
976
|
|
"""
|
977
|
2
|
_numba_flag = Numba.numba_flag
|
978
|
2
|
if y_pred.ndim == 1:
|
979
|
2
|
var_y_est = _numba_var(svar, np.var, y_pred)
|
980
|
2
|
var_e = _numba_var(svar, np.var, (y_true - y_pred))
|
981
|
|
else:
|
982
|
2
|
var_y_est = _numba_var(svar, np.var, y_pred.mean(0))
|
983
|
2
|
var_e = _numba_var(svar, np.var, (y_true - y_pred), axis=0)
|
984
|
2
|
r_squared = var_y_est / (var_y_est + var_e)
|
985
|
|
|
986
|
2
|
return pd.Series([np.mean(r_squared), np.std(r_squared)], index=["r2", "r2_std"])
|
987
|
|
|
988
|
|
|
989
|
2
|
def summary(
|
990
|
|
data,
|
991
|
|
var_names: Optional[List[str]] = None,
|
992
|
|
filter_vars=None,
|
993
|
|
fmt: str = "wide",
|
994
|
|
kind: str = "all",
|
995
|
|
round_to=None,
|
996
|
|
include_circ=None,
|
997
|
|
circ_var_names=None,
|
998
|
|
stat_funcs=None,
|
999
|
|
extend=True,
|
1000
|
|
hdi_prob=None,
|
1001
|
|
order="C",
|
1002
|
|
index_origin=None,
|
1003
|
|
skipna=False,
|
1004
|
|
coords: Optional[CoordSpec] = None,
|
1005
|
|
dims: Optional[DimSpec] = None,
|
1006
|
|
credible_interval=None,
|
1007
|
|
) -> Union[pd.DataFrame, xr.Dataset]:
|
1008
|
|
"""Create a data frame with summary statistics.
|
1009
|
|
|
1010
|
|
Parameters
|
1011
|
|
----------
|
1012
|
|
data: obj
|
1013
|
|
Any object that can be converted to an az.InferenceData object
|
1014
|
|
Refer to documentation of az.convert_to_dataset for details
|
1015
|
|
var_names: list
|
1016
|
|
Names of variables to include in summary. Prefix the variables by `~` when you
|
1017
|
|
want to exclude them from the summary: `["~beta"]` instead of `["beta"]` (see
|
1018
|
|
examples below).
|
1019
|
|
filter_vars: {None, "like", "regex"}, optional, default=None
|
1020
|
|
If `None` (default), interpret var_names as the real variables names. If "like",
|
1021
|
|
interpret var_names as substrings of the real variables names. If "regex",
|
1022
|
|
interpret var_names as regular expressions on the real variables names. A la
|
1023
|
|
`pandas.filter`.
|
1024
|
|
fmt: {'wide', 'long', 'xarray'}
|
1025
|
|
Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}.
|
1026
|
|
kind: {'all', 'stats', 'diagnostics'}
|
1027
|
|
Whether to include the `stats`: `mean`, `sd`, `hdi_3%`, `hdi_97%`, or the `diagnostics`:
|
1028
|
|
`mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and `r_hat`. Default to include `all` of
|
1029
|
|
them.
|
1030
|
|
round_to: int
|
1031
|
|
Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
|
1032
|
|
include_circ: boolean
|
1033
|
|
Whether to include circular statistics
|
1034
|
|
deprecated: Please see circ_var_names
|
1035
|
|
circ_var_names: list
|
1036
|
|
A list of circular variables to compute circular stats for
|
1037
|
|
stat_funcs: dict
|
1038
|
|
A list of functions or a dict of functions with function names as keys used to calculate
|
1039
|
|
statistics. By default, the mean, standard deviation, simulation standard error, and
|
1040
|
|
highest posterior density intervals are included.
|
1041
|
|
|
1042
|
|
The functions will be given one argument, the samples for a variable as an nD array,
|
1043
|
|
The functions should be in the style of a ufunc and return a single number. For example,
|
1044
|
|
`np.mean`, or `scipy.stats.var` would both work.
|
1045
|
|
extend: boolean
|
1046
|
|
If True, use the statistics returned by ``stat_funcs`` in addition to, rather than in place
|
1047
|
|
of, the default statistics. This is only meaningful when ``stat_funcs`` is not None.
|
1048
|
|
hdi_prob: float, optional
|
1049
|
|
HDI interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is
|
1050
|
|
None.
|
1051
|
|
order: {"C", "F"}
|
1052
|
|
If fmt is "wide", use either C or F unpacking order. Defaults to C.
|
1053
|
|
index_origin: int
|
1054
|
|
If fmt is "wide, select n-based indexing for multivariate parameters.
|
1055
|
|
Defaults to rcParam data.index.origin, which is 0.
|
1056
|
|
skipna: bool
|
1057
|
|
If true ignores nan values when computing the summary statistics, it does not affect the
|
1058
|
|
behaviour of the functions passed to ``stat_funcs``. Defaults to false.
|
1059
|
|
coords: Dict[str, List[Any]], optional
|
1060
|
|
Coordinates specification to be used if the ``fmt`` is ``'xarray'``.
|
1061
|
|
dims: Dict[str, List[str]], optional
|
1062
|
|
Dimensions specification for the variables to be used if the ``fmt`` is ``'xarray'``.
|
1063
|
|
credible_interval: float, optional
|
1064
|
|
deprecated: Please see hdi_prob
|
1065
|
|
|
1066
|
|
Returns
|
1067
|
|
-------
|
1068
|
|
pandas.DataFrame or xarray.Dataset
|
1069
|
|
Return type dicated by `fmt` argument.
|
1070
|
|
Return value will contain summary statistics for each variable. Default statistics are:
|
1071
|
|
`mean`, `sd`, `hdi_3%`, `hdi_97%`, `mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and
|
1072
|
|
`r_hat`.
|
1073
|
|
`r_hat` is only computed for traces with 2 or more chains.
|
1074
|
|
|
1075
|
|
Examples
|
1076
|
|
--------
|
1077
|
|
.. ipython::
|
1078
|
|
|
1079
|
|
In [1]: import arviz as az
|
1080
|
|
...: data = az.load_arviz_data("centered_eight")
|
1081
|
|
...: az.summary(data, var_names=["mu", "tau"])
|
1082
|
|
|
1083
|
|
You can use `filter_vars` to select variables without having to specify all the exact
|
1084
|
|
names. Use `filter_vars="like"` to select based on partial naming:
|
1085
|
|
|
1086
|
|
.. ipython::
|
1087
|
|
|
1088
|
|
In [1]: az.summary(data, var_names=["the"], filter_vars="like")
|
1089
|
|
|
1090
|
|
Use `filter_vars="regex"` to select based on regular expressions, and prefix the variables
|
1091
|
|
you want to exclude by `~`. Here, we exclude from the summary all the variables
|
1092
|
|
starting with the letter t:
|
1093
|
|
|
1094
|
|
.. ipython::
|
1095
|
|
|
1096
|
|
In [1]: az.summary(data, var_names=["~^t"], filter_vars="regex")
|
1097
|
|
|
1098
|
|
Other statistics can be calculated by passing a list of functions
|
1099
|
|
or a dictionary with key, function pairs.
|
1100
|
|
|
1101
|
|
.. ipython::
|
1102
|
|
|
1103
|
|
In [1]: import numpy as np
|
1104
|
|
...: def median_sd(x):
|
1105
|
|
...: median = np.percentile(x, 50)
|
1106
|
|
...: sd = np.sqrt(np.mean((x-median)**2))
|
1107
|
|
...: return sd
|
1108
|
|
...:
|
1109
|
|
...: func_dict = {
|
1110
|
|
...: "std": np.std,
|
1111
|
|
...: "median_std": median_sd,
|
1112
|
|
...: "5%": lambda x: np.percentile(x, 5),
|
1113
|
|
...: "median": lambda x: np.percentile(x, 50),
|
1114
|
|
...: "95%": lambda x: np.percentile(x, 95),
|
1115
|
|
...: }
|
1116
|
|
...: az.summary(
|
1117
|
|
...: data,
|
1118
|
|
...: var_names=["mu", "tau"],
|
1119
|
|
...: stat_funcs=func_dict,
|
1120
|
|
...: extend=False
|
1121
|
|
...: )
|
1122
|
|
|
1123
|
|
"""
|
1124
|
2
|
if include_circ:
|
1125
|
2
|
warnings.warn(
|
1126
|
|
"include_circ is deprecated and will be ignored. Use circ_var_names instead",
|
1127
|
|
DeprecationWarning,
|
1128
|
|
)
|
1129
|
|
|
1130
|
2
|
if credible_interval:
|
1131
|
0
|
hdi_prob = credible_interval_warning(hdi_prob, hdi_prob)
|
1132
|
|
|
1133
|
2
|
extra_args = {} # type: Dict[str, Any]
|
1134
|
2
|
if coords is not None:
|
1135
|
0
|
extra_args["coords"] = coords
|
1136
|
2
|
if dims is not None:
|
1137
|
0
|
extra_args["dims"] = dims
|
1138
|
2
|
if index_origin is None:
|
1139
|
2
|
index_origin = rcParams["data.index_origin"]
|
1140
|
2
|
if hdi_prob is None:
|
1141
|
2
|
hdi_prob = rcParams["stats.hdi_prob"]
|
1142
|
|
else:
|
1143
|
0
|
if not 1 >= hdi_prob > 0:
|
1144
|
0
|
raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
|
1145
|
2
|
posterior = convert_to_dataset(data, group="posterior", **extra_args)
|
1146
|
2
|
var_names = _var_names(var_names, posterior, filter_vars)
|
1147
|
2
|
posterior = posterior if var_names is None else posterior[var_names]
|
1148
|
|
|
1149
|
2
|
fmt_group = ("wide", "long", "xarray")
|
1150
|
2
|
if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
|
1151
|
2
|
raise TypeError("Invalid format: '{}'. Formatting options are: {}".format(fmt, fmt_group))
|
1152
|
|
|
1153
|
2
|
unpack_order_group = ("C", "F")
|
1154
|
2
|
if not isinstance(order, str) or (order.upper() not in unpack_order_group):
|
1155
|
2
|
raise TypeError(
|
1156
|
|
"Invalid order: '{}'. Unpacking options are: {}".format(order, unpack_order_group)
|
1157
|
|
)
|
1158
|
|
|
1159
|
2
|
alpha = 1 - hdi_prob
|
1160
|
|
|
1161
|
2
|
extra_metrics = []
|
1162
|
2
|
extra_metric_names = []
|
1163
|
|
|
1164
|
2
|
if stat_funcs is not None:
|
1165
|
2
|
if isinstance(stat_funcs, dict):
|
1166
|
2
|
for stat_func_name, stat_func in stat_funcs.items():
|
1167
|
2
|
extra_metrics.append(
|
1168
|
|
xr.apply_ufunc(
|
1169
|
|
_make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
|
1170
|
|
)
|
1171
|
|
)
|
1172
|
2
|
extra_metric_names.append(stat_func_name)
|
1173
|
|
else:
|
1174
|
2
|
for stat_func in stat_funcs:
|
1175
|
2
|
extra_metrics.append(
|
1176
|
|
xr.apply_ufunc(
|
1177
|
|
_make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
|
1178
|
|
)
|
1179
|
|
)
|
1180
|
2
|
extra_metric_names.append(stat_func.__name__)
|
1181
|
|
|
1182
|
2
|
if extend and kind in ["all", "stats"]:
|
1183
|
2
|
mean = posterior.mean(dim=("chain", "draw"), skipna=skipna)
|
1184
|
|
|
1185
|
2
|
sd = posterior.std(dim=("chain", "draw"), ddof=1, skipna=skipna)
|
1186
|
|
|
1187
|
2
|
hdi_post = hdi(posterior, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
|
1188
|
2
|
hdi_lower = hdi_post.sel(hdi="lower", drop=True)
|
1189
|
2
|
hdi_higher = hdi_post.sel(hdi="higher", drop=True)
|
1190
|
|
|
1191
|
2
|
if circ_var_names:
|
1192
|
2
|
nan_policy = "omit" if skipna else "propagate"
|
1193
|
2
|
circ_mean = xr.apply_ufunc(
|
1194
|
|
_make_ufunc(st.circmean),
|
1195
|
|
posterior,
|
1196
|
|
kwargs=dict(high=np.pi, low=-np.pi, nan_policy=nan_policy),
|
1197
|
|
input_core_dims=(("chain", "draw"),),
|
1198
|
|
)
|
1199
|
2
|
_numba_flag = Numba.numba_flag
|
1200
|
2
|
func = None
|
1201
|
2
|
if _numba_flag:
|
1202
|
2
|
func = _circular_standard_deviation
|
1203
|
2
|
kwargs_circ_std = dict(high=np.pi, low=-np.pi, skipna=skipna)
|
1204
|
|
else:
|
1205
|
2
|
func = st.circstd
|
1206
|
2
|
kwargs_circ_std = dict(high=np.pi, low=-np.pi, nan_policy=nan_policy)
|
1207
|
2
|
circ_sd = xr.apply_ufunc(
|
1208
|
|
_make_ufunc(func),
|
1209
|
|
posterior,
|
1210
|
|
kwargs=kwargs_circ_std,
|
1211
|
|
input_core_dims=(("chain", "draw"),),
|
1212
|
|
)
|
1213
|
|
|
1214
|
2
|
circ_mcse = xr.apply_ufunc(
|
1215
|
|
_make_ufunc(_mc_error),
|
1216
|
|
posterior,
|
1217
|
|
kwargs=dict(circular=True),
|
1218
|
|
input_core_dims=(("chain", "draw"),),
|
1219
|
|
)
|
1220
|
|
|
1221
|
2
|
circ_hdi = hdi(posterior, hdi_prob=hdi_prob, circular=True, skipna=skipna)
|
1222
|
2
|
circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True)
|
1223
|
2
|
circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)
|
1224
|
|
|
1225
|
2
|
if kind in ["all", "diagnostics"]:
|
1226
|
2
|
mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc(
|
1227
|
|
_make_ufunc(_multichain_statistics, n_output=7, ravel=False),
|
1228
|
|
posterior,
|
1229
|
|
input_core_dims=(("chain", "draw"),),
|
1230
|
|
output_core_dims=tuple([] for _ in range(7)),
|
1231
|
|
)
|
1232
|
|
|
1233
|
|
# Combine metrics
|
1234
|
2
|
metrics = []
|
1235
|
2
|
metric_names = []
|
1236
|
2
|
if extend:
|
1237
|
2
|
metrics_names_ = (
|
1238
|
|
"mean",
|
1239
|
|
"sd",
|
1240
|
|
"hdi_{:g}%".format(100 * alpha / 2),
|
1241
|
|
"hdi_{:g}%".format(100 * (1 - alpha / 2)),
|
1242
|
|
"mcse_mean",
|
1243
|
|
"mcse_sd",
|
1244
|
|
"ess_mean",
|
1245
|
|
"ess_sd",
|
1246
|
|
"ess_bulk",
|
1247
|
|
"ess_tail",
|
1248
|
|
"r_hat",
|
1249
|
|
)
|
1250
|
2
|
if kind == "all":
|
1251
|
2
|
metrics_ = (
|
1252
|
|
mean,
|
1253
|
|
sd,
|
1254
|
|
hdi_lower,
|
1255
|
|
hdi_higher,
|
1256
|
|
mcse_mean,
|
1257
|
|
mcse_sd,
|
1258
|
|
ess_mean,
|
1259
|
|
ess_sd,
|
1260
|
|
ess_bulk,
|
1261
|
|
ess_tail,
|
1262
|
|
r_hat,
|
1263
|
|
)
|
1264
|
2
|
elif kind == "stats":
|
1265
|
2
|
metrics_ = (mean, sd, hdi_lower, hdi_higher)
|
1266
|
2
|
metrics_names_ = metrics_names_[:4]
|
1267
|
2
|
elif kind == "diagnostics":
|
1268
|
2
|
metrics_ = (mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat)
|
1269
|
2
|
metrics_names_ = metrics_names_[4:]
|
1270
|
2
|
metrics.extend(metrics_)
|
1271
|
2
|
metric_names.extend(metrics_names_)
|
1272
|
|
|
1273
|
2
|
if circ_var_names:
|
1274
|
|
|
1275
|
2
|
if kind != "diagnostics":
|
1276
|
2
|
for metric, circ_stat in zip(
|
1277
|
|
# Replace only the first 5 statistics for their circular equivalent
|
1278
|
|
metrics[:5],
|
1279
|
|
(circ_mean, circ_sd, circ_hdi_lower, circ_hdi_higher, circ_mcse),
|
1280
|
|
):
|
1281
|
2
|
for circ_var in circ_var_names:
|
1282
|
2
|
metric[circ_var] = circ_stat[circ_var]
|
1283
|
|
|
1284
|
2
|
metrics.extend(extra_metrics)
|
1285
|
2
|
metric_names.extend(extra_metric_names)
|
1286
|
2
|
joined = (
|
1287
|
|
xr.concat(metrics, dim="metric").assign_coords(metric=metric_names).reset_coords(drop=True)
|
1288
|
|
)
|
1289
|
|
|
1290
|
2
|
if fmt.lower() == "wide":
|
1291
|
2
|
dfs = []
|
1292
|
2
|
for var_name, values in joined.data_vars.items():
|
1293
|
2
|
if len(values.shape[1:]):
|
1294
|
2
|
metric = list(values.metric.values)
|
1295
|
2
|
data_dict = OrderedDict()
|
1296
|
2
|
for idx in np.ndindex(values.shape[1:] if order == "C" else values.shape[1:][::-1]):
|
1297
|
2
|
if order == "F":
|
1298
|
2
|
idx = tuple(idx[::-1])
|
1299
|
2
|
ser = pd.Series(values[(Ellipsis, *idx)].values, index=metric)
|
1300
|
2
|
key_index = ",".join(map(str, (i + index_origin for i in idx)))
|
1301
|
2
|
key = "{}[{}]".format(var_name, key_index)
|
1302
|
2
|
data_dict[key] = ser
|
1303
|
2
|
df = pd.DataFrame.from_dict(data_dict, orient="index")
|
1304
|
2
|
df = df.loc[list(data_dict.keys())]
|
1305
|
|
else:
|
1306
|
2
|
df = values.to_dataframe()
|
1307
|
2
|
df.index = list(df.index)
|
1308
|
2
|
df = df.T
|
1309
|
2
|
dfs.append(df)
|
1310
|
2
|
summary_df = pd.concat(dfs, sort=False)
|
1311
|
2
|
elif fmt.lower() == "long":
|
1312
|
2
|
df = joined.to_dataframe().reset_index().set_index("metric")
|
1313
|
2
|
df.index = list(df.index)
|
1314
|
2
|
summary_df = df
|
1315
|
|
else:
|
1316
|
|
# format is 'xarray'
|
1317
|
2
|
summary_df = joined
|
1318
|
2
|
if (round_to is not None) and (round_to not in ("None", "none")):
|
1319
|
0
|
summary_df = summary_df.round(round_to)
|
1320
|
2
|
elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")):
|
1321
|
|
# Don't round xarray object by default (even with "none")
|
1322
|
2
|
decimals = {
|
1323
|
|
col: 3
|
1324
|
|
if col not in {"ess_mean", "ess_sd", "ess_bulk", "ess_tail", "r_hat"}
|
1325
|
|
else 2
|
1326
|
|
if col == "r_hat"
|
1327
|
|
else 0
|
1328
|
|
for col in summary_df.columns
|
1329
|
|
}
|
1330
|
2
|
summary_df = summary_df.round(decimals)
|
1331
|
|
|
1332
|
2
|
return summary_df
|
1333
|
|
|
1334
|
|
|
1335
|
2
|
def waic(data, pointwise=None, var_name=None, scale=None):
|
1336
|
|
"""Compute the widely applicable information criterion.
|
1337
|
|
|
1338
|
|
Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the
|
1339
|
|
WAIC's standard error and the effective number of parameters.
|
1340
|
|
Read more theory here https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1004.2316
|
1341
|
|
|
1342
|
|
Parameters
|
1343
|
|
----------
|
1344
|
|
data: obj
|
1345
|
|
Any object that can be converted to an az.InferenceData object. Refer to documentation of
|
1346
|
|
``az.convert_to_inference_data`` for details
|
1347
|
|
pointwise: bool
|
1348
|
|
If True the pointwise predictive accuracy will be returned. Defaults to
|
1349
|
|
``stats.ic_pointwise`` rcParam.
|
1350
|
|
var_name : str, optional
|
1351
|
|
The name of the variable in log_likelihood groups storing the pointwise log
|
1352
|
|
likelihood data to use for waic computation.
|
1353
|
|
scale: str
|
1354
|
|
Output scale for WAIC. Available options are:
|
1355
|
|
|
1356
|
|
- `log` : (default) log-score
|
1357
|
|
- `negative_log` : -1 * log-score
|
1358
|
|
- `deviance` : -2 * log-score
|
1359
|
|
|
1360
|
|
A higher log-score (or a lower deviance or negative log_score) indicates a model with
|
1361
|
|
better predictive accuracy.
|
1362
|
|
|
1363
|
|
Returns
|
1364
|
|
-------
|
1365
|
|
ELPDData object (inherits from panda.Series) with the following row/attributes:
|
1366
|
|
waic: approximated expected log pointwise predictive density (elpd)
|
1367
|
|
waic_se: standard error of waic
|
1368
|
|
p_waic: effective number parameters
|
1369
|
|
var_warn: bool
|
1370
|
|
True if posterior variance of the log predictive densities exceeds 0.4
|
1371
|
|
waic_i: xarray.DataArray with the pointwise predictive accuracy, only if pointwise=True
|
1372
|
|
waic_scale: scale of the reported waic results
|
1373
|
|
|
1374
|
|
The returned object has a custom print method that overrides pd.Series method.
|
1375
|
|
|
1376
|
|
Examples
|
1377
|
|
--------
|
1378
|
|
Calculate WAIC of a model:
|
1379
|
|
|
1380
|
|
.. ipython::
|
1381
|
|
|
1382
|
|
In [1]: import arviz as az
|
1383
|
|
...: data = az.load_arviz_data("centered_eight")
|
1384
|
|
...: az.waic(data)
|
1385
|
|
|
1386
|
|
Calculate WAIC of a model and return the pointwise values:
|
1387
|
|
|
1388
|
|
.. ipython::
|
1389
|
|
|
1390
|
|
In [2]: data_waic = az.waic(data, pointwise=True)
|
1391
|
|
...: data_waic.waic_i
|
1392
|
|
"""
|
1393
|
2
|
inference_data = convert_to_inference_data(data)
|
1394
|
2
|
log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
|
1395
|
2
|
scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
|
1396
|
2
|
pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
|
1397
|
|
|
1398
|
2
|
if scale == "deviance":
|
1399
|
2
|
scale_value = -2
|
1400
|
2
|
elif scale == "log":
|
1401
|
2
|
scale_value = 1
|
1402
|
2
|
elif scale == "negative_log":
|
1403
|
2
|
scale_value = -1
|
1404
|
|
else:
|
1405
|
2
|
raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
|
1406
|
|
|
1407
|
2
|
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
|
1408
|
2
|
shape = log_likelihood.shape
|
1409
|
2
|
n_samples = shape[-1]
|
1410
|
2
|
n_data_points = np.product(shape[:-1])
|
1411
|
|
|
1412
|
2
|
ufunc_kwargs = {"n_dims": 1, "ravel": False}
|
1413
|
2
|
kwargs = {"input_core_dims": [["sample"]]}
|
1414
|
2
|
lppd_i = _wrap_xarray_ufunc(
|
1415
|
|
_logsumexp,
|
1416
|
|
log_likelihood,
|
1417
|
|
func_kwargs={"b_inv": n_samples},
|
1418
|
|
ufunc_kwargs=ufunc_kwargs,
|
1419
|
|
**kwargs,
|
1420
|
|
)
|
1421
|
|
|
1422
|
2
|
vars_lpd = log_likelihood.var(dim="sample")
|
1423
|
2
|
warn_mg = False
|
1424
|
2
|
if np.any(vars_lpd > 0.4):
|
1425
|
2
|
warnings.warn(
|
1426
|
|
(
|
1427
|
|
"For one or more samples the posterior variance of the log predictive "
|
1428
|
|
"densities exceeds 0.4. This could be indication of WAIC starting to fail. \n"
|
1429
|
|
"See http://arxiv.org/abs/1507.04544 for details"
|
1430
|
|
)
|
1431
|
|
)
|
1432
|
2
|
warn_mg = True
|
1433
|
|
|
1434
|
2
|
waic_i = scale_value * (lppd_i - vars_lpd)
|
1435
|
2
|
waic_se = (n_data_points * np.var(waic_i.values)) ** 0.5
|
1436
|
2
|
waic_sum = np.sum(waic_i.values)
|
1437
|
2
|
p_waic = np.sum(vars_lpd.values)
|
1438
|
|
|
1439
|
2
|
if pointwise:
|
1440
|
2
|
if np.equal(waic_sum, waic_i).all(): # pylint: disable=no-member
|
1441
|
2
|
warnings.warn(
|
1442
|
|
"""The point-wise WAIC is the same with the sum WAIC, please double check
|
1443
|
|
the Observed RV in your model to make sure it returns element-wise logp.
|
1444
|
|
"""
|
1445
|
|
)
|
1446
|
2
|
return ELPDData(
|
1447
|
|
data=[
|
1448
|
|
waic_sum,
|
1449
|
|
waic_se,
|
1450
|
|
p_waic,
|
1451
|
|
n_samples,
|
1452
|
|
n_data_points,
|
1453
|
|
warn_mg,
|
1454
|
|
waic_i.rename("waic_i"),
|
1455
|
|
scale,
|
1456
|
|
],
|
1457
|
|
index=[
|
1458
|
|
"waic",
|
1459
|
|
"waic_se",
|
1460
|
|
"p_waic",
|
1461
|
|
"n_samples",
|
1462
|
|
"n_data_points",
|
1463
|
|
"warning",
|
1464
|
|
"waic_i",
|
1465
|
|
"waic_scale",
|
1466
|
|
],
|
1467
|
|
)
|
1468
|
|
else:
|
1469
|
2
|
return ELPDData(
|
1470
|
|
data=[waic_sum, waic_se, p_waic, n_samples, n_data_points, warn_mg, scale],
|
1471
|
|
index=[
|
1472
|
|
"waic",
|
1473
|
|
"waic_se",
|
1474
|
|
"p_waic",
|
1475
|
|
"n_samples",
|
1476
|
|
"n_data_points",
|
1477
|
|
"warning",
|
1478
|
|
"waic_scale",
|
1479
|
|
],
|
1480
|
|
)
|
1481
|
|
|
1482
|
|
|
1483
|
2
|
def loo_pit(idata=None, *, y=None, y_hat=None, log_weights=None):
|
1484
|
|
"""Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
|
1485
|
|
|
1486
|
|
Parameters
|
1487
|
|
----------
|
1488
|
|
idata: InferenceData
|
1489
|
|
InferenceData object.
|
1490
|
|
y: array, DataArray or str
|
1491
|
|
Observed data. If str, idata must be present and contain the observed data group
|
1492
|
|
y_hat: array, DataArray or str
|
1493
|
|
Posterior predictive samples for ``y``. It must have the same shape as y plus an
|
1494
|
|
extra dimension at the end of size n_samples (chains and draws stacked). If str or
|
1495
|
|
None, idata must contain the posterior predictive group. If None, y_hat is taken
|
1496
|
|
equal to y, thus, y must be str too.
|
1497
|
|
log_weights: array or DataArray
|
1498
|
|
Smoothed log_weights. It must have the same shape as ``y_hat``
|
1499
|
|
|
1500
|
|
Returns
|
1501
|
|
-------
|
1502
|
|
loo_pit: array or DataArray
|
1503
|
|
Value of the LOO-PIT at each observed data point.
|
1504
|
|
|
1505
|
|
Examples
|
1506
|
|
--------
|
1507
|
|
Calculate LOO-PIT values using as test quantity the observed values themselves.
|
1508
|
|
|
1509
|
|
.. ipython::
|
1510
|
|
|
1511
|
|
In [1]: import arviz as az
|
1512
|
|
...: data = az.load_arviz_data("centered_eight")
|
1513
|
|
...: az.loo_pit(idata=data, y="obs")
|
1514
|
|
|
1515
|
|
Calculate LOO-PIT values using as test quantity the square of the difference between
|
1516
|
|
each observation and `mu`. Both ``y`` and ``y_hat`` inputs will be array-like,
|
1517
|
|
but ``idata`` will still be passed in order to calculate the ``log_weights`` from
|
1518
|
|
there.
|
1519
|
|
|
1520
|
|
.. ipython::
|
1521
|
|
|
1522
|
|
In [1]: T = data.observed_data.obs - data.posterior.mu.median(dim=("chain", "draw"))
|
1523
|
|
...: T_hat = data.posterior_predictive.obs - data.posterior.mu
|
1524
|
|
...: T_hat = T_hat.stack(sample=("chain", "draw"))
|
1525
|
|
...: az.loo_pit(idata=data, y=T**2, y_hat=T_hat**2)
|
1526
|
|
|
1527
|
|
"""
|
1528
|
2
|
y_str = ""
|
1529
|
2
|
if idata is not None and not isinstance(idata, InferenceData):
|
1530
|
2
|
raise ValueError("idata must be of type InferenceData or None")
|
1531
|
|
|
1532
|
2
|
if idata is None:
|
1533
|
2
|
if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat, log_weights)):
|
1534
|
2
|
raise ValueError(
|
1535
|
|
"all 3 y, y_hat and log_weights must be array or DataArray when idata is None "
|
1536
|
|
"but they are of types {}".format([type(arg) for arg in (y, y_hat, log_weights)])
|
1537
|
|
)
|
1538
|
|
|
1539
|
|
else:
|
1540
|
2
|
if y_hat is None and isinstance(y, str):
|
1541
|
2
|
y_hat = y
|
1542
|
2
|
elif y_hat is None:
|
1543
|
2
|
raise ValueError("y_hat cannot be None if y is not a str")
|
1544
|
2
|
if isinstance(y, str):
|
1545
|
2
|
y_str = y
|
1546
|
2
|
y = idata.observed_data[y].values
|
1547
|
2
|
elif not isinstance(y, (np.ndarray, xr.DataArray)):
|
1548
|
2
|
raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y)))
|
1549
|
2
|
if isinstance(y_hat, str):
|
1550
|
2
|
y_hat = idata.posterior_predictive[y_hat].stack(sample=("chain", "draw")).values
|
1551
|
2
|
elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
|
1552
|
2
|
raise ValueError(
|
1553
|
|
"y_hat must be of types array, DataArray or str, not {}".format(type(y_hat))
|
1554
|
|
)
|
1555
|
2
|
if log_weights is None:
|
1556
|
2
|
if y_str:
|
1557
|
2
|
try:
|
1558
|
2
|
log_likelihood = _get_log_likelihood(idata, var_name=y)
|
1559
|
2
|
except TypeError:
|
1560
|
2
|
log_likelihood = _get_log_likelihood(idata)
|
1561
|
|
else:
|
1562
|
2
|
log_likelihood = _get_log_likelihood(idata)
|
1563
|
2
|
log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
|
1564
|
2
|
posterior = convert_to_dataset(idata, group="posterior")
|
1565
|
2
|
n_chains = len(posterior.chain)
|
1566
|
2
|
n_samples = len(log_likelihood.sample)
|
1567
|
2
|
ess_p = ess(posterior, method="mean")
|
1568
|
|
# this mean is over all data variables
|
1569
|
2
|
reff = (
|
1570
|
|
(np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples)
|
1571
|
|
if n_chains > 1
|
1572
|
|
else 1
|
1573
|
|
)
|
1574
|
2
|
log_weights = psislw(-log_likelihood, reff=reff)[0].values
|
1575
|
2
|
elif not isinstance(log_weights, (np.ndarray, xr.DataArray)):
|
1576
|
2
|
raise ValueError(
|
1577
|
|
"log_weights must be None or of types array or DataArray, not {}".format(
|
1578
|
|
type(log_weights)
|
1579
|
|
)
|
1580
|
|
)
|
1581
|
|
|
1582
|
2
|
if len(y.shape) + 1 != len(y_hat.shape):
|
1583
|
2
|
raise ValueError(
|
1584
|
|
"y_hat must have 1 more dimension than y, but y_hat has {} dims and y has "
|
1585
|
|
"{} dims".format(len(y.shape), len(y_hat.shape))
|
1586
|
|
)
|
1587
|
|
|
1588
|
2
|
if y.shape != y_hat.shape[:-1]:
|
1589
|
2
|
raise ValueError(
|
1590
|
|
"y has shape: {} which should be equal to y_hat shape (omitting the last "
|
1591
|
|
"dimension): {}".format(y.shape, y_hat.shape)
|
1592
|
|
)
|
1593
|
|
|
1594
|
2
|
if y_hat.shape != log_weights.shape:
|
1595
|
2
|
raise ValueError(
|
1596
|
|
"y_hat and log_weights must have the same shape but have shapes {} and {}".format(
|
1597
|
|
y_hat.shape, log_weights.shape
|
1598
|
|
)
|
1599
|
|
)
|
1600
|
|
|
1601
|
2
|
kwargs = {
|
1602
|
|
"input_core_dims": [[], ["sample"], ["sample"]],
|
1603
|
|
"output_core_dims": [[]],
|
1604
|
|
"join": "left",
|
1605
|
|
}
|
1606
|
2
|
ufunc_kwargs = {"n_dims": 1}
|
1607
|
|
|
1608
|
2
|
return _wrap_xarray_ufunc(_loo_pit, y, y_hat, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs)
|
1609
|
|
|
1610
|
|
|
1611
|
2
|
def _loo_pit(y, y_hat, log_weights):
|
1612
|
|
"""Compute LOO-PIT values."""
|
1613
|
2
|
sel = y_hat <= y
|
1614
|
2
|
if np.sum(sel) > 0:
|
1615
|
2
|
value = np.exp(_logsumexp(log_weights[sel]))
|
1616
|
2
|
return min(1, value)
|
1617
|
|
else:
|
1618
|
0
|
return 0
|
1619
|
|
|
1620
|
|
|
1621
|
2
|
def apply_test_function(
|
1622
|
|
idata,
|
1623
|
|
func,
|
1624
|
|
group="both",
|
1625
|
|
var_names=None,
|
1626
|
|
pointwise=False,
|
1627
|
|
out_data_shape=None,
|
1628
|
|
out_pp_shape=None,
|
1629
|
|
out_name_data="T",
|
1630
|
|
out_name_pp=None,
|
1631
|
|
func_args=None,
|
1632
|
|
func_kwargs=None,
|
1633
|
|
ufunc_kwargs=None,
|
1634
|
|
wrap_data_kwargs=None,
|
1635
|
|
wrap_pp_kwargs=None,
|
1636
|
|
inplace=True,
|
1637
|
|
overwrite=None,
|
1638
|
|
):
|
1639
|
|
"""Apply a Bayesian test function to an InferenceData object.
|
1640
|
|
|
1641
|
|
Parameters
|
1642
|
|
----------
|
1643
|
|
idata: InferenceData
|
1644
|
|
InferenceData object on which to apply the test function. This function will add
|
1645
|
|
new variables to the InferenceData object to store the result without modifying the
|
1646
|
|
existing ones.
|
1647
|
|
func: callable
|
1648
|
|
Callable that calculates the test function. It must have the following call signature
|
1649
|
|
``func(y, theta, *args, **kwargs)`` (where ``y`` is the observed data or posterior
|
1650
|
|
predictive and ``theta`` the model parameters) even if not all the arguments are
|
1651
|
|
used.
|
1652
|
|
group: str, optional
|
1653
|
|
Group on which to apply the test function. Can be observed_data, posterior_predictive
|
1654
|
|
or both.
|
1655
|
|
var_names: dict group -> var_names, optional
|
1656
|
|
Mapping from group name to the variables to be passed to func. It can be a dict of
|
1657
|
|
strings or lists of strings. There is also the option of using ``both`` as key,
|
1658
|
|
in which case, the same variables are used in observed data and posterior predictive
|
1659
|
|
groups
|
1660
|
|
pointwise: bool, optional
|
1661
|
|
If True, apply the test function to each observation and sample, otherwise, apply
|
1662
|
|
test function to each sample.
|
1663
|
|
out_data_shape, out_pp_shape: tuple, optional
|
1664
|
|
Output shape of the test function applied to the observed/posterior predictive data.
|
1665
|
|
If None, the default depends on the value of pointwise.
|
1666
|
|
out_name_data, out_name_pp: str, optional
|
1667
|
|
Name of the variables to add to the observed_data and posterior_predictive datasets
|
1668
|
|
respectively. ``out_name_pp`` can be ``None``, in which case will be taken equal to
|
1669
|
|
``out_name_data``.
|
1670
|
|
func_args: sequence, optional
|
1671
|
|
Passed as is to ``func``
|
1672
|
|
func_kwargs: mapping, optional
|
1673
|
|
Passed as is to ``func``
|
1674
|
|
wrap_data_kwargs, wrap_pp_kwargs: mapping, optional
|
1675
|
|
kwargs passed to ``az.stats.wrap_xarray_ufunc``. By default, some suitable input_core_dims
|
1676
|
|
are used.
|
1677
|
|
inplace: bool, optional
|
1678
|
|
If True, add the variables inplace, othewise, return a copy of idata with the variables
|
1679
|
|
added.
|
1680
|
|
overwrite: bool, optional
|
1681
|
|
Overwrite data in case ``out_name_data`` or ``out_name_pp`` are already variables in
|
1682
|
|
dataset. If ``None`` it will be the opposite of inplace.
|
1683
|
|
|
1684
|
|
Returns
|
1685
|
|
-------
|
1686
|
|
idata: InferenceData
|
1687
|
|
Output InferenceData object. If ``inplace=True``, it is the same input object modified
|
1688
|
|
inplace.
|
1689
|
|
|
1690
|
|
Notes
|
1691
|
|
-----
|
1692
|
|
This function is provided for convenience to wrap scalar or functions working on low
|
1693
|
|
dims to inference data object. It is not optimized to be faster nor as fast as vectorized
|
1694
|
|
computations.
|
1695
|
|
|
1696
|
|
Examples
|
1697
|
|
--------
|
1698
|
|
Use ``apply_test_function`` to wrap ``np.min`` for illustration purposes. And plot the
|
1699
|
|
results.
|
1700
|
|
|
1701
|
|
.. plot::
|
1702
|
|
:context: close-figs
|
1703
|
|
|
1704
|
|
>>> import arviz as az
|
1705
|
|
>>> idata = az.load_arviz_data("centered_eight")
|
1706
|
|
>>> az.apply_test_function(idata, lambda y, theta: np.min(y))
|
1707
|
|
>>> T = np.asscalar(idata.observed_data.T)
|
1708
|
|
>>> az.plot_posterior(idata, var_names=["T"], group="posterior_predictive", ref_val=T)
|
1709
|
|
|
1710
|
|
"""
|
1711
|
2
|
out = idata if inplace else deepcopy(idata)
|
1712
|
|
|
1713
|
2
|
valid_groups = ("observed_data", "posterior_predictive", "both")
|
1714
|
2
|
if group not in valid_groups:
|
1715
|
2
|
raise ValueError(
|
1716
|
|
"Invalid group argument. Must be one of {} not {}.".format(valid_groups, group)
|
1717
|
|
)
|
1718
|
2
|
if overwrite is None:
|
1719
|
2
|
overwrite = not inplace
|
1720
|
|
|
1721
|
2
|
if out_name_pp is None:
|
1722
|
2
|
out_name_pp = out_name_data
|
1723
|
|
|
1724
|
2
|
if func_args is None:
|
1725
|
2
|
func_args = tuple()
|
1726
|
|
|
1727
|
2
|
if func_kwargs is None:
|
1728
|
2
|
func_kwargs = {}
|
1729
|
|
|
1730
|
2
|
if ufunc_kwargs is None:
|
1731
|
2
|
ufunc_kwargs = {}
|
1732
|
2
|
ufunc_kwargs.setdefault("check_shape", False)
|
1733
|
2
|
ufunc_kwargs.setdefault("ravel", False)
|
1734
|
|
|
1735
|
2
|
if wrap_data_kwargs is None:
|
1736
|
2
|
wrap_data_kwargs = {}
|
1737
|
2
|
if wrap_pp_kwargs is None:
|
1738
|
2
|
wrap_pp_kwargs = {}
|
1739
|
2
|
if var_names is None:
|
1740
|
2
|
var_names = {}
|
1741
|
|
|
1742
|
2
|
both_var_names = var_names.pop("both", None)
|
1743
|
2
|
var_names.setdefault("posterior", list(out.posterior.data_vars))
|
1744
|
|
|
1745
|
2
|
in_posterior = out.posterior[var_names["posterior"]]
|
1746
|
2
|
if isinstance(in_posterior, xr.Dataset):
|
1747
|
2
|
in_posterior = in_posterior.to_array().squeeze()
|
1748
|
|
|
1749
|
2
|
groups = ("posterior_predictive", "observed_data") if group == "both" else [group]
|
1750
|
2
|
for grp in groups:
|
1751
|
2
|
out_group_shape = out_data_shape if grp == "observed_data" else out_pp_shape
|
1752
|
2
|
out_name_group = out_name_data if grp == "observed_data" else out_name_pp
|
1753
|
2
|
wrap_group_kwargs = wrap_data_kwargs if grp == "observed_data" else wrap_pp_kwargs
|
1754
|
2
|
if not hasattr(out, grp):
|
1755
|
2
|
raise ValueError("InferenceData object must have {} group".format(grp))
|
1756
|
2
|
if not overwrite and out_name_group in getattr(out, grp).data_vars:
|
1757
|
2
|
raise ValueError(
|
1758
|
|
"Should overwrite: {} variable present in group {}, but overwrite is False".format(
|
1759
|
|
out_name_group, grp
|
1760
|
|
)
|
1761
|
|
)
|
1762
|
2
|
var_names.setdefault(
|
1763
|
|
grp, list(getattr(out, grp).data_vars) if both_var_names is None else both_var_names
|
1764
|
|
)
|
1765
|
2
|
in_group = getattr(out, grp)[var_names[grp]]
|
1766
|
2
|
if isinstance(in_group, xr.Dataset):
|
1767
|
2
|
in_group = in_group.to_array(dim="{}_var".format(grp)).squeeze()
|
1768
|
|
|
1769
|
2
|
if pointwise:
|
1770
|
2
|
out_group_shape = in_group.shape if out_group_shape is None else out_group_shape
|
1771
|
2
|
elif grp == "observed_data":
|
1772
|
2
|
out_group_shape = () if out_group_shape is None else out_group_shape
|
1773
|
2
|
elif grp == "posterior_predictive":
|
1774
|
2
|
out_group_shape = in_group.shape[:2] if out_group_shape is None else out_group_shape
|
1775
|
2
|
loop_dims = in_group.dims[: len(out_group_shape)]
|
1776
|
|
|
1777
|
2
|
wrap_group_kwargs.setdefault(
|
1778
|
|
"input_core_dims",
|
1779
|
|
[
|
1780
|
|
[dim for dim in dataset.dims if dim not in loop_dims]
|
1781
|
|
for dataset in [in_group, in_posterior]
|
1782
|
|
],
|
1783
|
|
)
|
1784
|
2
|
func_kwargs["out"] = np.empty(out_group_shape)
|
1785
|
|
|
1786
|
2
|
out_group = getattr(out, grp)
|
1787
|
2
|
try:
|
1788
|
2
|
out_group[out_name_group] = _wrap_xarray_ufunc(
|
1789
|
|
func,
|
1790
|
|
in_group.values,
|
1791
|
|
in_posterior.values,
|
1792
|
|
func_args=func_args,
|
1793
|
|
func_kwargs=func_kwargs,
|
1794
|
|
ufunc_kwargs=ufunc_kwargs,
|
1795
|
|
**wrap_group_kwargs,
|
1796
|
|
)
|
1797
|
2
|
except IndexError:
|
1798
|
2
|
excluded_dims = set(
|
1799
|
|
wrap_group_kwargs["input_core_dims"][0] + wrap_group_kwargs["input_core_dims"][1]
|
1800
|
|
)
|
1801
|
2
|
out_group[out_name_group] = _wrap_xarray_ufunc(
|
1802
|
|
func,
|
1803
|
|
*xr.broadcast(in_group, in_posterior, exclude=excluded_dims),
|
1804
|
|
func_args=func_args,
|
1805
|
|
func_kwargs=func_kwargs,
|
1806
|
|
ufunc_kwargs=ufunc_kwargs,
|
1807
|
|
**wrap_group_kwargs,
|
1808
|
|
)
|
1809
|
2
|
setattr(out, grp, out_group)
|
1810
|
|
|
1811
|
2
|
return out
|