1
# pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name
2 2
"""Diagnostic functions for ArviZ."""
3 2
import warnings
4 2
from collections.abc import Sequence
5

6 2
import numpy as np
7 2
import pandas as pd
8 2
from scipy import stats
9

10 2
from ..data import convert_to_dataset
11 2
from ..utils import Numba, _numba_var, _stack, _var_names, conditional_jit
12 2
from .density_utils import histogram as _histogram
13 2
from .stats_utils import _circular_standard_deviation, _sqrt
14 2
from .stats_utils import autocov as _autocov
15 2
from .stats_utils import not_valid as _not_valid
16 2
from .stats_utils import quantile as _quantile
17 2
from .stats_utils import stats_variance_2d as svar
18 2
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
19

20 2
__all__ = ["bfmi", "ess", "rhat", "mcse", "geweke"]
21

22

23 2
def bfmi(data):
24
    r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
25

26
    BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
27
    information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
28
    values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
29
    change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
30
    information.
31

32
    Parameters
33
    ----------
34
    data : obj
35
        Any object that can be converted to an az.InferenceData object.
36
        Refer to documentation of az.convert_to_dataset for details.
37
        If InferenceData, energy variable needs to be found.
38

39
    Returns
40
    -------
41
    z : array
42
        The Bayesian fraction of missing information of the model and trace. One element per
43
        chain in the trace.
44

45
    Examples
46
    --------
47
    Compute the BFMI of an InferenceData object
48

49
    .. ipython::
50

51
        In [1]: import arviz as az
52
           ...: data = az.load_arviz_data('radon')
53
           ...: az.bfmi(data)
54

55
    """
56 2
    if isinstance(data, np.ndarray):
57 2
        return _bfmi(data)
58

59 2
    dataset = convert_to_dataset(data, group="sample_stats")
60 2
    if not hasattr(dataset, "energy"):
61 2
        raise TypeError("Energy variable was not found.")
62 2
    return _bfmi(dataset.energy)
63

64

65 2
def ess(data, *, var_names=None, method="bulk", relative=False, prob=None):
66
    r"""Calculate estimate of the effective sample size (ess).
67

68
    Parameters
69
    ----------
70
    data : obj
71
        Any object that can be converted to an ``az.InferenceData`` object.
72
        Refer to documentation of ``az.convert_to_dataset`` for details.
73
        For ndarray: shape = (chain, draw).
74
        For n-dimensional ndarray transform first to dataset with ``az.convert_to_dataset``.
75
    var_names : str or list of str
76
        Names of variables to include in the return value Dataset.
77
    method : str, optional, default "bulk"
78
        Select ess method. Valid methods are:
79

80
        - "bulk"
81
        - "tail"     # prob, optional
82
        - "quantile" # prob
83
        - "mean" (old ess)
84
        - "sd"
85
        - "median"
86
        - "mad" (mean absolute deviance)
87
        - "z_scale"
88
        - "folded"
89
        - "identity"
90
        - "local"
91
    relative : bool
92
        Return relative ess
93
        `ress = ess / n`
94
    prob : float, or tuple of two floats, optional
95
        probability value for "tail", "quantile" or "local" ess functions.
96

97
    Returns
98
    -------
99
    xarray.Dataset
100
        Return the effective sample size, :math:`\hat{N}_{eff}`
101

102
    Notes
103
    -----
104
    The basic ess (:math:`N_{\mathit{eff}}`) diagnostic is computed by:
105

106
    .. math:: \hat{N}_{\mathit{eff}} = \frac{MN}{\hat{\tau}}
107

108
    .. math:: \hat{\tau} = -1 + 2 \sum_{t'=0}^K \hat{P}_{t'}
109

110
    where :math:`M` is the number of chains, :math:`N` the number of draws,
111
    :math:`\hat{\rho}_t` is the estimated _autocorrelation at lag :math:`t`, and
112
    :math:`K` is the last integer for which :math:`\hat{P}_{K} = \hat{\rho}_{2K} +
113
    \hat{\rho}_{2K+1}` is still positive.
114

115
    The current implementation is similar to Stan, which uses Geyer's initial monotone sequence
116
    criterion (Geyer, 1992; Geyer, 2011).
117

118
    References
119
    ----------
120
    * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
121
    * https://mc-stan.org/docs/2_18/reference-manual/effective-sample-size-section.html
122
      Section 15.4.2
123
    * Gelman et al. BDA (2014) Formula 11.8
124

125
    Examples
126
    --------
127
    Calculate the effective_sample_size using the default arguments:
128

129
    .. ipython::
130

131
        In [1]: import arviz as az
132
           ...: data = az.load_arviz_data('non_centered_eight')
133
           ...: az.ess(data)
134

135
    Calculate the ress of some of the variables
136

137
    .. ipython::
138

