1
|
|
# pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name
|
2
|
2
|
"""Diagnostic functions for ArviZ."""
|
3
|
2
|
import warnings
|
4
|
2
|
from collections.abc import Sequence
|
5
|
|
|
6
|
2
|
import numpy as np
|
7
|
2
|
import pandas as pd
|
8
|
2
|
from scipy import stats
|
9
|
|
|
10
|
2
|
from ..data import convert_to_dataset
|
11
|
2
|
from ..utils import Numba, _numba_var, _stack, _var_names, conditional_jit
|
12
|
2
|
from .density_utils import histogram as _histogram
|
13
|
2
|
from .stats_utils import _circular_standard_deviation, _sqrt
|
14
|
2
|
from .stats_utils import autocov as _autocov
|
15
|
2
|
from .stats_utils import not_valid as _not_valid
|
16
|
2
|
from .stats_utils import quantile as _quantile
|
17
|
2
|
from .stats_utils import stats_variance_2d as svar
|
18
|
2
|
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
|
19
|
|
|
20
|
2
|
__all__ = ["bfmi", "ess", "rhat", "mcse", "geweke"]
|
21
|
|
|
22
|
|
|
23
|
2
|
def bfmi(data):
|
24
|
|
r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
|
25
|
|
|
26
|
|
BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
|
27
|
|
information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
|
28
|
|
values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
|
29
|
|
change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
|
30
|
|
information.
|
31
|
|
|
32
|
|
Parameters
|
33
|
|
----------
|
34
|
|
data : obj
|
35
|
|
Any object that can be converted to an az.InferenceData object.
|
36
|
|
Refer to documentation of az.convert_to_dataset for details.
|
37
|
|
If InferenceData, energy variable needs to be found.
|
38
|
|
|
39
|
|
Returns
|
40
|
|
-------
|
41
|
|
z : array
|
42
|
|
The Bayesian fraction of missing information of the model and trace. One element per
|
43
|
|
chain in the trace.
|
44
|
|
|
45
|
|
Examples
|
46
|
|
--------
|
47
|
|
Compute the BFMI of an InferenceData object
|
48
|
|
|
49
|
|
.. ipython::
|
50
|
|
|
51
|
|
In [1]: import arviz as az
|
52
|
|
...: data = az.load_arviz_data('radon')
|
53
|
|
...: az.bfmi(data)
|
54
|
|
|
55
|
|
"""
|
56
|
2
|
if isinstance(data, np.ndarray):
|
57
|
2
|
return _bfmi(data)
|
58
|
|
|
59
|
2
|
dataset = convert_to_dataset(data, group="sample_stats")
|
60
|
2
|
if not hasattr(dataset, "energy"):
|
61
|
2
|
raise TypeError("Energy variable was not found.")
|
62
|
2
|
return _bfmi(dataset.energy)
|
63
|
|
|
64
|
|
|
65
|
2
|
def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
|
66
|
|
r"""Calculate estimate of the effective sample size (ess).
|
67
|
|
|
68
|
|
Parameters
|
69
|
|
----------
|
70
|
|
data : obj
|
71
|
|
Any object that can be converted to an ``az.InferenceData`` object.
|
72
|
|
Refer to documentation of ``az.convert_to_dataset`` for details.
|
73
|
|
For ndarray: shape = (chain, draw).
|
74
|
|
For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
|
75
|
|
var_names : str or list of str
|
76
|
|
Names of variables to include in the return value Dataset.
|
77
|
|
method : str, optional, default "bulk"
|
78
|
|
Select ess method. Valid methods are:
|
79
|
|
|
80
|
|
- "bulk"
|
81
|
|
- "tail" # prob, optional
|
82
|
|
- "quantile" # prob
|
83
|
|
- "mean" (old ess)
|
84
|
|
- "sd"
|
85
|
|
- "median"
|
86
|
|
- "mad" (mean absolute deviance)
|
87
|
|
- "z_scale"
|
88
|
|
- "folded"
|
89
|
|
- "identity"
|
90
|
|
- "local"
|
91
|
|
relative : bool
|
92
|
|
Return relative ess
|
93
|
|
`ress = ess / n`
|
94
|
|
prob : float, or tuple of two floats, optional
|
95
|
|
probability value for "tail", "quantile" or "local" ess functions.
|
96
|
|
|
97
|
|
Returns
|
98
|
|
-------
|
99
|
|
xarray.Dataset
|
100
|
|
Return the effective sample size, :math:`\hat{N}_{eff}`
|
101
|
|
|
102
|
|
Notes
|
103
|
|
-----
|
104
|
|
The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by:
|
105
|
|
|
106
|
|
.. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}
|
107
|
|
|
108
|
|
.. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}
|
109
|
|
|
110
|
|
where :math:`M` is the number of chains, :math:`N` the number of draws,
|
111
|
|
:math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
|
112
|
|
:math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} +
|
113
|
|
\hat{\rho}_{2K+1}` is still positive.
|
114
|
|
|
115
|
|
The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
|
116
|
|
criterion (Geyer, 1992; Geyer, 2011).
|
117
|
|
|
118
|
|
References
|
119
|
|
----------
|
120
|
|
* Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
|
121
|
|
* https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
|
122
|
|
Section 15.4.2
|
123
|
|
* Gelman et al. BDA (2014) Formula 11.8
|
124
|
|
|
125
|
|
Examples
|
126
|
|
--------
|
127
|
|
Calculate the effective_sample_size using the default arguments:
|
128
|
|
|
129
|
|
.. ipython::
|
130
|
|
|
131
|
|
In [1]: import arviz as az
|
132
|
|
...: data = az.load_arviz_data('non_centered_eight')
|
133
|
|
...: az.ess(data)
|
134
|
|
|
135
|
|
Calculate the ress of some of the variables
|
136
|
|
|
137
|
|
.. ipython::
|
138
|
|
|
139
|
|
In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"])
|
140
|
|
|
141
|
|
Calculate the ess using the "tail" method, leaving the `prob` argument at its default
|
142
|
|
value.
|
143
|
|
|
144
|
|
.. ipython::
|
145
|
|
|
146
|
|
In [1]: az.ess(data, method="tail")
|
147
|
|
|
148
|
|
"""
|
149
|
2
|
methods = {
|
150
|
|
"bulk": _ess_bulk,
|
151
|
|
"tail": _ess_tail,
|
152
|
|
"quantile": _ess_quantile,
|
153
|
|
"mean": _ess_mean,
|
154
|
|
"sd": _ess_sd,
|
155
|
|
"median": _ess_median,
|
156
|
|
"mad": _ess_mad,
|
157
|
|
"z_scale": _ess_z_scale,
|
158
|
|
"folded": _ess_folded,
|
159
|
|
"identity": _ess_identity,
|
160
|
|
"local": _ess_local,
|
161
|
|
}
|
162
|
|
|
163
|
2
|
if method not in methods:
|
164
|
2
|
raise TypeError(
|
165
|
|
"ess method {} not found. Valid methods are:\n{}".format(method, "\n ".join(methods))
|
166
|
|
)
|
167
|
2
|
ess_func = methods[method]
|
168
|
|
|
169
|
2
|
if (method == "quantile") and prob is None:
|
170
|
2
|
raise TypeError("Quantile (prob) information needs to be defined.")
|
171
|
|
|
172
|
2
|
if isinstance(data, np.ndarray):
|
173
|
2
|
data = np.atleast_2d(data)
|
174
|
2
|
if len(data.shape) < 3:
|
175
|
2
|
if prob is not None:
|
176
|
2
|
return ess_func( # pylint: disable=unexpected-keyword-arg
|
177
|
|
data, prob=prob, relative=relative
|
178
|
|
)
|
179
|
|
else:
|
180
|
2
|
return ess_func(data, relative=relative)
|
181
|
|
else:
|
182
|
2
|
msg = (
|
183
|
|
"Only uni-dimensional ndarray variables are supported."
|
184
|
|
" Please transform first to dataset with `az.convert_to_dataset`."
|
185
|
|
)
|
186
|
2
|
raise TypeError(msg)
|
187
|
|
|
188
|
2
|
dataset = convert_to_dataset(data, group="posterior")
|
189
|
2
|
var_names = _var_names(var_names, dataset)
|
190
|
|
|
191
|
2
|
dataset = dataset if var_names is None else dataset[var_names]
|
192
|
|
|
193
|
2
|
ufunc_kwargs = {"ravel": False}
|
194
|
2
|
func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative}
|
195
|
2
|
return _wrap_xarray_ufunc(ess_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs)
|
196
|
|
|
197
|
|
|
198
|
2
|
def rhat(data, *, var_names=None, method="rank"):
|
199
|
|
r"""Compute estimate of rank normalized splitR-hat for a set of traces.
|
200
|
|
|
201
|
|
The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance
|
202
|
|
between multiple chains to the variance within each chain. If convergence has been achieved,
|
203
|
|
the between-chain and within-chain variances should be identical. To be most effective in
|
204
|
|
detecting evidence for nonconvergence, each chain should have been initialized to starting
|
205
|
|
values that are dispersed relative to the target distribution.
|
206
|
|
|
207
|
|
Parameters
|
208
|
|
----------
|
209
|
|
data : obj
|
210
|
|
Any object that can be converted to an az.InferenceData object.
|
211
|
|
Refer to documentation of az.convert_to_dataset for details.
|
212
|
|
At least 2 posterior chains are needed to compute this diagnostic of one or more
|
213
|
|
stochastic parameters.
|
214
|
|
For ndarray: shape = (chain, draw).
|
215
|
|
For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
|
216
|
|
var_names : list
|
217
|
|
Names of variables to include in the rhat report
|
218
|
|
method : str
|
219
|
|
Select R-hat method. Valid methods are:
|
220
|
|
- "rank" # recommended by Vehtari et al. (2019)
|
221
|
|
- "split"
|
222
|
|
- "folded"
|
223
|
|
- "z_scale"
|
224
|
|
- "identity"
|
225
|
|
|
226
|
|
Returns
|
227
|
|
-------
|
228
|
|
xarray.Dataset
|
229
|
|
Returns dataset of the potential scale reduction factors, :math:`\hat{R}`
|
230
|
|
|
231
|
|
Notes
|
232
|
|
-----
|
233
|
|
The diagnostic is computed by:
|
234
|
|
|
235
|
|
.. math:: \hat{R} = \frac{\hat{V}}{W}
|
236
|
|
|
237
|
|
where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
|
238
|
|
estimate for the pooled rank-traces. This is the potential scale reduction factor, which
|
239
|
|
converges to unity when each of the traces is a sample from the target posterior. Values
|
240
|
|
greater than one indicate that one or more chains have not yet converged.
|
241
|
|
|
242
|
|
Rank values are calculated over all the chains with `scipy.stats.rankdata`.
|
243
|
|
Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
|
244
|
|
|
245
|
|
References
|
246
|
|
----------
|
247
|
|
* Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
|
248
|
|
* Gelman et al. BDA (2014)
|
249
|
|
* Brooks and Gelman (1998)
|
250
|
|
* Gelman and Rubin (1992)
|
251
|
|
|
252
|
|
Examples
|
253
|
|
--------
|
254
|
|
Calculate the R-hat using the default arguments:
|
255
|
|
|
256
|
|
.. ipython::
|
257
|
|
|
258
|
|
In [1]: import arviz as az
|
259
|
|
...: data = az.load_arviz_data("non_centered_eight")
|
260
|
|
...: az.rhat(data)
|
261
|
|
|
262
|
|
Calculate the R-hat of some variables using the folded method:
|
263
|
|
|
264
|
|
.. ipython::
|
265
|
|
|
266
|
|
In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded")
|
267
|
|
|
268
|
|
"""
|
269
|
2
|
methods = {
|
270
|
|
"rank": _rhat_rank,
|
271
|
|
"split": _rhat_split,
|
272
|
|
"folded": _rhat_folded,
|
273
|
|
"z_scale": _rhat_z_scale,
|
274
|
|
"identity": _rhat_identity,
|
275
|
|
}
|
276
|
2
|
if method not in methods:
|
277
|
2
|
raise TypeError(
|
278
|
|
"R-hat method {} not found. Valid methods are:\n{}".format(
|
279
|
|
method, "\n ".join(methods)
|
280
|
|
)
|
281
|
|
)
|
282
|
2
|
rhat_func = methods[method]
|
283
|
|
|
284
|
2
|
if isinstance(data, np.ndarray):
|
285
|
2
|
data = np.atleast_2d(data)
|
286
|
2
|
if len(data.shape) < 3:
|
287
|
2
|
return rhat_func(data)
|
288
|
|
else:
|
289
|
2
|
msg = (
|
290
|
|
"Only uni-dimensional ndarray variables are supported."
|
291
|
|
" Please transform first to dataset with `az.convert_to_dataset`."
|
292
|
|
)
|
293
|
2
|
raise TypeError(msg)
|
294
|
|
|
295
|
2
|
dataset = convert_to_dataset(data, group="posterior")
|
296
|
2
|
var_names = _var_names(var_names, dataset)
|
297
|
|
|
298
|
2
|
dataset = dataset if var_names is None else dataset[var_names]
|
299
|
|
|
300
|
2
|
ufunc_kwargs = {"ravel": False}
|
301
|
2
|
func_kwargs = {}
|
302
|
2
|
return _wrap_xarray_ufunc(
|
303
|
|
rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
|
304
|
|
)
|
305
|
|
|
306
|
|
|
307
|
2
|
def mcse(data, *, var_names=None, method="mean", prob=None):
|
308
|
|
"""Calculate Markov Chain Standard Error statistic.
|
309
|
|
|
310
|
|
Parameters
|
311
|
|
----------
|
312
|
|
data : obj
|
313
|
|
Any object that can be converted to an az.InferenceData object
|
314
|
|
Refer to documentation of az.convert_to_dataset for details
|
315
|
|
For ndarray: shape = (chain, draw).
|
316
|
|
For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
|
317
|
|
var_names : list
|
318
|
|
Names of variables to include in the rhat report
|
319
|
|
method : str
|
320
|
|
Select mcse method. Valid methods are:
|
321
|
|
- "mean"
|
322
|
|
- "sd"
|
323
|
|
- "median"
|
324
|
|
- "quantile"
|
325
|
|
|
326
|
|
prob : float
|
327
|
|
Quantile information.
|
328
|
|
|
329
|
|
Returns
|
330
|
|
-------
|
331
|
|
xarray.Dataset
|
332
|
|
Return the msce dataset
|
333
|
|
|
334
|
|
Examples
|
335
|
|
--------
|
336
|
|
Calculate the Markov Chain Standard Error using the default arguments:
|
337
|
|
|
338
|
|
.. ipython::
|
339
|
|
|
340
|
|
In [1]: import arviz as az
|
341
|
|
...: data = az.load_arviz_data("non_centered_eight")
|
342
|
|
...: az.mcse(data)
|
343
|
|
|
344
|
|
Calculate the Markov Chain Standard Error using the quantile method:
|
345
|
|
|
346
|
|
.. ipython::
|
347
|
|
|
348
|
|
In [1]: az.mcse(data, method="quantile", prob=0.7)
|
349
|
|
|
350
|
|
"""
|
351
|
2
|
methods = {
|
352
|
|
"mean": _mcse_mean,
|
353
|
|
"sd": _mcse_sd,
|
354
|
|
"median": _mcse_median,
|
355
|
|
"quantile": _mcse_quantile,
|
356
|
|
}
|
357
|
2
|
if method not in methods:
|
358
|
2
|
raise TypeError(
|
359
|
|
"mcse method {} not found. Valid methods are:\n{}".format(
|
360
|
|
method, "\n ".join(methods)
|
361
|
|
)
|
362
|
|
)
|
363
|
2
|
mcse_func = methods[method]
|
364
|
|
|
365
|
2
|
if method == "quantile" and prob is None:
|
366
|
2
|
raise TypeError("Quantile (prob) information needs to be defined.")
|
367
|
|
|
368
|
2
|
if isinstance(data, np.ndarray):
|
369
|
2
|
data = np.atleast_2d(data)
|
370
|
2
|
if len(data.shape) < 3:
|
371
|
2
|
if prob is not None:
|
372
|
2
|
return mcse_func(data, prob=prob) # pylint: disable=unexpected-keyword-arg
|
373
|
|
else:
|
374
|
2
|
return mcse_func(data)
|
375
|
|
else:
|
376
|
2
|
msg = (
|
377
|
|
"Only uni-dimensional ndarray variables are supported."
|
378
|
|
" Please transform first to dataset with `az.convert_to_dataset`."
|
379
|
|
)
|
380
|
2
|
raise TypeError(msg)
|
381
|
|
|
382
|
2
|
dataset = convert_to_dataset(data, group="posterior")
|
383
|
2
|
var_names = _var_names(var_names, dataset)
|
384
|
|
|
385
|
2
|
dataset = dataset if var_names is None else dataset[var_names]
|
386
|
|
|
387
|
2
|
ufunc_kwargs = {"ravel": False}
|
388
|
2
|
func_kwargs = {} if prob is None else {"prob": prob}
|
389
|
2
|
return _wrap_xarray_ufunc(
|
390
|
|
mcse_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
|
391
|
|
)
|
392
|
|
|
393
|
|
|
394
|
2
|
@conditional_jit(forceobj=True)
|
395
|
2
|
def geweke(ary, first=0.1, last=0.5, intervals=20):
|
396
|
|
r"""Compute z-scores for convergence diagnostics.
|
397
|
|
|
398
|
|
Compare the mean of the first % of series with the mean of the last % of series. x is divided
|
399
|
|
into a number of segments for which this difference is computed. If the series is converged,
|
400
|
|
this score should oscillate between -1 and 1.
|
401
|
|
|
402
|
|
Parameters
|
403
|
|
----------
|
404
|
|
ary : 1D array-like
|
405
|
|
The trace of some stochastic parameter.
|
406
|
|
first : float
|
407
|
|
The fraction of series at the beginning of the trace.
|
408
|
|
last : float
|
409
|
|
The fraction of series at the end to be compared with the section
|
410
|
|
at the beginning.
|
411
|
|
intervals : int
|
412
|
|
The number of segments.
|
413
|
|
|
414
|
|
Returns
|
415
|
|
-------
|
416
|
|
scores : list [[]]
|
417
|
|
Return a list of [i, score], where i is the starting index for each interval and score the
|
418
|
|
Geweke score on the interval.
|
419
|
|
|
420
|
|
Notes
|
421
|
|
-----
|
422
|
|
The Geweke score on some series x is computed by:
|
423
|
|
|
424
|
|
.. math:: \frac{E[x_s] - E[x_e]}{\sqrt{V[x_s] + V[x_e]}}
|
425
|
|
|
426
|
|
where :math:`E` stands for the mean, :math:`V` the variance,
|
427
|
|
:math:`x_s` a section at the start of the series and
|
428
|
|
:math:`x_e` a section at the end of the series.
|
429
|
|
|
430
|
|
References
|
431
|
|
----------
|
432
|
|
* Geweke (1992)
|
433
|
|
"""
|
434
|
|
# Filter out invalid intervals
|
435
|
2
|
return _geweke(ary, first, last, intervals)
|
436
|
|
|
437
|
|
|
438
|
2
|
def _geweke(ary, first, last, intervals):
|
439
|
2
|
_numba_flag = Numba.numba_flag
|
440
|
2
|
for interval in (first, last):
|
441
|
2
|
if interval <= 0 or interval >= 1:
|
442
|
2
|
raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
|
443
|
2
|
if first + last >= 1:
|
444
|
2
|
raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
|
445
|
|
|
446
|
|
# Initialize list of z-scores
|
447
|
2
|
zscores = []
|
448
|
|
|
449
|
|
# Last index value
|
450
|
2
|
end = len(ary) - 1
|
451
|
|
|
452
|
|
# Start intervals going up to the <last>% of the chain
|
453
|
2
|
last_start_idx = (1 - last) * end
|
454
|
|
|
455
|
|
# Calculate starting indices
|
456
|
2
|
start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)
|
457
|
|
|
458
|
|
# Loop over start indices
|
459
|
2
|
for start in start_indices:
|
460
|
|
# Calculate slices
|
461
|
2
|
first_slice = ary[start : start + int(first * (end - start))]
|
462
|
2
|
last_slice = ary[int(end - last * (end - start)) :]
|
463
|
|
|
464
|
2
|
z_score = first_slice.mean() - last_slice.mean()
|
465
|
2
|
if _numba_flag:
|
466
|
2
|
z_score /= _sqrt(svar(first_slice), svar(last_slice))
|
467
|
|
else:
|
468
|
2
|
z_score /= np.sqrt(first_slice.var() + last_slice.var())
|
469
|
|
|
470
|
2
|
zscores.append([start, z_score])
|
471
|
|
|
472
|
2
|
return np.array(zscores)
|
473
|
|
|
474
|
|
|
475
|
2
|
def ks_summary(pareto_tail_indices):
|
476
|
|
"""Display a summary of Pareto tail indices.
|
477
|
|
|
478
|
|
Parameters
|
479
|
|
----------
|
480
|
|
pareto_tail_indices : array
|
481
|
|
Pareto tail indices.
|
482
|
|
|
483
|
|
Returns
|
484
|
|
-------
|
485
|
|
df_k : dataframe
|
486
|
|
Dataframe containing k diagnostic values.
|
487
|
|
"""
|
488
|
2
|
_numba_flag = Numba.numba_flag
|
489
|
2
|
if _numba_flag:
|
490
|
2
|
bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
|
491
|
2
|
kcounts, *_ = _histogram(pareto_tail_indices, bins)
|
492
|
|
else:
|
493
|
2
|
kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
|
494
|
2
|
kprop = kcounts / len(pareto_tail_indices) * 100
|
495
|
2
|
df_k = pd.DataFrame(
|
496
|
|
dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
|
497
|
|
).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: " (0.7, 1]", 3: " (1, Inf)"})
|
498
|
|
|
499
|
2
|
if np.sum(kcounts[1:]) == 0:
|
500
|
2
|
warnings.warn("All Pareto k estimates are good (k < 0.5)")
|
501
|
2
|
elif np.sum(kcounts[2:]) == 0:
|
502
|
2
|
warnings.warn("All Pareto k estimates are ok (k < 0.7)")
|
503
|
|
|
504
|
2
|
return df_k
|
505
|
|
|
506
|
|
|
507
|
2
|
def _bfmi(energy):
|
508
|
|
r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
|
509
|
|
|
510
|
|
BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
|
511
|
|
information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
|
512
|
|
values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
|
513
|
|
change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
|
514
|
|
information.
|
515
|
|
|
516
|
|
Parameters
|
517
|
|
----------
|
518
|
|
energy : NumPy array
|
519
|
|
Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically,
|
520
|
|
after converting a trace or fit to InferenceData, the energy will be in
|
521
|
|
`data.sample_stats.energy`.
|
522
|
|
|
523
|
|
Returns
|
524
|
|
-------
|
525
|
|
z : array
|
526
|
|
The Bayesian fraction of missing information of the model and trace. One element per
|
527
|
|
chain in the trace.
|
528
|
|
"""
|
529
|
2
|
energy_mat = np.atleast_2d(energy)
|
530
|
2
|
num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1) # pylint: disable=no-member
|
531
|
2
|
if energy_mat.ndim == 2:
|
532
|
2
|
den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1)
|
533
|
|
else:
|
534
|
2
|
den = np.var(energy, axis=1, ddof=1)
|
535
|
2
|
return num / den
|
536
|
|
|
537
|
|
|
538
|
2
|
def _backtransform_ranks(arr, c=3 / 8): # pylint: disable=invalid-name
|
539
|
|
"""Backtransformation of ranks.
|
540
|
|
|
541
|
|
Parameters
|
542
|
|
----------
|
543
|
|
arr : np.ndarray
|
544
|
|
Ranks array
|
545
|
|
c : float
|
546
|
|
Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).
|
547
|
|
|
548
|
|
Returns
|
549
|
|
-------
|
550
|
|
np.ndarray
|
551
|
|
|
552
|
|
References
|
553
|
|
----------
|
554
|
|
Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
|
555
|
|
"""
|
556
|
2
|
arr = np.asarray(arr)
|
557
|
2
|
size = arr.size
|
558
|
2
|
return (arr - c) / (size - 2 * c + 1)
|
559
|
|
|
560
|
|
|
561
|
2
|
def _z_scale(ary):
|
562
|
|
"""Calculate z_scale.
|
563
|
|
|
564
|
|
Parameters
|
565
|
|
----------
|
566
|
|
ary : np.ndarray
|
567
|
|
|
568
|
|
Returns
|
569
|
|
-------
|
570
|
|
np.ndarray
|
571
|
|
"""
|
572
|
2
|
ary = np.asarray(ary)
|
573
|
2
|
rank = stats.rankdata(ary, method="average")
|
574
|
2
|
rank = _backtransform_ranks(rank)
|
575
|
2
|
z = stats.norm.ppf(rank)
|
576
|
2
|
z = z.reshape(ary.shape)
|
577
|
2
|
return z
|
578
|
|
|
579
|
|
|
580
|
2
|
def _split_chains(ary):
|
581
|
|
"""Split and stack chains."""
|
582
|
2
|
ary = np.asarray(ary)
|
583
|
2
|
if len(ary.shape) > 1:
|
584
|
2
|
_, n_draw = ary.shape
|
585
|
|
else:
|
586
|
2
|
ary = np.atleast_2d(ary)
|
587
|
2
|
_, n_draw = ary.shape
|
588
|
2
|
half = n_draw // 2
|
589
|
2
|
return _stack(ary[:, :half], ary[:, -half:])
|
590
|
|
|
591
|
|
|
592
|
2
|
def _z_fold(ary):
|
593
|
|
"""Fold and z-scale values."""
|
594
|
2
|
ary = np.asarray(ary)
|
595
|
2
|
ary = abs(ary - np.median(ary))
|
596
|
2
|
ary = _z_scale(ary)
|
597
|
2
|
return ary
|
598
|
|
|
599
|
|
|
600
|
2
|
def _rhat(ary):
|
601
|
|
"""Compute the rhat for a 2d array."""
|
602
|
2
|
_numba_flag = Numba.numba_flag
|
603
|
2
|
ary = np.asarray(ary, dtype=float)
|
604
|
2
|
if _not_valid(ary, check_shape=False):
|
605
|
2
|
return np.nan
|
606
|
2
|
_, num_samples = ary.shape
|
607
|
|
|
608
|
|
# Calculate chain mean
|
609
|
2
|
chain_mean = np.mean(ary, axis=1)
|
610
|
|
# Calculate chain variance
|
611
|
2
|
chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
|
612
|
|
# Calculate between-chain variance
|
613
|
2
|
between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
|
614
|
|
# Calculate within-chain variance
|
615
|
2
|
within_chain_variance = np.mean(chain_var)
|
616
|
|
# Estimate of marginal posterior variance
|
617
|
2
|
rhat_value = np.sqrt(
|
618
|
|
(between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
|
619
|
|
)
|
620
|
2
|
return rhat_value
|
621
|
|
|
622
|
|
|
623
|
2
|
def _rhat_rank(ary):
|
624
|
|
"""Compute the rank normalized rhat for 2d array.
|
625
|
|
|
626
|
|
Computation follows https://arxiv.org/abs/1903.08008
|
627
|
|
"""
|
628
|
2
|
ary = np.asarray(ary)
|
629
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
630
|
2
|
return np.nan
|
631
|
2
|
split_ary = _split_chains(ary)
|
632
|
2
|
rhat_bulk = _rhat(_z_scale(split_ary))
|
633
|
|
|
634
|
2
|
split_ary_folded = abs(split_ary - np.median(split_ary))
|
635
|
2
|
rhat_tail = _rhat(_z_scale(split_ary_folded))
|
636
|
|
|
637
|
2
|
rhat_rank = max(rhat_bulk, rhat_tail)
|
638
|
2
|
return rhat_rank
|
639
|
|
|
640
|
|
|
641
|
2
|
def _rhat_folded(ary):
|
642
|
|
"""Calculate split-Rhat for folded z-values."""
|
643
|
2
|
ary = np.asarray(ary)
|
644
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
645
|
2
|
return np.nan
|
646
|
2
|
ary = _z_fold(_split_chains(ary))
|
647
|
2
|
return _rhat(ary)
|
648
|
|
|
649
|
|
|
650
|
2
|
def _rhat_z_scale(ary):
|
651
|
2
|
ary = np.asarray(ary)
|
652
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
653
|
2
|
return np.nan
|
654
|
2
|
return _rhat(_z_scale(_split_chains(ary)))
|
655
|
|
|
656
|
|
|
657
|
2
|
def _rhat_split(ary):
|
658
|
2
|
ary = np.asarray(ary)
|
659
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
660
|
2
|
return np.nan
|
661
|
2
|
return _rhat(_split_chains(ary))
|
662
|
|
|
663
|
|
|
664
|
2
|
def _rhat_identity(ary):
|
665
|
2
|
ary = np.asarray(ary)
|
666
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
667
|
2
|
return np.nan
|
668
|
2
|
return _rhat(ary)
|
669
|
|
|
670
|
|
|
671
|
2
|
def _ess(ary, relative=False):
|
672
|
|
"""Compute the effective sample size for a 2D array."""
|
673
|
2
|
_numba_flag = Numba.numba_flag
|
674
|
2
|
ary = np.asarray(ary, dtype=float)
|
675
|
2
|
if _not_valid(ary, check_shape=False):
|
676
|
2
|
return np.nan
|
677
|
2
|
if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution: # pylint: disable=no-member
|
678
|
2
|
return ary.size
|
679
|
2
|
if len(ary.shape) < 2:
|
680
|
2
|
ary = np.atleast_2d(ary)
|
681
|
2
|
n_chain, n_draw = ary.shape
|
682
|
2
|
acov = _autocov(ary, axis=1)
|
683
|
2
|
chain_mean = ary.mean(axis=1)
|
684
|
2
|
mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
|
685
|
2
|
var_plus = mean_var * (n_draw - 1.0) / n_draw
|
686
|
2
|
if n_chain > 1:
|
687
|
2
|
var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
|
688
|
|
|
689
|
2
|
rho_hat_t = np.zeros(n_draw)
|
690
|
2
|
rho_hat_even = 1.0
|
691
|
2
|
rho_hat_t[0] = rho_hat_even
|
692
|
2
|
rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
|
693
|
2
|
rho_hat_t[1] = rho_hat_odd
|
694
|
|
|
695
|
|
# Geyer's initial positive sequence
|
696
|
2
|
t = 1
|
697
|
2
|
while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
|
698
|
2
|
rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
|
699
|
2
|
rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
|
700
|
2
|
if (rho_hat_even + rho_hat_odd) >= 0:
|
701
|
2
|
rho_hat_t[t + 1] = rho_hat_even
|
702
|
2
|
rho_hat_t[t + 2] = rho_hat_odd
|
703
|
2
|
t += 2
|
704
|
|
|
705
|
2
|
max_t = t - 2
|
706
|
|
# improve estimation
|
707
|
2
|
if rho_hat_even > 0:
|
708
|
2
|
rho_hat_t[max_t + 1] = rho_hat_even
|
709
|
|
# Geyer's initial monotone sequence
|
710
|
2
|
t = 1
|
711
|
2
|
while t <= max_t - 2:
|
712
|
2
|
if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
|
713
|
2
|
rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
|
714
|
2
|
rho_hat_t[t + 2] = rho_hat_t[t + 1]
|
715
|
2
|
t += 2
|
716
|
|
|
717
|
2
|
ess = n_chain * n_draw
|
718
|
2
|
tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
|
719
|
2
|
tau_hat = max(tau_hat, 1 / np.log10(ess))
|
720
|
2
|
ess = (1 if relative else ess) / tau_hat
|
721
|
2
|
if np.isnan(rho_hat_t).any():
|
722
|
0
|
ess = np.nan
|
723
|
2
|
return ess
|
724
|
|
|
725
|
|
|
726
|
2
|
def _ess_bulk(ary, relative=False):
|
727
|
|
"""Compute the effective sample size for the bulk."""
|
728
|
2
|
ary = np.asarray(ary)
|
729
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
730
|
2
|
return np.nan
|
731
|
2
|
z_scaled = _z_scale(_split_chains(ary))
|
732
|
2
|
ess_bulk = _ess(z_scaled, relative=relative)
|
733
|
2
|
return ess_bulk
|
734
|
|
|
735
|
|
|
736
|
2
|
def _ess_tail(ary, prob=None, relative=False):
|
737
|
|
"""Compute the effective sample size for the tail.
|
738
|
|
|
739
|
|
If `prob` defined, ess = min(qess(prob), qess(1-prob))
|
740
|
|
"""
|
741
|
2
|
if prob is None:
|
742
|
2
|
prob = (0.05, 0.95)
|
743
|
2
|
elif not isinstance(prob, Sequence):
|
744
|
2
|
prob = (prob, 1 - prob)
|
745
|
|
|
746
|
2
|
ary = np.asarray(ary)
|
747
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
748
|
2
|
return np.nan
|
749
|
|
|
750
|
2
|
prob_low, prob_high = prob
|
751
|
2
|
quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
|
752
|
2
|
quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
|
753
|
2
|
return min(quantile_low_ess, quantile_high_ess)
|
754
|
|
|
755
|
|
|
756
|
2
|
def _ess_mean(ary, relative=False):
|
757
|
|
"""Compute the effective sample size for the mean."""
|
758
|
2
|
ary = np.asarray(ary)
|
759
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
760
|
2
|
return np.nan
|
761
|
2
|
return _ess(_split_chains(ary), relative=relative)
|
762
|
|
|
763
|
|
|
764
|
2
|
def _ess_sd(ary, relative=False):
|
765
|
|
"""Compute the effective sample size for the sd."""
|
766
|
2
|
ary = np.asarray(ary)
|
767
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
768
|
2
|
return np.nan
|
769
|
2
|
ary = _split_chains(ary)
|
770
|
2
|
return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative))
|
771
|
|
|
772
|
|
|
773
|
2
|
def _ess_quantile(ary, prob, relative=False):
|
774
|
|
"""Compute the effective sample size for the specific residual."""
|
775
|
2
|
ary = np.asarray(ary)
|
776
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
777
|
2
|
return np.nan
|
778
|
2
|
if prob is None:
|
779
|
2
|
raise TypeError("Prob not defined.")
|
780
|
2
|
(quantile,) = _quantile(ary, prob)
|
781
|
2
|
iquantile = ary <= quantile
|
782
|
2
|
return _ess(_split_chains(iquantile), relative=relative)
|
783
|
|
|
784
|
|
|
785
|
2
|
def _ess_local(ary, prob, relative=False):
|
786
|
|
"""Compute the effective sample size for the specific residual."""
|
787
|
2
|
ary = np.asarray(ary)
|
788
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
789
|
2
|
return np.nan
|
790
|
2
|
if prob is None:
|
791
|
0
|
raise TypeError("Prob not defined.")
|
792
|
2
|
if len(prob) != 2:
|
793
|
2
|
raise ValueError("Prob argument in ess local must be upper and lower bound")
|
794
|
2
|
quantile = _quantile(ary, prob)
|
795
|
2
|
iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
|
796
|
2
|
return _ess(_split_chains(iquantile), relative=relative)
|
797
|
|
|
798
|
|
|
799
|
2
|
def _ess_z_scale(ary, relative=False):
|
800
|
|
"""Calculate ess for z-scaLe."""
|
801
|
2
|
ary = np.asarray(ary)
|
802
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
803
|
2
|
return np.nan
|
804
|
2
|
return _ess(_z_scale(_split_chains(ary)), relative=relative)
|
805
|
|
|
806
|
|
|
807
|
2
|
def _ess_folded(ary, relative=False):
|
808
|
|
"""Calculate split-ess for folded data."""
|
809
|
2
|
ary = np.asarray(ary)
|
810
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
811
|
2
|
return np.nan
|
812
|
2
|
return _ess(_z_fold(_split_chains(ary)), relative=relative)
|
813
|
|
|
814
|
|
|
815
|
2
|
def _ess_median(ary, relative=False):
|
816
|
|
"""Calculate split-ess for median."""
|
817
|
2
|
ary = np.asarray(ary)
|
818
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
819
|
2
|
return np.nan
|
820
|
2
|
return _ess_quantile(ary, 0.5, relative=relative)
|
821
|
|
|
822
|
|
|
823
|
2
|
def _ess_mad(ary, relative=False):
|
824
|
|
"""Calculate split-ess for mean absolute deviance."""
|
825
|
2
|
ary = np.asarray(ary)
|
826
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
827
|
2
|
return np.nan
|
828
|
2
|
ary = abs(ary - np.median(ary))
|
829
|
2
|
ary = ary <= np.median(ary)
|
830
|
2
|
ary = _z_scale(_split_chains(ary))
|
831
|
2
|
return _ess(ary, relative=relative)
|
832
|
|
|
833
|
|
|
834
|
2
|
def _ess_identity(ary, relative=False):
|
835
|
|
"""Calculate ess."""
|
836
|
2
|
ary = np.asarray(ary)
|
837
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
838
|
2
|
return np.nan
|
839
|
2
|
return _ess(ary, relative=relative)
|
840
|
|
|
841
|
|
|
842
|
2
|
def _mcse_mean(ary):
|
843
|
|
"""Compute the Markov Chain mean error."""
|
844
|
2
|
_numba_flag = Numba.numba_flag
|
845
|
2
|
ary = np.asarray(ary)
|
846
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
847
|
2
|
return np.nan
|
848
|
2
|
ess = _ess_mean(ary)
|
849
|
2
|
if _numba_flag:
|
850
|
2
|
sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
|
851
|
|
else:
|
852
|
2
|
sd = np.std(ary, ddof=1)
|
853
|
2
|
mcse_mean_value = sd / np.sqrt(ess)
|
854
|
2
|
return mcse_mean_value
|
855
|
|
|
856
|
|
|
857
|
2
|
def _mcse_sd(ary):
|
858
|
|
"""Compute the Markov Chain sd error."""
|
859
|
2
|
_numba_flag = Numba.numba_flag
|
860
|
2
|
ary = np.asarray(ary)
|
861
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
862
|
2
|
return np.nan
|
863
|
2
|
ess = _ess_sd(ary)
|
864
|
2
|
if _numba_flag:
|
865
|
2
|
sd = np.float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
|
866
|
|
else:
|
867
|
2
|
sd = np.std(ary, ddof=1)
|
868
|
2
|
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
|
869
|
2
|
mcse_sd_value = sd * fac_mcse_sd
|
870
|
2
|
return mcse_sd_value
|
871
|
|
|
872
|
|
|
873
|
2
|
def _mcse_median(ary):
|
874
|
|
"""Compute the Markov Chain median error."""
|
875
|
2
|
return _mcse_quantile(ary, 0.5)
|
876
|
|
|
877
|
|
|
878
|
2
|
def _mcse_quantile(ary, prob):
|
879
|
|
"""Compute the Markov Chain quantile error at quantile=prob."""
|
880
|
2
|
ary = np.asarray(ary)
|
881
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
882
|
2
|
return np.nan
|
883
|
2
|
ess = _ess_quantile(ary, prob)
|
884
|
2
|
probability = [0.1586553, 0.8413447]
|
885
|
2
|
with np.errstate(invalid="ignore"):
|
886
|
2
|
ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
|
887
|
2
|
sorted_ary = np.sort(ary.ravel())
|
888
|
2
|
size = sorted_ary.size
|
889
|
2
|
ppf_size = ppf * size - 1
|
890
|
2
|
th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
|
891
|
2
|
th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
|
892
|
2
|
return (th2 - th1) / 2
|
893
|
|
|
894
|
|
|
895
|
2
|
def _mc_error(ary, batches=5, circular=False):
|
896
|
|
"""Calculate the simulation standard error, accounting for non-independent samples.
|
897
|
|
|
898
|
|
The trace is divided into batches, and the standard deviation of the batch
|
899
|
|
means is calculated.
|
900
|
|
|
901
|
|
Parameters
|
902
|
|
----------
|
903
|
|
ary : Numpy array
|
904
|
|
An array containing MCMC samples
|
905
|
|
batches : integer
|
906
|
|
Number of batches
|
907
|
|
circular : bool
|
908
|
|
Whether to compute the error taking into account `ary` is a circular variable
|
909
|
|
(in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
|
910
|
|
|
911
|
|
Returns
|
912
|
|
-------
|
913
|
|
mc_error : float
|
914
|
|
Simulation standard error
|
915
|
|
"""
|
916
|
2
|
_numba_flag = Numba.numba_flag
|
917
|
2
|
if ary.ndim > 1:
|
918
|
|
|
919
|
2
|
dims = np.shape(ary)
|
920
|
2
|
trace = np.transpose([t.ravel() for t in ary])
|
921
|
|
|
922
|
2
|
return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])
|
923
|
|
|
924
|
|
else:
|
925
|
2
|
if _not_valid(ary, check_shape=False):
|
926
|
2
|
return np.nan
|
927
|
2
|
if batches == 1:
|
928
|
2
|
if circular:
|
929
|
2
|
if _numba_flag:
|
930
|
2
|
std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
|
931
|
|
else:
|
932
|
0
|
std = stats.circstd(ary, high=np.pi, low=-np.pi)
|
933
|
|
else:
|
934
|
2
|
if _numba_flag:
|
935
|
2
|
std = np.float(_sqrt(svar(ary), np.zeros(1)))
|
936
|
|
else:
|
937
|
2
|
std = np.std(ary)
|
938
|
2
|
return std / np.sqrt(len(ary))
|
939
|
|
|
940
|
2
|
batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))
|
941
|
|
|
942
|
2
|
if circular:
|
943
|
2
|
means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
|
944
|
2
|
if _numba_flag:
|
945
|
2
|
std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
|
946
|
|
else:
|
947
|
2
|
std = stats.circstd(means, high=np.pi, low=-np.pi)
|
948
|
|
else:
|
949
|
2
|
means = np.mean(batched_traces, 1)
|
950
|
2
|
if _numba_flag:
|
951
|
2
|
std = _sqrt(svar(means), np.zeros(1))
|
952
|
|
else:
|
953
|
2
|
std = np.std(means)
|
954
|
|
|
955
|
2
|
return std / np.sqrt(batches)
|
956
|
|
|
957
|
|
|
958
|
2
|
def _multichain_statistics(ary):
|
959
|
|
"""Calculate efficiently multichain statistics for summary.
|
960
|
|
|
961
|
|
Parameters
|
962
|
|
----------
|
963
|
|
ary : numpy.ndarray
|
964
|
|
|
965
|
|
Returns
|
966
|
|
-------
|
967
|
|
tuple
|
968
|
|
Order of return parameters is
|
969
|
|
- mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat
|
970
|
|
"""
|
971
|
2
|
ary = np.atleast_2d(ary)
|
972
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
|
973
|
2
|
return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
|
974
|
|
# ess mean
|
975
|
2
|
ess_mean_value = _ess_mean(ary)
|
976
|
|
|
977
|
|
# ess sd
|
978
|
2
|
ess_sd_value = _ess_sd(ary)
|
979
|
|
|
980
|
|
# ess bulk
|
981
|
2
|
z_split = _z_scale(_split_chains(ary))
|
982
|
2
|
ess_bulk_value = _ess(z_split)
|
983
|
|
|
984
|
|
# ess tail
|
985
|
2
|
quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
|
986
|
2
|
iquantile05 = ary <= quantile05
|
987
|
2
|
quantile05_ess = _ess(_split_chains(iquantile05))
|
988
|
2
|
iquantile95 = ary <= quantile95
|
989
|
2
|
quantile95_ess = _ess(_split_chains(iquantile95))
|
990
|
2
|
ess_tail_value = min(quantile05_ess, quantile95_ess)
|
991
|
|
|
992
|
2
|
if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
|
993
|
2
|
rhat_value = np.nan
|
994
|
|
else:
|
995
|
|
# r_hat
|
996
|
2
|
rhat_bulk = _rhat(z_split)
|
997
|
2
|
ary_folded = np.abs(ary - np.median(ary))
|
998
|
2
|
rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
|
999
|
2
|
rhat_value = max(rhat_bulk, rhat_tail)
|
1000
|
|
|
1001
|
|
# mcse_mean
|
1002
|
2
|
sd = np.std(ary, ddof=1)
|
1003
|
2
|
mcse_mean_value = sd / np.sqrt(ess_mean_value)
|
1004
|
|
|
1005
|
|
# mcse_sd
|
1006
|
2
|
fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
|
1007
|
2
|
mcse_sd_value = sd * fac_mcse_sd
|
1008
|
|
|
1009
|
2
|
return (
|
1010
|
|
mcse_mean_value,
|
1011
|
|
mcse_sd_value,
|
1012
|
|
ess_mean_value,
|
1013
|
|
ess_sd_value,
|
1014
|
|
ess_bulk_value,
|
1015
|
|
ess_tail_value,
|
1016
|
|
rhat_value,
|
1017
|
|
)
|