139
        In [1]: az.ess(data, relative=True, var_names=["mu", "theta_t"])
140

141
    Calculate the ess using the "tail" method, leaving the `prob` argument at its default
142
    value.
143

144
    .. ipython::
145

146
        In [1]: az.ess(data, method="tail")
147

148
    """
149 2
    methods = {
150
        "bulk": _ess_bulk,
151
        "tail": _ess_tail,
152
        "quantile": _ess_quantile,
153
        "mean": _ess_mean,
154
        "sd": _ess_sd,
155
        "median": _ess_median,
156
        "mad": _ess_mad,
157
        "z_scale": _ess_z_scale,
158
        "folded": _ess_folded,
159
        "identity": _ess_identity,
160
        "local": _ess_local,
161
    }
162

163 2
    if method not in methods:
164 2
        raise TypeError(
165
            "ess method {} not found. Valid methods are:\n{}".format(method, "\n    ".join(methods))
166
        )
167 2
    ess_func = methods[method]
168

169 2
    if (method == "quantile") and prob is None:
170 2
        raise TypeError("Quantile (prob) information needs to be defined.")
171

172 2
    if isinstance(data, np.ndarray):
173 2
        data = np.atleast_2d(data)
174 2
        if len(data.shape) < 3:
175 2
            if prob is not None:
176 2
                return ess_func(  # pylint: disable=unexpected-keyword-arg
177
                    data, prob=prob, relative=relative
178
                )
179
            else:
180 2
                return ess_func(data, relative=relative)
181
        else:
182 2
            msg = (
183
                "Only uni-dimensional ndarray variables are supported."
184
                " Please transform first to dataset with `az.convert_to_dataset`."
185
            )
186 2
            raise TypeError(msg)
187

188 2
    dataset = convert_to_dataset(data, group="posterior")
189 2
    var_names = _var_names(var_names, dataset)
190

191 2
    dataset = dataset if var_names is None else dataset[var_names]
192

193 2
    ufunc_kwargs = {"ravel": False}
194 2
    func_kwargs = {"relative": relative} if prob is None else {"prob": prob, "relative": relative}
195 2
    return _wrap_xarray_ufunc(ess_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs)
196

197

198 2
def rhat(data, *, var_names=None, method="rank"):
199
    r"""Compute estimate of rank normalized splitR-hat for a set of traces.
200

201
    The rank normalized R-hat diagnostic tests for lack of convergence by comparing the variance
202
    between multiple chains to the variance within each chain. If convergence has been achieved,
203
    the between-chain and within-chain variances should be identical. To be most effective in
204
    detecting evidence for nonconvergence, each chain should have been initialized to starting
205
    values that are dispersed relative to the target distribution.
206

207
    Parameters
208
    ----------
209
    data : obj
210
        Any object that can be converted to an az.InferenceData object.
211
        Refer to documentation of az.convert_to_dataset for details.
212
        At least 2 posterior chains are needed to compute this diagnostic of one or more
213
        stochastic parameters.
214
        For ndarray: shape = (chain, draw).
215
        For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
216
    var_names : list
217
        Names of variables to include in the rhat report
218
    method : str
219
        Select R-hat method. Valid methods are:
220
        - "rank"        # recommended by Vehtari et al. (2019)
221
        - "split"
222
        - "folded"
223
        - "z_scale"
224
        - "identity"
225

226
    Returns
227
    -------
228
    xarray.Dataset
229
      Returns dataset of the potential scale reduction factors, :math:`\hat{R}`
230

231
    Notes
232
    -----
233
    The diagnostic is computed by:
234

235
      .. math:: \hat{R} = \frac{\hat{V}}{W}
236

237
    where :math:`W` is the within-chain variance and :math:`\hat{V}` is the posterior variance
238
    estimate for the pooled rank-traces. This is the potential scale reduction factor, which
239
    converges to unity when each of the traces is a sample from the target posterior. Values
240
    greater than one indicate that one or more chains have not yet converged.
241

242
    Rank values are calculated over all the chains with `scipy.stats.rankdata`.
243
    Each chain is split in two and normalized with the z-transform following Vehtari et al. (2019).
244

245
    References
246
    ----------
247
    * Vehtari et al. (2019) see https://arxiv.org/abs/1903.08008
248
    * Gelman et al. BDA (2014)
249
    * Brooks and Gelman (1998)
250
    * Gelman and Rubin (1992)
251

252
    Examples
253
    --------
254
    Calculate the R-hat using the default arguments:
255

256
    .. ipython::
257

258
        In [1]: import arviz as az
259
           ...: data = az.load_arviz_data("non_centered_eight")
260
           ...: az.rhat(data)
261

262
    Calculate the R-hat of some variables using the folded method:
263

264
    .. ipython::
265

266
        In [1]: az.rhat(data, var_names=["mu", "theta_t"], method="folded")
267

268
    """
269 2
    methods = {
270
        "rank": _rhat_rank,
271
        "split": _rhat_split,
272
        "folded": _rhat_folded,
273
        "z_scale": _rhat_z_scale,
274
        "identity": _rhat_identity,
275
    }
276 2
    if method not in methods:
277 2
        raise TypeError(
278
            "R-hat method {} not found. Valid methods are:\n{}".format(
279
                method, "\n    ".join(methods)
280
            )
281
        )
282 2
    rhat_func = methods[method]
283

284 2
    if isinstance(data, np.ndarray):
285 2
        data = np.atleast_2d(data)
286 2
        if len(data.shape) < 3:
287 2
            return rhat_func(data)
288
        else:
289 2
            msg = (
290
                "Only uni-dimensional ndarray variables are supported."
291
                " Please transform first to dataset with `az.convert_to_dataset`."
292
            )
293 2
            raise TypeError(msg)
294

295 2
    dataset = convert_to_dataset(data, group="posterior")
296 2
    var_names = _var_names(var_names, dataset)
297

298 2
    dataset = dataset if var_names is None else dataset[var_names]
299

300 2
    ufunc_kwargs = {"ravel": False}
301 2
    func_kwargs = {}
302 2
    return _wrap_xarray_ufunc(
303
        rhat_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
304
    )
305

306

307 2
def mcse(data, *, var_names=None, method="mean", prob=None):
308
    """Calculate Markov Chain Standard Error statistic.
309

310
    Parameters
311
    ----------
312
    data : obj
313
        Any object that can be converted to an az.InferenceData object
314
        Refer to documentation of az.convert_to_dataset for details
315
        For ndarray: shape = (chain, draw).
316
        For n-dimensional ndarray transform first to dataset with az.convert_to_dataset.
317
    var_names : list
318
        Names of variables to include in the rhat report
319
    method : str
320
        Select mcse method. Valid methods are:
321
        - "mean"
322
        - "sd"
323
        - "median"
324
        - "quantile"
325

326
    prob : float
327
        Quantile information.
328

329
    Returns
330
    -------
331
    xarray.Dataset
332
        Return the msce dataset
333

334
    Examples
335
    --------
336
    Calculate the Markov Chain Standard Error using the default arguments:
337

338
    .. ipython::
339

340
        In [1]: import arviz as az
341
           ...: data = az.load_arviz_data("non_centered_eight")
342
           ...: az.mcse(data)
343

344
    Calculate the Markov Chain Standard Error using the quantile method:
345

346
    .. ipython::
347

348
        In [1]: az.mcse(data, method="quantile", prob=0.7)
349

350
    """
351 2
    methods = {
352
        "mean": _mcse_mean,
353
        "sd": _mcse_sd,
354
        "median": _mcse_median,
355
        "quantile": _mcse_quantile,
356
    }
357 2
    if method not in methods:
358 2
        raise TypeError(
359
            "mcse method {} not found. Valid methods are:\n{}".format(
360
                method, "\n    ".join(methods)
361
            )
362
        )
363 2
    mcse_func = methods[method]
364

365 2
    if method == "quantile" and prob is None:
366 2
        raise TypeError("Quantile (prob) information needs to be defined.")
367

368 2
    if isinstance(data, np.ndarray):
369 2
        data = np.atleast_2d(data)
370 2
        if len(data.shape) < 3:
371 2
            if prob is not None:
372 2
                return mcse_func(data, prob=prob)  # pylint: disable=unexpected-keyword-arg
373
            else:
374 2
                return mcse_func(data)
375
        else:
376 2
            msg = (
377
                "Only uni-dimensional ndarray variables are supported."
378
                " Please transform first to dataset with `az.convert_to_dataset`."
379
            )
380 2
            raise TypeError(msg)
381

382 2
    dataset = convert_to_dataset(data, group="posterior")
383 2
    var_names = _var_names(var_names, dataset)
384

385 2
    dataset = dataset if var_names is None else dataset[var_names]
386

387 2
    ufunc_kwargs = {"ravel": False}
388 2
    func_kwargs = {} if prob is None else {"prob": prob}
389 2
    return _wrap_xarray_ufunc(
390
        mcse_func, dataset, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs
391
    )
392

393

394 2
@conditional_jit(forceobj=True)
395 2
def geweke(ary, first=0.1, last=0.5, intervals=20):
396
    r"""Compute z-scores for convergence diagnostics.
397

398
    Compare the mean of the first % of series with the mean of the last % of series. x is divided
399
    into a number of segments for which this difference is computed. If the series is converged,
400
    this score should oscillate between -1 and 1.
401

402
    Parameters
403
    ----------
404
    ary : 1D array-like
405
      The trace of some stochastic parameter.
406
    first : float
407
      The fraction of series at the beginning of the trace.
408
    last : float
409
      The fraction of series at the end to be compared with the section
410
      at the beginning.
411
    intervals : int
412
      The number of segments.
413

414
    Returns
415
    -------
416
    scores : list [[]]
417
      Return a list of [i, score], where i is the starting index for each interval and score the
418
      Geweke score on the interval.
419

420
    Notes
421
    -----
422
    The Geweke score on some series x is computed by:
423

424
      .. math:: \frac{E[x_s] - E[x_e]}{\sqrt{V[x_s] + V[x_e]}}
425

426
    where :math:`E` stands for the mean, :math:`V` the variance,
427
    :math:`x_s` a section at the start of the series and
428
    :math:`x_e` a section at the end of the series.
429

430
    References
431
    ----------
432
    * Geweke (1992)
433
    """
434
    # Filter out invalid intervals
435 2
    return _geweke(ary, first, last, intervals)
436

437

438 2
def _geweke(ary, first, last, intervals):
439 2
    _numba_flag = Numba.numba_flag
440 2
    for interval in (first, last):
441 2
        if interval <= 0 or interval >= 1:
442 2
            raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
443 2
    if first + last >= 1:
444 2
        raise ValueError("Invalid intervals for Geweke convergence analysis", (first, last))
445

446
    # Initialize list of z-scores
447 2
    zscores = []
448

449
    # Last index value
450 2
    end = len(ary) - 1
451

452
    # Start intervals going up to the <last>% of the chain
453 2
    last_start_idx = (1 - last) * end
454

455
    # Calculate starting indices
456 2
    start_indices = np.linspace(0, last_start_idx, num=intervals, endpoint=True, dtype=int)
457

458
    # Loop over start indices
459 2
    for start in start_indices:
460
        # Calculate slices
461 2
        first_slice = ary[start : start + int(first * (end - start))]
462 2
        last_slice = ary[int(end - last * (end - start)) :]
463

464 2
        z_score = first_slice.mean() - last_slice.mean()
465 2
        if _numba_flag:
466 2
            z_score /= _sqrt(svar(first_slice), svar(last_slice))
467
        else:
468 2
            z_score /= np.sqrt(first_slice.var() + last_slice.var())
469

470 2
        zscores.append([start, z_score])
471

472 2
    return np.array(zscores)
473

474

475 2
def ks_summary(pareto_tail_indices):
476
    """Display a summary of Pareto tail indices.
477

478
    Parameters
479
    ----------
480
    pareto_tail_indices : array
481
      Pareto tail indices.
482

483
    Returns
484
    -------
485
    df_k : dataframe
486
      Dataframe containing k diagnostic values.
487
    """
488 2
    _numba_flag = Numba.numba_flag
489 2
    if _numba_flag:
490 2
        bins = np.asarray([-np.Inf, 0.5, 0.7, 1, np.Inf])
491 2
        kcounts, *_ = _histogram(pareto_tail_indices, bins)
492
    else:
493 2
        kcounts, *_ = _histogram(pareto_tail_indices, bins=[-np.Inf, 0.5, 0.7, 1, np.Inf])
494 2
    kprop = kcounts / len(pareto_tail_indices) * 100
495 2
    df_k = pd.DataFrame(
496
        dict(_=["(good)", "(ok)", "(bad)", "(very bad)"], Count=kcounts, Pct=kprop)
497
    ).rename(index={0: "(-Inf, 0.5]", 1: " (0.5, 0.7]", 2: "   (0.7, 1]", 3: "   (1, Inf)"})
498

499 2
    if np.sum(kcounts[1:]) == 0:
500 2
        warnings.warn("All Pareto k estimates are good (k < 0.5)")
501 2
    elif np.sum(kcounts[2:]) == 0:
502 2
        warnings.warn("All Pareto k estimates are ok (k < 0.7)")
503

504 2
    return df_k
505

506

507 2
def _bfmi(energy):
508
    r"""Calculate the estimated Bayesian fraction of missing information (BFMI).
509

510
    BFMI quantifies how well momentum resampling matches the marginal energy distribution. For more
511
    information on BFMI, see https://arxiv.org/pdf/1604.00695v1.pdf. The current advice is that
512
    values smaller than 0.3 indicate poor sampling. However, this threshold is provisional and may
513
    change. See http://mc-stan.org/users/documentation/case-studies/pystan_workflow.html for more
514
    information.
515

516
    Parameters
517
    ----------
518
    energy : NumPy array
519
        Should be extracted from a gradient based sampler, such as in Stan or PyMC3. Typically,
520
        after converting a trace or fit to InferenceData, the energy will be in
521
        `data.sample_stats.energy`.
522

523
    Returns
524
    -------
525
    z : array
526
        The Bayesian fraction of missing information of the model and trace. One element per
527
        chain in the trace.
528
    """
529 2
    energy_mat = np.atleast_2d(energy)
530 2
    num = np.square(np.diff(energy_mat, axis=1)).mean(axis=1)  # pylint: disable=no-member
531 2
    if energy_mat.ndim == 2:
532 2
        den = _numba_var(svar, np.var, energy_mat, axis=1, ddof=1)
533
    else:
534 2
        den = np.var(energy, axis=1, ddof=1)
535 2
    return num / den
536

537

538 2
def _backtransform_ranks(arr, c=3 / 8):  # pylint: disable=invalid-name
539
    """Backtransformation of ranks.
540

541
    Parameters
542
    ----------
543
    arr : np.ndarray
544
        Ranks array
545
    c : float
546
        Fractional offset. Defaults to c = 3/8 as recommended by Blom (1958).
547

548
    Returns
549
    -------
550
    np.ndarray
551

552
    References
553
    ----------
554
    Blom, G. (1958). Statistical Estimates and Transformed Beta-Variables. Wiley; New York.
555
    """
556 2
    arr = np.asarray(arr)
557 2
    size = arr.size
558 2
    return (arr - c) / (size - 2 * c + 1)
559

560

561 2
def _z_scale(ary):
562
    """Calculate z_scale.
563

564
    Parameters
565
    ----------
566
    ary : np.ndarray
567

568
    Returns
569
    -------
570
    np.ndarray
571
    """
572 2
    ary = np.asarray(ary)
573 2
    rank = stats.rankdata(ary, method="average")
574 2
    rank = _backtransform_ranks(rank)
575 2
    z = stats.norm.ppf(rank)
576 2
    z = z.reshape(ary.shape)
577 2
    return z
578

579

580 2
def _split_chains(ary):
581
    """Split and stack chains."""
582 2
    ary = np.asarray(ary)
583 2
    if len(ary.shape) > 1:
584 2
        _, n_draw = ary.shape
585
    else:
586 2
        ary = np.atleast_2d(ary)
587 2
        _, n_draw = ary.shape
588 2
    half = n_draw // 2
589 2
    return _stack(ary[:, :half], ary[:, -half:])
590

591

592 2
def _z_fold(ary):
593
    """Fold and z-scale values."""
594 2
    ary = np.asarray(ary)
595 2
    ary = abs(ary - np.median(ary))
596 2
    ary = _z_scale(ary)
597 2
    return ary
598

599

600 2
def _rhat(ary):
601
    """Compute the rhat for a 2d array."""
602 2
    _numba_flag = Numba.numba_flag
603 2
    ary = np.asarray(ary, dtype=float)
604 2
    if _not_valid(ary, check_shape=False):
605 2
        return np.nan
606 2
    _, num_samples = ary.shape
607

608
    # Calculate chain mean
609 2
    chain_mean = np.mean(ary, axis=1)
610
    # Calculate chain variance
611 2
    chain_var = _numba_var(svar, np.var, ary, axis=1, ddof=1)
612
    # Calculate between-chain variance
613 2
    between_chain_variance = num_samples * _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
614
    # Calculate within-chain variance
615 2
    within_chain_variance = np.mean(chain_var)
616
    # Estimate of marginal posterior variance
617 2
    rhat_value = np.sqrt(
618
        (between_chain_variance / within_chain_variance + num_samples - 1) / (num_samples)
619
    )
620 2
    return rhat_value
621

622

623 2
def _rhat_rank(ary):
624
    """Compute the rank normalized rhat for 2d array.
625

626
    Computation follows https://arxiv.org/abs/1903.08008
627
    """
628 2
    ary = np.asarray(ary)
629 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
630 2
        return np.nan
631 2
    split_ary = _split_chains(ary)
632 2
    rhat_bulk = _rhat(_z_scale(split_ary))
633

634 2
    split_ary_folded = abs(split_ary - np.median(split_ary))
635 2
    rhat_tail = _rhat(_z_scale(split_ary_folded))
636

637 2
    rhat_rank = max(rhat_bulk, rhat_tail)
638 2
    return rhat_rank
639

640

641 2
def _rhat_folded(ary):
642
    """Calculate split-Rhat for folded z-values."""
643 2
    ary = np.asarray(ary)
644 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
645 2
        return np.nan
646 2
    ary = _z_fold(_split_chains(ary))
647 2
    return _rhat(ary)
648

649

650 2
def _rhat_z_scale(ary):
651 2
    ary = np.asarray(ary)
652 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
653 2
        return np.nan
654 2
    return _rhat(_z_scale(_split_chains(ary)))
655

656

657 2
def _rhat_split(ary):
658 2
    ary = np.asarray(ary)
659 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
660 2
        return np.nan
661 2
    return _rhat(_split_chains(ary))
662

663

664 2
def _rhat_identity(ary):
665 2
    ary = np.asarray(ary)
666 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
667 2
        return np.nan
668 2
    return _rhat(ary)
669

670

671 2
def _ess(ary, relative=False):
672
    """Compute the effective sample size for a 2D array."""
673 2
    _numba_flag = Numba.numba_flag
674 2
    ary = np.asarray(ary, dtype=float)
675 2
    if _not_valid(ary, check_shape=False):
676 2
        return np.nan
677 2
    if (np.max(ary) - np.min(ary)) < np.finfo(float).resolution:  # pylint: disable=no-member
678 2
        return ary.size
679 2
    if len(ary.shape) < 2:
680 2
        ary = np.atleast_2d(ary)
681 2
    n_chain, n_draw = ary.shape
682 2
    acov = _autocov(ary, axis=1)
683 2
    chain_mean = ary.mean(axis=1)
684 2
    mean_var = np.mean(acov[:, 0]) * n_draw / (n_draw - 1.0)
685 2
    var_plus = mean_var * (n_draw - 1.0) / n_draw
686 2
    if n_chain > 1:
687 2
        var_plus += _numba_var(svar, np.var, chain_mean, axis=None, ddof=1)
688

689 2
    rho_hat_t = np.zeros(n_draw)
690 2
    rho_hat_even = 1.0
691 2
    rho_hat_t[0] = rho_hat_even
692 2
    rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, 1])) / var_plus
693 2
    rho_hat_t[1] = rho_hat_odd
694

695
    # Geyer's initial positive sequence
696 2
    t = 1
697 2
    while t < (n_draw - 3) and (rho_hat_even + rho_hat_odd) > 0.0:
698 2
        rho_hat_even = 1.0 - (mean_var - np.mean(acov[:, t + 1])) / var_plus
699 2
        rho_hat_odd = 1.0 - (mean_var - np.mean(acov[:, t + 2])) / var_plus
700 2
        if (rho_hat_even + rho_hat_odd) >= 0:
701 2
            rho_hat_t[t + 1] = rho_hat_even
702 2
            rho_hat_t[t + 2] = rho_hat_odd
703 2
        t += 2
704

705 2
    max_t = t - 2
706
    # improve estimation
707 2
    if rho_hat_even > 0:
708 2
        rho_hat_t[max_t + 1] = rho_hat_even
709
    # Geyer's initial monotone sequence
710 2
    t = 1
711 2
    while t <= max_t - 2:
712 2
        if (rho_hat_t[t + 1] + rho_hat_t[t + 2]) > (rho_hat_t[t - 1] + rho_hat_t[t]):
713 2
            rho_hat_t[t + 1] = (rho_hat_t[t - 1] + rho_hat_t[t]) / 2.0
714 2
            rho_hat_t[t + 2] = rho_hat_t[t + 1]
715 2
        t += 2
716

717 2
    ess = n_chain * n_draw
718 2
    tau_hat = -1.0 + 2.0 * np.sum(rho_hat_t[: max_t + 1]) + np.sum(rho_hat_t[max_t + 1 : max_t + 2])
719 2
    tau_hat = max(tau_hat, 1 / np.log10(ess))
720 2
    ess = (1 if relative else ess) / tau_hat
721 2
    if np.isnan(rho_hat_t).any():
722 0
        ess = np.nan
723 2
    return ess
724

725

726 2
def _ess_bulk(ary, relative=False):
727
    """Compute the effective sample size for the bulk."""
728 2
    ary = np.asarray(ary)
729 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
730 2
        return np.nan
731 2
    z_scaled = _z_scale(_split_chains(ary))
732 2
    ess_bulk = _ess(z_scaled, relative=relative)
733 2
    return ess_bulk
734

735

736 2
def _ess_tail(ary, prob=None, relative=False):
737
    """Compute the effective sample size for the tail.
738

739
    If `prob` defined, ess = min(qess(prob), qess(1-prob))
740
    """
741 2
    if prob is None:
742 2
        prob = (0.05, 0.95)
743 2
    elif not isinstance(prob, Sequence):
744 2
        prob = (prob, 1 - prob)
745

746 2
    ary = np.asarray(ary)
747 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
748 2
        return np.nan
749

750 2
    prob_low, prob_high = prob
751 2
    quantile_low_ess = _ess_quantile(ary, prob_low, relative=relative)
752 2
    quantile_high_ess = _ess_quantile(ary, prob_high, relative=relative)
753 2
    return min(quantile_low_ess, quantile_high_ess)
754

755

756 2
def _ess_mean(ary, relative=False):
757
    """Compute the effective sample size for the mean."""
758 2
    ary = np.asarray(ary)
759 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
760 2
        return np.nan
761 2
    return _ess(_split_chains(ary), relative=relative)
762

763

764 2
def _ess_sd(ary, relative=False):
765
    """Compute the effective sample size for the sd."""
766 2
    ary = np.asarray(ary)
767 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
768 2
        return np.nan
769 2
    ary = _split_chains(ary)
770 2
    return min(_ess(ary, relative=relative), _ess(ary ** 2, relative=relative))
771

772

773 2
def _ess_quantile(ary, prob, relative=False):
774
    """Compute the effective sample size for the specific residual."""
775 2
    ary = np.asarray(ary)
776 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
777 2
        return np.nan
778 2
    if prob is None:
779 2
        raise TypeError("Prob not defined.")
780 2
    (quantile,) = _quantile(ary, prob)
781 2
    iquantile = ary <= quantile
782 2
    return _ess(_split_chains(iquantile), relative=relative)
783

784

785 2
def _ess_local(ary, prob, relative=False):
786
    """Compute the effective sample size for the specific residual."""
787 2
    ary = np.asarray(ary)
788 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
789 2
        return np.nan
790 2
    if prob is None:
791 0
        raise TypeError("Prob not defined.")
792 2
    if len(prob) != 2:
793 2
        raise ValueError("Prob argument in ess local must be upper and lower bound")
794 2
    quantile = _quantile(ary, prob)
795 2
    iquantile = (quantile[0] <= ary) & (ary <= quantile[1])
796 2
    return _ess(_split_chains(iquantile), relative=relative)
797

798

799 2
def _ess_z_scale(ary, relative=False):
800
    """Calculate ess for z-scaLe."""
801 2
    ary = np.asarray(ary)
802 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
803 2
        return np.nan
804 2
    return _ess(_z_scale(_split_chains(ary)), relative=relative)
805

806

807 2
def _ess_folded(ary, relative=False):
808
    """Calculate split-ess for folded data."""
809 2
    ary = np.asarray(ary)
810 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
811 2
        return np.nan
812 2
    return _ess(_z_fold(_split_chains(ary)), relative=relative)
813

814

815 2
def _ess_median(ary, relative=False):
816
    """Calculate split-ess for median."""
817 2
    ary = np.asarray(ary)
818 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
819 2
        return np.nan
820 2
    return _ess_quantile(ary, 0.5, relative=relative)
821

822

823 2
def _ess_mad(ary, relative=False):
824
    """Calculate split-ess for mean absolute deviance."""
825 2
    ary = np.asarray(ary)
826 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
827 2
        return np.nan
828 2
    ary = abs(ary - np.median(ary))
829 2
    ary = ary <= np.median(ary)
830 2
    ary = _z_scale(_split_chains(ary))
831 2
    return _ess(ary, relative=relative)
832

833

834 2
def _ess_identity(ary, relative=False):
835
    """Calculate ess."""
836 2
    ary = np.asarray(ary)
837 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
838 2
        return np.nan
839 2
    return _ess(ary, relative=relative)
840

841

842 2
def _mcse_mean(ary):
843
    """Compute the Markov Chain mean error."""
844 2
    _numba_flag = Numba.numba_flag
845 2
    ary = np.asarray(ary)
846 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
847 2
        return np.nan
848 2
    ess = _ess_mean(ary)
849 2
    if _numba_flag:
850 2
        sd = _sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1))
851
    else:
852 2
        sd = np.std(ary, ddof=1)
853 2
    mcse_mean_value = sd / np.sqrt(ess)
854 2
    return mcse_mean_value
855

856

857 2
def _mcse_sd(ary):
858
    """Compute the Markov Chain sd error."""
859 2
    _numba_flag = Numba.numba_flag
860 2
    ary = np.asarray(ary)
861 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
862 2
        return np.nan
863 2
    ess = _ess_sd(ary)
864 2
    if _numba_flag:
865 2
        sd = np.float(_sqrt(svar(np.ravel(ary), ddof=1), np.zeros(1)))
866
    else:
867 2
        sd = np.std(ary, ddof=1)
868 2
    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess) ** (ess - 1) - 1)
869 2
    mcse_sd_value = sd * fac_mcse_sd
870 2
    return mcse_sd_value
871

872

873 2
def _mcse_median(ary):
874
    """Compute the Markov Chain median error."""
875 2
    return _mcse_quantile(ary, 0.5)
876

877

878 2
def _mcse_quantile(ary, prob):
879
    """Compute the Markov Chain quantile error at quantile=prob."""
880 2
    ary = np.asarray(ary)
881 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
882 2
        return np.nan
883 2
    ess = _ess_quantile(ary, prob)
884 2
    probability = [0.1586553, 0.8413447]
885 2
    with np.errstate(invalid="ignore"):
886 2
        ppf = stats.beta.ppf(probability, ess * prob + 1, ess * (1 - prob) + 1)
887 2
    sorted_ary = np.sort(ary.ravel())
888 2
    size = sorted_ary.size
889 2
    ppf_size = ppf * size - 1
890 2
    th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))]
891 2
    th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))]
892 2
    return (th2 - th1) / 2
893

894

895 2
def _mc_error(ary, batches=5, circular=False):
896
    """Calculate the simulation standard error, accounting for non-independent samples.
897

898
    The trace is divided into batches, and the standard deviation of the batch
899
    means is calculated.
900

901
    Parameters
902
    ----------
903
    ary : Numpy array
904
        An array containing MCMC samples
905
    batches : integer
906
        Number of batches
907
    circular : bool
908
        Whether to compute the error taking into account `ary` is a circular variable
909
        (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
910

911
    Returns
912
    -------
913
    mc_error : float
914
        Simulation standard error
915
    """
916 2
    _numba_flag = Numba.numba_flag
917 2
    if ary.ndim > 1:
918

919 2
        dims = np.shape(ary)
920 2
        trace = np.transpose([t.ravel() for t in ary])
921

922 2
        return np.reshape([_mc_error(t, batches) for t in trace], dims[1:])
923

924
    else:
925 2
        if _not_valid(ary, check_shape=False):
926 2
            return np.nan
927 2
        if batches == 1:
928 2
            if circular:
929 2
                if _numba_flag:
930 2
                    std = _circular_standard_deviation(ary, high=np.pi, low=-np.pi)
931
                else:
932 0
                    std = stats.circstd(ary, high=np.pi, low=-np.pi)
933
            else:
934 2
                if _numba_flag:
935 2
                    std = np.float(_sqrt(svar(ary), np.zeros(1)))
936
                else:
937 2
                    std = np.std(ary)
938 2
            return std / np.sqrt(len(ary))
939

940 2
        batched_traces = np.resize(ary, (batches, int(len(ary) / batches)))
941

942 2
        if circular:
943 2
            means = stats.circmean(batched_traces, high=np.pi, low=-np.pi, axis=1)
944 2
            if _numba_flag:
945 2
                std = _circular_standard_deviation(means, high=np.pi, low=-np.pi)
946
            else:
947 2
                std = stats.circstd(means, high=np.pi, low=-np.pi)
948
        else:
949 2
            means = np.mean(batched_traces, 1)
950 2
            if _numba_flag:
951 2
                std = _sqrt(svar(means), np.zeros(1))
952
            else:
953 2
                std = np.std(means)
954

955 2
        return std / np.sqrt(batches)
956

957

958 2
def _multichain_statistics(ary):
959
    """Calculate efficiently multichain statistics for summary.
960

961
    Parameters
962
    ----------
963
    ary : numpy.ndarray
964

965
    Returns
966
    -------
967
    tuple
968
        Order of return parameters is
969
            - mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat
970
    """
971 2
    ary = np.atleast_2d(ary)
972 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=1)):
973 2
        return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
974
    # ess mean
975 2
    ess_mean_value = _ess_mean(ary)
976

977
    # ess sd
978 2
    ess_sd_value = _ess_sd(ary)
979

980
    # ess bulk
981 2
    z_split = _z_scale(_split_chains(ary))
982 2
    ess_bulk_value = _ess(z_split)
983

984
    # ess tail
985 2
    quantile05, quantile95 = _quantile(ary, [0.05, 0.95])
986 2
    iquantile05 = ary <= quantile05
987 2
    quantile05_ess = _ess(_split_chains(iquantile05))
988 2
    iquantile95 = ary <= quantile95
989 2
    quantile95_ess = _ess(_split_chains(iquantile95))
990 2
    ess_tail_value = min(quantile05_ess, quantile95_ess)
991

992 2
    if _not_valid(ary, shape_kwargs=dict(min_draws=4, min_chains=2)):
993 2
        rhat_value = np.nan
994
    else:
995
        # r_hat
996 2
        rhat_bulk = _rhat(z_split)
997 2
        ary_folded = np.abs(ary - np.median(ary))
998 2
        rhat_tail = _rhat(_z_scale(_split_chains(ary_folded)))
999 2
        rhat_value = max(rhat_bulk, rhat_tail)
1000

1001
    # mcse_mean
1002 2
    sd = np.std(ary, ddof=1)
1003 2
    mcse_mean_value = sd / np.sqrt(ess_mean_value)
1004

1005
    # mcse_sd
1006 2
    fac_mcse_sd = np.sqrt(np.exp(1) * (1 - 1 / ess_sd_value) ** (ess_sd_value - 1) - 1)
1007 2
    mcse_sd_value = sd * fac_mcse_sd
1008

1009 2
    return (
1010
        mcse_mean_value,
1011
        mcse_sd_value,
1012
        ess_mean_value,
1013
        ess_sd_value,
1014
        ess_bulk_value,
1015
        ess_tail_value,
1016
        rhat_value,
1017
    )

Read our documentation on viewing source code .

Loading