1
# pylint: disable=too-many-lines
2 2
"""Statistical functions in ArviZ."""
3 2
import logging
4 2
import warnings
5 2
from collections import OrderedDict
6 2
from copy import deepcopy
7 2
from typing import List, Optional, Union
8

9 2
import numpy as np
10 2
import pandas as pd
11 2
import scipy.stats as st
12 2
import xarray as xr
13 2
from scipy.optimize import minimize
14

15 2
from ..data import CoordSpec, DimSpec, InferenceData, convert_to_dataset, convert_to_inference_data
16 2
from ..rcparams import rcParams
17 2
from ..utils import Numba, _numba_var, _var_names, credible_interval_warning, get_coords
18 2
from .density_utils import get_bins as _get_bins
19 2
from .density_utils import histogram as _histogram
20 2
from .density_utils import kde as _kde
21 2
from .diagnostics import _mc_error, _multichain_statistics, ess
22 2
from .stats_utils import ELPDData, _circular_standard_deviation
23 2
from .stats_utils import get_log_likelihood as _get_log_likelihood
24 2
from .stats_utils import logsumexp as _logsumexp
25 2
from .stats_utils import make_ufunc as _make_ufunc
26 2
from .stats_utils import stats_variance_2d as svar
27 2
from .stats_utils import wrap_xarray_ufunc as _wrap_xarray_ufunc
28

29 2
_log = logging.getLogger(__name__)
30

31 2
__all__ = [
32
    "apply_test_function",
33
    "compare",
34
    "hdi",
35
    "hpd",
36
    "loo",
37
    "loo_pit",
38
    "psislw",
39
    "r2_score",
40
    "summary",
41
    "waic",
42
]
43

44

45 2
def compare(
46
    dataset_dict, ic=None, method="BB-pseudo-BMA", b_samples=1000, alpha=1, seed=None, scale=None
47
):
48
    r"""Compare models based on PSIS-LOO `loo` or WAIC `waic` cross-validation.
49

50
    LOO is leave-one-out (PSIS-LOO `loo`) cross-validation and
51
    WAIC is the widely applicable information criterion.
52
    Read more theory here - in a paper by some of the leading authorities
53
    on model selection dx.doi.org/10.1111/1467-9868.00353
54

55
    Parameters
56
    ----------
57
    dataset_dict: dict[str] -> InferenceData
58
        A dictionary of model names and InferenceData objects
59
    ic: str
60
        Information Criterion (PSIS-LOO `loo` or WAIC `waic`) used to compare models. Defaults to
61
        ``rcParams["stats.information_criterion"]``.
62
    method: str
63
        Method used to estimate the weights for each model. Available options are:
64

65
        - 'stacking' : stacking of predictive distributions.
66
        - 'BB-pseudo-BMA' : (default) pseudo-Bayesian Model averaging using Akaike-type
67
          weighting. The weights are stabilized using the Bayesian bootstrap.
68
        - 'pseudo-BMA': pseudo-Bayesian Model averaging using Akaike-type
69
          weighting, without Bootstrap stabilization (not recommended).
70

71
        For more information read https://arxiv.org/abs/1704.02030
72
    b_samples: int
73
        Number of samples taken by the Bayesian bootstrap estimation.
74
        Only useful when method = 'BB-pseudo-BMA'.
75
    alpha: float
76
        The shape parameter in the Dirichlet distribution used for the Bayesian bootstrap. Only
77
        useful when method = 'BB-pseudo-BMA'. When alpha=1 (default), the distribution is uniform
78
        on the simplex. A smaller alpha will keeps the final weights more away from 0 and 1.
79
    seed: int or np.random.RandomState instance
80
        If int or RandomState, use it for seeding Bayesian bootstrap. Only
81
        useful when method = 'BB-pseudo-BMA'. Default None the global
82
        np.random state is used.
83
    scale: str
84
        Output scale for IC. Available options are:
85

86
        - `log` : (default) log-score (after Vehtari et al. (2017))
87
        - `negative_log` : -1 * (log-score)
88
        - `deviance` : -2 * (log-score)
89

90
        A higher log-score (or a lower deviance) indicates a model with better predictive
91
        accuracy.
92

93
    Returns
94
    -------
95
    A DataFrame, ordered from best to worst model (measured by information criteria).
96
    The index reflects the key with which the models are passed to this function. The columns are:
97
    rank: The rank-order of the models. 0 is the best.
98
    IC: Information Criteria (PSIS-LOO `loo` or WAIC `waic`).
99
        Higher IC indicates higher out-of-sample predictive fit ("better" model). Default LOO.
100
        If `scale` is `deviance` or `negative_log` smaller IC indicates
101
        higher out-of-sample predictive fit ("better" model).
102
    pIC: Estimated effective number of parameters.
103
    dIC: Relative difference between each IC (PSIS-LOO `loo` or WAIC `waic`)
104
          and the lowest IC (PSIS-LOO `loo` or WAIC `waic`).
105
          The top-ranked model is always 0.
106
    weight: Relative weight for each model.
107
        This can be loosely interpreted as the probability of each model (among the compared model)
108
        given the data. By default the uncertainty in the weights estimation is considered using
109
        Bayesian bootstrap.
110
    SE: Standard error of the IC estimate.
111
        If method = BB-pseudo-BMA these values are estimated using Bayesian bootstrap.
112
    dSE: Standard error of the difference in IC between each model and the top-ranked model.
113
        It's always 0 for the top-ranked model.
114
    warning: A value of 1 indicates that the computation of the IC may not be reliable.
115
        This could be indication of WAIC/LOO starting to fail see
116
        http://arxiv.org/abs/1507.04544 for details.
117
    scale: Scale used for the IC.
118

119
    Examples
120
    --------
121
    Compare the centered and non centered models of the eight school problem:
122

123
    .. ipython::
124

125
        In [1]: import arviz as az
126
           ...: data1 = az.load_arviz_data("non_centered_eight")
127
           ...: data2 = az.load_arviz_data("centered_eight")
128
           ...: compare_dict = {"non centered": data1, "centered": data2}
129
           ...: az.compare(compare_dict)
130

131
    Compare the models using LOO-CV, returning the IC in log scale and calculating the
132
    weights using the stacking method.
133

134
    .. ipython::
135

136
        In [1]: az.compare(compare_dict, ic="loo", method="stacking", scale="log")
137

138
    See Also
139
    --------
140
    loo : Compute the Pareto Smoothed importance sampling Leave One Out cross-validation.
141
    waic : Compute the widely applicable information criterion.
142

143
    """
144 2
    names = list(dataset_dict.keys())
145 2
    scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
146 2
    if scale == "log":
147 2
        scale_value = 1
148 2
        ascending = False
149 2
        warnings.warn(
150
            "\nThe scale is now log by default. Use 'scale' argument or "
151
            "'stats.ic_scale' rcParam if you rely on a specific value.\nA higher "
152
            "log-score (or a lower deviance) indicates a model with better predictive "
153
            "accuracy."
154
        )
155
    else:
156 2
        if scale == "negative_log":
157 2
            scale_value = -1
158
        else:
159 2
            scale_value = -2
160 2
        ascending = True
161

162 2
    ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
163 2
    if ic == "loo":
164 2
        ic_func = loo
165 2
        df_comp = pd.DataFrame(
166
            index=names,
167
            columns=[
168
                "rank",
169
                "loo",
170
                "p_loo",
171
                "d_loo",
172
                "weight",
173
                "se",
174
                "dse",
175
                "warning",
176
                "loo_scale",
177
            ],
178
        )
179 2
        scale_col = "loo_scale"
180 2
    elif ic == "waic":
181 2
        ic_func = waic
182 2
        df_comp = pd.DataFrame(
183
            index=names,
184
            columns=[
185
                "rank",
186
                "waic",
187
                "p_waic",
188
                "d_waic",
189
                "weight",
190
                "se",
191
                "dse",
192
                "warning",
193
                "waic_scale",
194
            ],
195
        )
196 2
        scale_col = "waic_scale"
197
    else:
198 2
        raise NotImplementedError("The information criterion {} is not supported.".format(ic))
199

200 2
    if method.lower() not in ["stacking", "bb-pseudo-bma", "pseudo-bma"]:
201 2
        raise ValueError("The method {}, to compute weights, is not supported.".format(method))
202

203 2
    ic_se = "{}_se".format(ic)
204 2
    p_ic = "p_{}".format(ic)
205 2
    ic_i = "{}_i".format(ic)
206

207 2
    ics = pd.DataFrame()
208 2
    names = []
209 2
    for name, dataset in dataset_dict.items():
210 2
        names.append(name)
211 2
        ics = ics.append([ic_func(dataset, pointwise=True, scale=scale)])
212 2
    ics.index = names
213 2
    ics.sort_values(by=ic, inplace=True, ascending=ascending)
214 2
    ics[ic_i] = ics[ic_i].apply(lambda x: x.values.flatten())
215

216 2
    if method.lower() == "stacking":
217 2
        rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
218 2
        exp_ic_i = np.exp(ic_i_val / scale_value)
219 2
        last_col = cols - 1
220

221 2
        def w_fuller(weights):
222 2
            return np.concatenate((weights, [max(1.0 - np.sum(weights), 0.0)]))
223

224 2
        def log_score(weights):
225 2
            w_full = w_fuller(weights)
226 2
            score = 0.0
227 2
            for i in range(rows):
228 2
                score += np.log(np.dot(exp_ic_i[i], w_full))
229 2
            return -score
230

231 2
        def gradient(weights):
232 2
            w_full = w_fuller(weights)
233 2
            grad = np.zeros(last_col)
234 2
            for k in range(last_col - 1):
235 0
                for i in range(rows):
236 0
                    grad[k] += (exp_ic_i[i, k] - exp_ic_i[i, last_col]) / np.dot(
237
                        exp_ic_i[i], w_full
238
                    )
239 2
            return -grad
240

241 2
        theta = np.full(last_col, 1.0 / cols)
242 2
        bounds = [(0.0, 1.0) for _ in range(last_col)]
243 2
        constraints = [
244
            {"type": "ineq", "fun": lambda x: 1.0 - np.sum(x)},
245
            {"type": "ineq", "fun": np.sum},
246
        ]
247

248 2
        weights = minimize(
249
            fun=log_score, x0=theta, jac=gradient, bounds=bounds, constraints=constraints
250
        )
251

252 2
        weights = w_fuller(weights["x"])
253 2
        ses = ics[ic_se]
254

255 2
    elif method.lower() == "bb-pseudo-bma":
256 2
        rows, cols, ic_i_val = _ic_matrix(ics, ic_i)
257 2
        ic_i_val = ic_i_val * rows
258

259 2
        b_weighting = st.dirichlet.rvs(alpha=[alpha] * rows, size=b_samples, random_state=seed)
260 2
        weights = np.zeros((b_samples, cols))
261 2
        z_bs = np.zeros_like(weights)
262 2
        for i in range(b_samples):
263 2
            z_b = np.dot(b_weighting[i], ic_i_val)
264 2
            u_weights = np.exp((z_b - np.min(z_b)) / scale_value)
265 2
            z_bs[i] = z_b  # pylint: disable=unsupported-assignment-operation
266 2
            weights[i] = u_weights / np.sum(u_weights)
267

268 2
        weights = weights.mean(axis=0)
269 2
        ses = pd.Series(z_bs.std(axis=0), index=names)  # pylint: disable=no-member
270

271 2
    elif method.lower() == "pseudo-bma":
272 2
        min_ic = ics.iloc[0][ic]
273 2
        z_rv = np.exp((ics[ic] - min_ic) / scale_value)
274 2
        weights = z_rv / np.sum(z_rv)
275 2
        ses = ics[ic_se]
276

277 2
    if np.any(weights):
278 2
        min_ic_i_val = ics[ic_i].iloc[0]
279 2
        for idx, val in enumerate(ics.index):
280 2
            res = ics.loc[val]
281 2
            if scale_value < 0:
282 2
                diff = res[ic_i] - min_ic_i_val
283
            else:
284 2
                diff = min_ic_i_val - res[ic_i]
285 2
            d_ic = np.sum(diff)
286 2
            d_std_err = np.sqrt(len(diff) * np.var(diff))
287 2
            std_err = ses.loc[val]
288 2
            weight = weights[idx]
289 2
            df_comp.at[val] = (
290
                idx,
291
                res[ic],
292
                res[p_ic],
293
                d_ic,
294
                weight,
295
                std_err,
296
                d_std_err,
297
                res["warning"],
298
                res[scale_col],
299
            )
300

301 2
    return df_comp.sort_values(by=ic, ascending=ascending)
302

303

304 2
def _ic_matrix(ics, ic_i):
305
    """Store the previously computed pointwise predictive accuracy values (ics) in a 2D matrix."""
306 2
    cols, _ = ics.shape
307 2
    rows = len(ics[ic_i].iloc[0])
308 2
    ic_i_val = np.zeros((rows, cols))
309

310 2
    for idx, val in enumerate(ics.index):
311 2
        ic = ics.loc[val][ic_i]
312

313 2
        if len(ic) != rows:
314 2
            raise ValueError("The number of observations should be the same across all models")
315

316 2
        ic_i_val[:, idx] = ic
317

318 2
    return rows, cols, ic_i_val
319

320

321 2
def hpd(
322
    # pylint: disable=unused-argument
323
    ary,
324
    hdi_prob=None,
325
    circular=False,
326
    multimodal=False,
327
    skipna=False,
328
    group="posterior",
329
    var_names=None,
330
    filter_vars=None,
331
    coords=None,
332
    max_modes=10,
333
    **kwargs,
334
):
335
    """Pending deprecation. Please refer to :func:`~arviz.hdi`."""
336
    # pylint: enable=unused-argument
337 0
    warnings.warn(
338
        ("hpd will be deprecated " "Please replace hdi"),
339
    )
340 0
    return hdi(
341
        ary,
342
        hdi_prob,
343
        circular,
344
        multimodal,
345
        skipna,
346
        group,
347
        var_names,
348
        filter_vars,
349
        coords,
350
        max_modes,
351
        **kwargs,
352
    )
353

354

355 2
def hdi(
356
    ary,
357
    hdi_prob=None,
358
    circular=False,
359
    multimodal=False,
360
    skipna=False,
361
    group="posterior",
362
    var_names=None,
363
    filter_vars=None,
364
    coords=None,
365
    max_modes=10,
366
    **kwargs,
367
):
368
    """
369
    Calculate highest density interval (HDI) of array for given probability.
370

371
    The HDI is the minimum width Bayesian credible interval (BCI).
372

373
    Parameters
374
    ----------
375
    ary: obj
376
        object containing posterior samples.
377
        Any object that can be converted to an az.InferenceData object.
378
        Refer to documentation of az.convert_to_dataset for details.
379
    hdi_prob: float, optional
380
        HDI prob for which interval will be computed. Defaults to ``stats.hdi_prob`` rcParam.
381
    circular: bool, optional
382
        Whether to compute the hdi taking into account `x` is a circular variable
383
        (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
384
        Only works if multimodal is False.
385
    multimodal: bool, optional
386
        If true it may compute more than one hdi interval if the distribution is multimodal and the
387
        modes are well separated.
388
    skipna: bool, optional
389
        If true ignores nan values when computing the hdi interval. Defaults to false.
390
    group: str, optional
391
        Specifies which InferenceData group should be used to calculate hdi.
392
        Defaults to 'posterior'
393
    var_names: list, optional
394
        Names of variables to include in the hdi report. Prefix the variables by `~`
395
        when you want to exclude them from the report: `["~beta"]` instead of `["beta"]`
396
        (see `az.summary` for more details).
397
    filter_vars: {None, "like", "regex"}, optional, default=None
398
        If `None` (default), interpret var_names as the real variables names. If "like",
399
        interpret var_names as substrings of the real variables names. If "regex",
400
        interpret var_names as regular expressions on the real variables names. A la
401
        `pandas.filter`.
402
    coords: mapping, optional
403
        Specifies the subset over to calculate hdi.
404
    max_modes: int, optional
405
        Specifies the maximum number of modes for multimodal case.
406
    kwargs: dict, optional
407
        Additional keywords passed to :func:`~arviz.wrap_xarray_ufunc`.
408

409
    Returns
410
    -------
411
    np.ndarray or xarray.Dataset, depending upon input
412
        lower(s) and upper(s) values of the interval(s).
413

414
    See Also
415
    --------
416
    plot_hdi : Plot HDI intervals for regression data.
417
    xarray.Dataset.quantile : Calculate quantiles of array for given probabilities.
418

419
    Examples
420
    --------
421
    Calculate the HDI of a Normal random variable:
422

423
    .. ipython::
424

425
        In [1]: import arviz as az
426
           ...: import numpy as np
427
           ...: data = np.random.normal(size=2000)
428
           ...: az.hdi(data, hdi_prob=.68)
429

430
    Calculate the HDI of a dataset:
431

432
    .. ipython::
433

434
        In [1]: import arviz as az
435
           ...: data = az.load_arviz_data('centered_eight')
436
           ...: az.hdi(data)
437

438
    We can also calculate the HDI of some of the variables of dataset:
439

440
    .. ipython::
441

442
        In [1]: az.hdi(data, var_names=["mu", "theta"])
443

444
    If we want to calculate the HDI over specified dimension of dataset,
445
    we can pass `input_core_dims` by kwargs:
446

447
    .. ipython::
448

449
        In [1]: az.hdi(data, input_core_dims = [["chain"]])
450

451
    We can also calculate the hdi over a particular selection over all groups:
452

453
    .. ipython::
454

455
        In [1]: az.hdi(data, coords={"chain":[0, 1, 3]}, input_core_dims = [["draw"]])
456

457
    """
458 2
    if hdi_prob is None:
459 2
        hdi_prob = rcParams["stats.hdi_prob"]
460
    else:
461 2
        if not 1 >= hdi_prob > 0:
462 2
            raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
463

464 2
    func_kwargs = {
465
        "hdi_prob": hdi_prob,
466
        "skipna": skipna,
467
        "out_shape": (max_modes, 2) if multimodal else (2,),
468
    }
469 2
    kwargs.setdefault("output_core_dims", [["hdi", "mode"] if multimodal else ["hdi"]])
470 2
    if not multimodal:
471 2
        func_kwargs["circular"] = circular
472
    else:
473 2
        func_kwargs["max_modes"] = max_modes
474

475 2
    func = _hdi_multimodal if multimodal else _hdi
476

477 2
    isarray = isinstance(ary, np.ndarray)
478 2
    if isarray and ary.ndim <= 1:
479 2
        func_kwargs.pop("out_shape")
480 2
        hdi_data = func(ary, **func_kwargs)  # pylint: disable=unexpected-keyword-arg
481 2
        return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data
482

483 2
    if isarray and ary.ndim == 2:
484 2
        warnings.warn(
485
            "hdi currently interprets 2d data as (draw, shape) but this will change in "
486
            "a future release to (chain, draw) for coherence with other functions",
487
            FutureWarning,
488
        )
489 2
        ary = np.expand_dims(ary, 0)
490

491 2
    ary = convert_to_dataset(ary, group=group)
492 2
    if coords is not None:
493 2
        ary = get_coords(ary, coords)
494 2
    var_names = _var_names(var_names, ary, filter_vars)
495 2
    ary = ary[var_names] if var_names else ary
496

497 2
    hdi_coord = xr.DataArray(["lower", "higher"], dims=["hdi"], attrs=dict(hdi_prob=hdi_prob))
498 2
    hdi_data = _wrap_xarray_ufunc(func, ary, func_kwargs=func_kwargs, **kwargs).assign_coords(
499
        {"hdi": hdi_coord}
500
    )
501 2
    hdi_data = hdi_data.dropna("mode", how="all") if multimodal else hdi_data
502 2
    return hdi_data.x.values if isarray else hdi_data
503

504

505 2
def _hdi(ary, hdi_prob, circular, skipna):
506
    """Compute hpi over the flattened array."""
507 2
    ary = ary.flatten()
508 2
    if skipna:
509 2
        nans = np.isnan(ary)
510 2
        if not nans.all():
511 2
            ary = ary[~nans]
512 2
    n = len(ary)
513

514 2
    if circular:
515 2
        mean = st.circmean(ary, high=np.pi, low=-np.pi)
516 2
        ary = ary - mean
517 2
        ary = np.arctan2(np.sin(ary), np.cos(ary))
518

519 2
    ary = np.sort(ary)
520 2
    interval_idx_inc = int(np.floor(hdi_prob * n))
521 2
    n_intervals = n - interval_idx_inc
522 2
    interval_width = ary[interval_idx_inc:] - ary[:n_intervals]
523

524 2
    if len(interval_width) == 0:
525 0
        raise ValueError("Too few elements for interval calculation. ")
526

527 2
    min_idx = np.argmin(interval_width)
528 2
    hdi_min = ary[min_idx]
529 2
    hdi_max = ary[min_idx + interval_idx_inc]
530

531 2
    if circular:
532 2
        hdi_min = hdi_min + mean
533 2
        hdi_max = hdi_max + mean
534 2
        hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
535 2
        hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
536

537 2
    hdi_interval = np.array([hdi_min, hdi_max])
538

539 2
    return hdi_interval
540

541

542 2
def _hdi_multimodal(ary, hdi_prob, skipna, max_modes):
543
    """Compute HDI if the distribution is multimodal."""
544 2
    ary = ary.flatten()
545 2
    if skipna:
546 0
        ary = ary[~np.isnan(ary)]
547

548 2
    if ary.dtype.kind == "f":
549 2
        bins, density = _kde(ary)
550 2
        lower, upper = bins[0], bins[-1]
551 2
        range_x = upper - lower
552 2
        dx = range_x / len(density)
553
    else:
554 0
        bins = _get_bins(ary)
555 0
        _, density, _ = _histogram(ary, bins=bins)
556 0
        dx = np.diff(bins)[0]
557

558 2
    density *= dx
559

560 2
    idx = np.argsort(-density)
561 2
    intervals = bins[idx][density[idx].cumsum() <= hdi_prob]
562 2
    intervals.sort()
563

564 2
    intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)
565

566 2
    hdi_intervals = np.full((max_modes, 2), np.nan)
567 2
    for i, interval in enumerate(intervals_splitted):
568 2
        if i == max_modes:
569 0
            warnings.warn(
570
                "found more modes than {0}, returning only the first {0} modes".format(max_modes)
571
            )
572 0
            break
573 2
        if interval.size == 0:
574 0
            hdi_intervals[i] = np.asarray([bins[0], bins[0]])
575
        else:
576 2
            hdi_intervals[i] = np.asarray([interval[0], interval[-1]])
577

578 2
    return np.array(hdi_intervals)
579

580

581 2
def loo(data, pointwise=None, var_name=None, reff=None, scale=None):
582
    """Compute Pareto-smoothed importance sampling leave-one-out cross-validation (PSIS-LOO-CV).
583

584
    Estimates the expected log pointwise predictive density (elpd) using Pareto-smoothed
585
    importance sampling leave-one-out cross-validation (PSIS-LOO-CV). Also calculates LOO's
586
    standard error and the effective number of parameters. Read more theory here
587
    https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1507.02646
588

589
    Parameters
590
    ----------
591
    data: obj
592
        Any object that can be converted to an az.InferenceData object. Refer to documentation of
593
        az.convert_to_inference_data for details
594
    pointwise: bool, optional
595
        If True the pointwise predictive accuracy will be returned. Defaults to
596
        ``stats.ic_pointwise`` rcParam.
597
    var_name : str, optional
598
        The name of the variable in log_likelihood groups storing the pointwise log
599
        likelihood data to use for loo computation.
600
    reff: float, optional
601
        Relative MCMC efficiency, `ess / n` i.e. number of effective samples divided by the number
602
        of actual samples. Computed from trace by default.
603
    scale: str
604
        Output scale for loo. Available options are:
605

606
        - `log` : (default) log-score
607
        - `negative_log` : -1 * log-score
608
        - `deviance` : -2 * log-score
609

610
        A higher log-score (or a lower deviance or negative log_score) indicates a model with
611
        better predictive accuracy.
612

613
    Returns
614
    -------
615
    ELPDData object (inherits from panda.Series) with the following row/attributes:
616
    loo: approximated expected log pointwise predictive density (elpd)
617
    loo_se: standard error of loo
618
    p_loo: effective number of parameters
619
    shape_warn: bool
620
        True if the estimated shape parameter of
621
        Pareto distribution is greater than 0.7 for one or more samples
622
    loo_i: array of pointwise predictive accuracy, only if pointwise True
623
    pareto_k: array of Pareto shape values, only if pointwise True
624
    loo_scale: scale of the loo results
625

626
        The returned object has a custom print method that overrides pd.Series method.
627

628
    Examples
629
    --------
630
    Calculate LOO of a model:
631

632
    .. ipython::
633

634
        In [1]: import arviz as az
635
           ...: data = az.load_arviz_data("centered_eight")
636
           ...: az.loo(data)
637

638
    Calculate LOO of a model and return the pointwise values:
639

640
    .. ipython::
641

642
        In [2]: data_loo = az.loo(data, pointwise=True)
643
           ...: data_loo.loo_i
644
    """
645 2
    inference_data = convert_to_inference_data(data)
646 2
    log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
647 2
    pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
648

649 2
    log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
650 2
    shape = log_likelihood.shape
651 2
    n_samples = shape[-1]
652 2
    n_data_points = np.product(shape[:-1])
653 2
    scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
654

655 2
    if scale == "deviance":
656 2
        scale_value = -2
657 2
    elif scale == "log":
658 2
        scale_value = 1
659 2
    elif scale == "negative_log":
660 2
        scale_value = -1
661
    else:
662 2
        raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
663

664 2
    if reff is None:
665 2
        if not hasattr(inference_data, "posterior"):
666 2
            raise TypeError("Must be able to extract a posterior group from data.")
667 2
        posterior = inference_data.posterior
668 2
        n_chains = len(posterior.chain)
669 2
        if n_chains == 1:
670 2
            reff = 1.0
671
        else:
672 2
            ess_p = ess(posterior, method="mean")
673
            # this mean is over all data variables
674 2
            reff = (
675
                np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples
676
            )
677

678 2
    log_weights, pareto_shape = psislw(-log_likelihood, reff)
679 2
    log_weights += log_likelihood
680

681 2
    warn_mg = False
682 2
    if np.any(pareto_shape > 0.7):
683 2
        warnings.warn(
684
            "Estimated shape parameter of Pareto distribution is greater than 0.7 for "
685
            "one or more samples. You should consider using a more robust model, this is because "
686
            "importance sampling is less likely to work well if the marginal posterior and "
687
            "LOO posterior are very different. This is more likely to happen with a non-robust "
688
            "model and highly influential observations."
689
        )
690 2
        warn_mg = True
691

692 2
    ufunc_kwargs = {"n_dims": 1, "ravel": False}
693 2
    kwargs = {"input_core_dims": [["sample"]]}
694 2
    loo_lppd_i = scale_value * _wrap_xarray_ufunc(
695
        _logsumexp, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs
696
    )
697 2
    loo_lppd = loo_lppd_i.values.sum()
698 2
    loo_lppd_se = (n_data_points * np.var(loo_lppd_i.values)) ** 0.5
699

700 2
    lppd = np.sum(
701
        _wrap_xarray_ufunc(
702
            _logsumexp,
703
            log_likelihood,
704
            func_kwargs={"b_inv": n_samples},
705
            ufunc_kwargs=ufunc_kwargs,
706
            **kwargs,
707
        ).values
708
    )
709 2
    p_loo = lppd - loo_lppd / scale_value
710

711 2
    if pointwise:
712 2
        if np.equal(loo_lppd, loo_lppd_i).all():  # pylint: disable=no-member
713 0
            warnings.warn(
714
                "The point-wise LOO is the same with the sum LOO, please double check "
715
                "the Observed RV in your model to make sure it returns element-wise logp."
716
            )
717 2
        return ELPDData(
718
            data=[
719
                loo_lppd,
720
                loo_lppd_se,
721
                p_loo,
722
                n_samples,
723
                n_data_points,
724
                warn_mg,
725
                loo_lppd_i.rename("loo_i"),
726
                pareto_shape,
727
                scale,
728
            ],
729
            index=[
730
                "loo",
731
                "loo_se",
732
                "p_loo",
733
                "n_samples",
734
                "n_data_points",
735
                "warning",
736
                "loo_i",
737
                "pareto_k",
738
                "loo_scale",
739
            ],
740
        )
741

742
    else:
743 2
        return ELPDData(
744
            data=[loo_lppd, loo_lppd_se, p_loo, n_samples, n_data_points, warn_mg, scale],
745
            index=["loo", "loo_se", "p_loo", "n_samples", "n_data_points", "warning", "loo_scale"],
746
        )
747

748

749 2
def psislw(log_weights, reff=1.0):
750
    """
751
    Pareto smoothed importance sampling (PSIS).
752

753
    Parameters
754
    ----------
755
    log_weights: array
756
        Array of size (n_observations, n_samples)
757
    reff: float
758
        relative MCMC efficiency, `ess / n`
759

760
    Returns
761
    -------
762
    lw_out: array
763
        Smoothed log weights
764
    kss: array
765
        Pareto tail indices
766

767
    References
768
    ----------
769
    * Vehtari et al. (2015) see https://arxiv.org/abs/1507.02646
770

771
    Examples
772
    --------
773
    Get Pareto smoothed importance sampling (PSIS) log weights:
774

775
    .. ipython::
776

777
        In [1]: import arviz as az
778
           ...: data = az.load_arviz_data("centered_eight")
779
           ...: log_likelihood = data.sample_stats.log_likelihood.stack(sample=("chain", "draw"))
780
           ...: az.psislw(-log_likelihood, reff=0.8)
781

782
    """
783 2
    if hasattr(log_weights, "sample"):
784 2
        n_samples = len(log_weights.sample)
785 2
        shape = [size for size, dim in zip(log_weights.shape, log_weights.dims) if dim != "sample"]
786
    else:
787 0
        n_samples = log_weights.shape[-1]
788 0
        shape = log_weights.shape[:-1]
789
    # precalculate constants
790 2
    cutoff_ind = -int(np.ceil(min(n_samples / 5.0, 3 * (n_samples / reff) ** 0.5))) - 1
791 2
    cutoffmin = np.log(np.finfo(float).tiny)  # pylint: disable=no-member, assignment-from-no-return
792 2
    k_min = 1.0 / 3
793

794
    # create output array with proper dimensions
795 2
    out = tuple([np.empty_like(log_weights), np.empty(shape)])
796

797
    # define kwargs
798 2
    func_kwargs = {"cutoff_ind": cutoff_ind, "cutoffmin": cutoffmin, "k_min": k_min, "out": out}
799 2
    ufunc_kwargs = {"n_dims": 1, "n_output": 2, "ravel": False, "check_shape": False}
800 2
    kwargs = {"input_core_dims": [["sample"]], "output_core_dims": [["sample"], []]}
801 2
    log_weights, pareto_shape = _wrap_xarray_ufunc(
802
        _psislw, log_weights, ufunc_kwargs=ufunc_kwargs, func_kwargs=func_kwargs, **kwargs
803
    )
804 2
    if isinstance(log_weights, xr.DataArray):
805 2
        log_weights = log_weights.rename("log_weights").rename(sample="sample")
806 2
    if isinstance(pareto_shape, xr.DataArray):
807 2
        pareto_shape = pareto_shape.rename("pareto_shape")
808 2
    return log_weights, pareto_shape
809

810

811 2
def _psislw(log_weights, cutoff_ind, cutoffmin, k_min=1.0 / 3):
812
    """
813
    Pareto smoothed importance sampling (PSIS) for a 1D vector.
814

815
    Parameters
816
    ----------
817
    log_weights: array
818
        Array of length n_observations
819
    cutoff_ind: int
820
    cutoffmin: float
821
    k_min: float
822

823
    Returns
824
    -------
825
    lw_out: array
826
        Smoothed log weights
827
    kss: float
828
        Pareto tail index
829
    """
830 2
    x = np.asarray(log_weights)
831

832
    # improve numerical accuracy
833 2
    x -= np.max(x)
834
    # sort the array
835 2
    x_sort_ind = np.argsort(x)
836
    # divide log weights into body and right tail
837 2
    xcutoff = max(x[x_sort_ind[cutoff_ind]], cutoffmin)
838

839 2
    expxcutoff = np.exp(xcutoff)
840 2
    (tailinds,) = np.where(x > xcutoff)  # pylint: disable=unbalanced-tuple-unpacking
841 2
    x_tail = x[tailinds]
842 2
    tail_len = len(x_tail)
843 2
    if tail_len <= 4:
844
        # not enough tail samples for gpdfit
845 2
        k = np.inf
846
    else:
847
        # order of tail samples
848 2
        x_tail_si = np.argsort(x_tail)
849
        # fit generalized Pareto distribution to the right tail samples
850 2
        x_tail = np.exp(x_tail) - expxcutoff
851 2
        k, sigma = _gpdfit(x_tail[x_tail_si])
852

853 2
        if k >= k_min:
854
            # no smoothing if short tail or GPD fit failed
855
            # compute ordered statistic for the fit
856 2
            sti = np.arange(0.5, tail_len) / tail_len
857 2
            smoothed_tail = _gpinv(sti, k, sigma)
858 2
            smoothed_tail = np.log(  # pylint: disable=assignment-from-no-return
859
                smoothed_tail + expxcutoff
860
            )
861
            # place the smoothed tail into the output array
862 2
            x[tailinds[x_tail_si]] = smoothed_tail
863
            # truncate smoothed values to the largest raw weight 0
864 2
            x[x > 0] = 0
865
    # renormalize weights
866 2
    x -= _logsumexp(x)
867

868 2
    return x, k
869

870

871 2
def _gpdfit(ary):
872
    """Estimate the parameters for the Generalized Pareto Distribution (GPD).
873

874
    Empirical Bayes estimate for the parameters of the generalized Pareto
875
    distribution given the data.
876

877
    Parameters
878
    ----------
879
    ary: array
880
        sorted 1D data array
881

882
    Returns
883
    -------
884
    k: float
885
        estimated shape parameter
886
    sigma: float
887
        estimated scale parameter
888
    """
889 2
    prior_bs = 3
890 2
    prior_k = 10
891 2
    n = len(ary)
892 2
    m_est = 30 + int(n ** 0.5)
893

894 2
    b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5))
895 2
    b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1]
896 2
    b_ary += 1 / ary[-1]
897

898 2
    k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1)  # pylint: disable=no-member
899 2
    len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1)
900 2
    weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1)
901

902
    # remove negligible weights
903 2
    real_idxs = weights >= 10 * np.finfo(float).eps
904 2
    if not np.all(real_idxs):
905 2
        weights = weights[real_idxs]
906 2
        b_ary = b_ary[real_idxs]
907
    # normalise weights
908 2
    weights /= weights.sum()
909

910
    # posterior mean for b
911 2
    b_post = np.sum(b_ary * weights)
912
    # estimate for k
913 2
    k_post = np.log1p(-b_post * ary).mean()  # pylint: disable=invalid-unary-operand-type,no-member
914
    # add prior for k_post
915 2
    k_post = (n * k_post + prior_k * 0.5) / (n + prior_k)
916 2
    sigma = -k_post / b_post
917

918 2
    return k_post, sigma
919

920

921 2
def _gpinv(probs, kappa, sigma):
922
    """Inverse Generalized Pareto distribution function."""
923
    # pylint: disable=unsupported-assignment-operation, invalid-unary-operand-type
924 2
    x = np.full_like(probs, np.nan)
925 2
    if sigma <= 0:
926 2
        return x
927 2
    ok = (probs > 0) & (probs < 1)
928 2
    if np.all(ok):
929 2
        if np.abs(kappa) < np.finfo(float).eps:
930 2
            x = -np.log1p(-probs)
931
        else:
932 2
            x = np.expm1(-kappa * np.log1p(-probs)) / kappa
933 2
        x *= sigma
934
    else:
935 2
        if np.abs(kappa) < np.finfo(float).eps:
936 2
            x[ok] = -np.log1p(-probs[ok])
937
        else:
938 2
            x[ok] = np.expm1(-kappa * np.log1p(-probs[ok])) / kappa
939 2
        x *= sigma
940 2
        x[probs == 0] = 0
941 2
        if kappa >= 0:
942 2
            x[probs == 1] = np.inf
943
        else:
944 2
            x[probs == 1] = -sigma / kappa
945 2
    return x
946

947

948 2
def r2_score(y_true, y_pred):
949
    """R² for Bayesian regression models. Only valid for linear models.
950

951
    Parameters
952
    ----------
953
    y_true: array-like of shape = (n_samples) or (n_samples, n_outputs)
954
        Ground truth (correct) target values.
955
    y_pred: array-like of shape = (n_samples) or (n_samples, n_outputs)
956
        Estimated target values.
957

958
    Returns
959
    -------
960
    Pandas Series with the following indices:
961
    r2: Bayesian R²
962
    r2_std: standard deviation of the Bayesian R².
963

964
    Examples
965
    --------
966
    Calculate R² for Bayesian regression models :
967

968
    .. ipython::
969

970
        In [1]: import arviz as az
971
           ...: data = az.load_arviz_data('regression1d')
972
           ...: y_true = data.observed_data["y"].values
973
           ...: y_pred = data.posterior_predictive.stack(sample=("chain", "draw"))["y"].values.T
974
           ...: az.r2_score(y_true, y_pred)
975

976
    """
977 2
    _numba_flag = Numba.numba_flag
978 2
    if y_pred.ndim == 1:
979 2
        var_y_est = _numba_var(svar, np.var, y_pred)
980 2
        var_e = _numba_var(svar, np.var, (y_true - y_pred))
981
    else:
982 2
        var_y_est = _numba_var(svar, np.var, y_pred.mean(0))
983 2
        var_e = _numba_var(svar, np.var, (y_true - y_pred), axis=0)
984 2
    r_squared = var_y_est / (var_y_est + var_e)
985

986 2
    return pd.Series([np.mean(r_squared), np.std(r_squared)], index=["r2", "r2_std"])
987

988

989 2
def summary(
990
    data,
991
    var_names: Optional[List[str]] = None,
992
    filter_vars=None,
993
    fmt: str = "wide",
994
    kind: str = "all",
995
    round_to=None,
996
    include_circ=None,
997
    circ_var_names=None,
998
    stat_funcs=None,
999
    extend=True,
1000
    hdi_prob=None,
1001
    order="C",
1002
    index_origin=None,
1003
    skipna=False,
1004
    coords: Optional[CoordSpec] = None,
1005
    dims: Optional[DimSpec] = None,
1006
    credible_interval=None,
1007
) -> Union[pd.DataFrame, xr.Dataset]:
1008
    """Create a data frame with summary statistics.
1009

1010
    Parameters
1011
    ----------
1012
    data: obj
1013
        Any object that can be converted to an az.InferenceData object
1014
        Refer to documentation of az.convert_to_dataset for details
1015
    var_names: list
1016
        Names of variables to include in summary. Prefix the variables by `~` when you
1017
        want to exclude them from the summary: `["~beta"]` instead of `["beta"]` (see
1018
        examples below).
1019
    filter_vars: {None, "like", "regex"}, optional, default=None
1020
        If `None` (default), interpret var_names as the real variables names. If "like",
1021
        interpret var_names as substrings of the real variables names. If "regex",
1022
        interpret var_names as regular expressions on the real variables names. A la
1023
        `pandas.filter`.
1024
    fmt: {'wide', 'long', 'xarray'}
1025
        Return format is either pandas.DataFrame {'wide', 'long'} or xarray.Dataset {'xarray'}.
1026
    kind: {'all', 'stats', 'diagnostics'}
1027
        Whether to include the `stats`: `mean`, `sd`, `hdi_3%`, `hdi_97%`, or the `diagnostics`:
1028
        `mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and `r_hat`. Default to include `all` of
1029
        them.
1030
    round_to: int
1031
        Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
1032
    include_circ: boolean
1033
        Whether to include circular statistics
1034
        deprecated: Please see circ_var_names
1035
    circ_var_names: list
1036
        A list of circular variables to compute circular stats for
1037
    stat_funcs: dict
1038
        A list of functions or a dict of functions with function names as keys used to calculate
1039
        statistics. By default, the mean, standard deviation, simulation standard error, and
1040
        highest posterior density intervals are included.
1041

1042
        The functions will be given one argument, the samples for a variable as an nD array,
1043
        The functions should be in the style of a ufunc and return a single number. For example,
1044
        `np.mean`, or `scipy.stats.var` would both work.
1045
    extend: boolean
1046
        If True, use the statistics returned by ``stat_funcs`` in addition to, rather than in place
1047
        of, the default statistics. This is only meaningful when ``stat_funcs`` is not None.
1048
    hdi_prob: float, optional
1049
        HDI interval to compute. Defaults to 0.94. This is only meaningful when ``stat_funcs`` is
1050
        None.
1051
    order: {"C", "F"}
1052
        If fmt is "wide", use either C or F unpacking order. Defaults to C.
1053
    index_origin: int
1054
        If fmt is "wide, select n-based indexing for multivariate parameters.
1055
        Defaults to rcParam data.index.origin, which is 0.
1056
    skipna: bool
1057
        If true ignores nan values when computing the summary statistics, it does not affect the
1058
        behaviour of the functions passed to ``stat_funcs``. Defaults to false.
1059
    coords: Dict[str, List[Any]], optional
1060
        Coordinates specification to be used if the ``fmt`` is ``'xarray'``.
1061
    dims: Dict[str, List[str]], optional
1062
        Dimensions specification for the variables to be used if the ``fmt`` is ``'xarray'``.
1063
    credible_interval: float, optional
1064
        deprecated: Please see hdi_prob
1065

1066
    Returns
1067
    -------
1068
    pandas.DataFrame or xarray.Dataset
1069
        Return type dicated by `fmt` argument.
1070
        Return value will contain summary statistics for each variable. Default statistics are:
1071
        `mean`, `sd`, `hdi_3%`, `hdi_97%`, `mcse_mean`, `mcse_sd`, `ess_bulk`, `ess_tail`, and
1072
        `r_hat`.
1073
        `r_hat` is only computed for traces with 2 or more chains.
1074

1075
    Examples
1076
    --------
1077
    .. ipython::
1078

1079
        In [1]: import arviz as az
1080
           ...: data = az.load_arviz_data("centered_eight")
1081
           ...: az.summary(data, var_names=["mu", "tau"])
1082

1083
    You can use `filter_vars` to select variables without having to specify all the exact
1084
    names. Use `filter_vars="like"` to select based on partial naming:
1085

1086
    .. ipython::
1087

1088
        In [1]: az.summary(data, var_names=["the"], filter_vars="like")
1089

1090
    Use `filter_vars="regex"` to select based on regular expressions, and prefix the variables
1091
    you want to exclude by `~`. Here, we exclude from the summary all the variables
1092
    starting with the letter t:
1093

1094
    .. ipython::
1095

1096
        In [1]: az.summary(data, var_names=["~^t"], filter_vars="regex")
1097

1098
    Other statistics can be calculated by passing a list of functions
1099
    or a dictionary with key, function pairs.
1100

1101
    .. ipython::
1102

1103
        In [1]: import numpy as np
1104
           ...: def median_sd(x):
1105
           ...:     median = np.percentile(x, 50)
1106
           ...:     sd = np.sqrt(np.mean((x-median)**2))
1107
           ...:     return sd
1108
           ...:
1109
           ...: func_dict = {
1110
           ...:     "std": np.std,
1111
           ...:     "median_std": median_sd,
1112
           ...:     "5%": lambda x: np.percentile(x, 5),
1113
           ...:     "median": lambda x: np.percentile(x, 50),
1114
           ...:     "95%": lambda x: np.percentile(x, 95),
1115
           ...: }
1116
           ...: az.summary(
1117
           ...:     data,
1118
           ...:     var_names=["mu", "tau"],
1119
           ...:     stat_funcs=func_dict,
1120
           ...:     extend=False
1121
           ...: )
1122

1123
    """
1124 2
    if include_circ:
1125 2
        warnings.warn(
1126
            "include_circ is deprecated and will be ignored. Use circ_var_names instead",
1127
            DeprecationWarning,
1128
        )
1129

1130 2
    if credible_interval:
1131 0
        hdi_prob = credible_interval_warning(hdi_prob, hdi_prob)
1132

1133 2
    extra_args = {}  # type: Dict[str, Any]
1134 2
    if coords is not None:
1135 0
        extra_args["coords"] = coords
1136 2
    if dims is not None:
1137 0
        extra_args["dims"] = dims
1138 2
    if index_origin is None:
1139 2
        index_origin = rcParams["data.index_origin"]
1140 2
    if hdi_prob is None:
1141 2
        hdi_prob = rcParams["stats.hdi_prob"]
1142
    else:
1143 0
        if not 1 >= hdi_prob > 0:
1144 0
            raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
1145 2
    posterior = convert_to_dataset(data, group="posterior", **extra_args)
1146 2
    var_names = _var_names(var_names, posterior, filter_vars)
1147 2
    posterior = posterior if var_names is None else posterior[var_names]
1148

1149 2
    fmt_group = ("wide", "long", "xarray")
1150 2
    if not isinstance(fmt, str) or (fmt.lower() not in fmt_group):
1151 2
        raise TypeError("Invalid format: '{}'. Formatting options are: {}".format(fmt, fmt_group))
1152

1153 2
    unpack_order_group = ("C", "F")
1154 2
    if not isinstance(order, str) or (order.upper() not in unpack_order_group):
1155 2
        raise TypeError(
1156
            "Invalid order: '{}'. Unpacking options are: {}".format(order, unpack_order_group)
1157
        )
1158

1159 2
    alpha = 1 - hdi_prob
1160

1161 2
    extra_metrics = []
1162 2
    extra_metric_names = []
1163

1164 2
    if stat_funcs is not None:
1165 2
        if isinstance(stat_funcs, dict):
1166 2
            for stat_func_name, stat_func in stat_funcs.items():
1167 2
                extra_metrics.append(
1168
                    xr.apply_ufunc(
1169
                        _make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
1170
                    )
1171
                )
1172 2
                extra_metric_names.append(stat_func_name)
1173
        else:
1174 2
            for stat_func in stat_funcs:
1175 2
                extra_metrics.append(
1176
                    xr.apply_ufunc(
1177
                        _make_ufunc(stat_func), posterior, input_core_dims=(("chain", "draw"),)
1178
                    )
1179
                )
1180 2
                extra_metric_names.append(stat_func.__name__)
1181

1182 2
    if extend and kind in ["all", "stats"]:
1183 2
        mean = posterior.mean(dim=("chain", "draw"), skipna=skipna)
1184

1185 2
        sd = posterior.std(dim=("chain", "draw"), ddof=1, skipna=skipna)
1186

1187 2
        hdi_post = hdi(posterior, hdi_prob=hdi_prob, multimodal=False, skipna=skipna)
1188 2
        hdi_lower = hdi_post.sel(hdi="lower", drop=True)
1189 2
        hdi_higher = hdi_post.sel(hdi="higher", drop=True)
1190

1191 2
    if circ_var_names:
1192 2
        nan_policy = "omit" if skipna else "propagate"
1193 2
        circ_mean = xr.apply_ufunc(
1194
            _make_ufunc(st.circmean),
1195
            posterior,
1196
            kwargs=dict(high=np.pi, low=-np.pi, nan_policy=nan_policy),
1197
            input_core_dims=(("chain", "draw"),),
1198
        )
1199 2
        _numba_flag = Numba.numba_flag
1200 2
        func = None
1201 2
        if _numba_flag:
1202 2
            func = _circular_standard_deviation
1203 2
            kwargs_circ_std = dict(high=np.pi, low=-np.pi, skipna=skipna)
1204
        else:
1205 2
            func = st.circstd
1206 2
            kwargs_circ_std = dict(high=np.pi, low=-np.pi, nan_policy=nan_policy)
1207 2
        circ_sd = xr.apply_ufunc(
1208
            _make_ufunc(func),
1209
            posterior,
1210
            kwargs=kwargs_circ_std,
1211
            input_core_dims=(("chain", "draw"),),
1212
        )
1213

1214 2
        circ_mcse = xr.apply_ufunc(
1215
            _make_ufunc(_mc_error),
1216
            posterior,
1217
            kwargs=dict(circular=True),
1218
            input_core_dims=(("chain", "draw"),),
1219
        )
1220

1221 2
        circ_hdi = hdi(posterior, hdi_prob=hdi_prob, circular=True, skipna=skipna)
1222 2
        circ_hdi_lower = circ_hdi.sel(hdi="lower", drop=True)
1223 2
        circ_hdi_higher = circ_hdi.sel(hdi="higher", drop=True)
1224

1225 2
    if kind in ["all", "diagnostics"]:
1226 2
        mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat = xr.apply_ufunc(
1227
            _make_ufunc(_multichain_statistics, n_output=7, ravel=False),
1228
            posterior,
1229
            input_core_dims=(("chain", "draw"),),
1230
            output_core_dims=tuple([] for _ in range(7)),
1231
        )
1232

1233
    # Combine metrics
1234 2
    metrics = []
1235 2
    metric_names = []
1236 2
    if extend:
1237 2
        metrics_names_ = (
1238
            "mean",
1239
            "sd",
1240
            "hdi_{:g}%".format(100 * alpha / 2),
1241
            "hdi_{:g}%".format(100 * (1 - alpha / 2)),
1242
            "mcse_mean",
1243
            "mcse_sd",
1244
            "ess_mean",
1245
            "ess_sd",
1246
            "ess_bulk",
1247
            "ess_tail",
1248
            "r_hat",
1249
        )
1250 2
        if kind == "all":
1251 2
            metrics_ = (
1252
                mean,
1253
                sd,
1254
                hdi_lower,
1255
                hdi_higher,
1256
                mcse_mean,
1257
                mcse_sd,
1258
                ess_mean,
1259
                ess_sd,
1260
                ess_bulk,
1261
                ess_tail,
1262
                r_hat,
1263
            )
1264 2
        elif kind == "stats":
1265 2
            metrics_ = (mean, sd, hdi_lower, hdi_higher)
1266 2
            metrics_names_ = metrics_names_[:4]
1267 2
        elif kind == "diagnostics":
1268 2
            metrics_ = (mcse_mean, mcse_sd, ess_mean, ess_sd, ess_bulk, ess_tail, r_hat)
1269 2
            metrics_names_ = metrics_names_[4:]
1270 2
        metrics.extend(metrics_)
1271 2
        metric_names.extend(metrics_names_)
1272

1273 2
    if circ_var_names:
1274

1275 2
        if kind != "diagnostics":
1276 2
            for metric, circ_stat in zip(
1277
                # Replace only the first 5 statistics for their circular equivalent
1278
                metrics[:5],
1279
                (circ_mean, circ_sd, circ_hdi_lower, circ_hdi_higher, circ_mcse),
1280
            ):
1281 2
                for circ_var in circ_var_names:
1282 2
                    metric[circ_var] = circ_stat[circ_var]
1283

1284 2
    metrics.extend(extra_metrics)
1285 2
    metric_names.extend(extra_metric_names)
1286 2
    joined = (
1287
        xr.concat(metrics, dim="metric").assign_coords(metric=metric_names).reset_coords(drop=True)
1288
    )
1289

1290 2
    if fmt.lower() == "wide":
1291 2
        dfs = []
1292 2
        for var_name, values in joined.data_vars.items():
1293 2
            if len(values.shape[1:]):
1294 2
                metric = list(values.metric.values)
1295 2
                data_dict = OrderedDict()
1296 2
                for idx in np.ndindex(values.shape[1:] if order == "C" else values.shape[1:][::-1]):
1297 2
                    if order == "F":
1298 2
                        idx = tuple(idx[::-1])
1299 2
                    ser = pd.Series(values[(Ellipsis, *idx)].values, index=metric)
1300 2
                    key_index = ",".join(map(str, (i + index_origin for i in idx)))
1301 2
                    key = "{}[{}]".format(var_name, key_index)
1302 2
                    data_dict[key] = ser
1303 2
                df = pd.DataFrame.from_dict(data_dict, orient="index")
1304 2
                df = df.loc[list(data_dict.keys())]
1305
            else:
1306 2
                df = values.to_dataframe()
1307 2
                df.index = list(df.index)
1308 2
                df = df.T
1309 2
            dfs.append(df)
1310 2
        summary_df = pd.concat(dfs, sort=False)
1311 2
    elif fmt.lower() == "long":
1312 2
        df = joined.to_dataframe().reset_index().set_index("metric")
1313 2
        df.index = list(df.index)
1314 2
        summary_df = df
1315
    else:
1316
        # format is 'xarray'
1317 2
        summary_df = joined
1318 2
    if (round_to is not None) and (round_to not in ("None", "none")):
1319 0
        summary_df = summary_df.round(round_to)
1320 2
    elif round_to not in ("None", "none") and (fmt.lower() in ("long", "wide")):
1321
        # Don't round xarray object by default (even with "none")
1322 2
        decimals = {
1323
            col: 3
1324
            if col not in {"ess_mean", "ess_sd", "ess_bulk", "ess_tail", "r_hat"}
1325
            else 2
1326
            if col == "r_hat"
1327
            else 0
1328
            for col in summary_df.columns
1329
        }
1330 2
        summary_df = summary_df.round(decimals)
1331

1332 2
    return summary_df
1333

1334

1335 2
def waic(data, pointwise=None, var_name=None, scale=None):
1336
    """Compute the widely applicable information criterion.
1337

1338
    Estimates the expected log pointwise predictive density (elpd) using WAIC. Also calculates the
1339
    WAIC's standard error and the effective number of parameters.
1340
    Read more theory here https://arxiv.org/abs/1507.04544 and here https://arxiv.org/abs/1004.2316
1341

1342
    Parameters
1343
    ----------
1344
    data: obj
1345
        Any object that can be converted to an az.InferenceData object. Refer to documentation of
1346
        ``az.convert_to_inference_data`` for details
1347
    pointwise: bool
1348
        If True the pointwise predictive accuracy will be returned. Defaults to
1349
        ``stats.ic_pointwise`` rcParam.
1350
    var_name : str, optional
1351
        The name of the variable in log_likelihood groups storing the pointwise log
1352
        likelihood data to use for waic computation.
1353
    scale: str
1354
        Output scale for WAIC. Available options are:
1355

1356
        - `log` : (default) log-score
1357
        - `negative_log` : -1 * log-score
1358
        - `deviance` : -2 * log-score
1359

1360
        A higher log-score (or a lower deviance or negative log_score) indicates a model with
1361
        better predictive accuracy.
1362

1363
    Returns
1364
    -------
1365
    ELPDData object (inherits from panda.Series) with the following row/attributes:
1366
    waic: approximated expected log pointwise predictive density (elpd)
1367
    waic_se: standard error of waic
1368
    p_waic: effective number parameters
1369
    var_warn: bool
1370
        True if posterior variance of the log predictive densities exceeds 0.4
1371
    waic_i: xarray.DataArray with the pointwise predictive accuracy, only if pointwise=True
1372
    waic_scale: scale of the reported waic results
1373

1374
        The returned object has a custom print method that overrides pd.Series method.
1375

1376
    Examples
1377
    --------
1378
    Calculate WAIC of a model:
1379

1380
    .. ipython::
1381

1382
        In [1]: import arviz as az
1383
           ...: data = az.load_arviz_data("centered_eight")
1384
           ...: az.waic(data)
1385

1386
    Calculate WAIC of a model and return the pointwise values:
1387

1388
    .. ipython::
1389

1390
        In [2]: data_waic = az.waic(data, pointwise=True)
1391
           ...: data_waic.waic_i
1392
    """
1393 2
    inference_data = convert_to_inference_data(data)
1394 2
    log_likelihood = _get_log_likelihood(inference_data, var_name=var_name)
1395 2
    scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
1396 2
    pointwise = rcParams["stats.ic_pointwise"] if pointwise is None else pointwise
1397

1398 2
    if scale == "deviance":
1399 2
        scale_value = -2
1400 2
    elif scale == "log":
1401 2
        scale_value = 1
1402 2
    elif scale == "negative_log":
1403 2
        scale_value = -1
1404
    else:
1405 2
        raise TypeError('Valid scale values are "deviance", "log", "negative_log"')
1406

1407 2
    log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
1408 2
    shape = log_likelihood.shape
1409 2
    n_samples = shape[-1]
1410 2
    n_data_points = np.product(shape[:-1])
1411

1412 2
    ufunc_kwargs = {"n_dims": 1, "ravel": False}
1413 2
    kwargs = {"input_core_dims": [["sample"]]}
1414 2
    lppd_i = _wrap_xarray_ufunc(
1415
        _logsumexp,
1416
        log_likelihood,
1417
        func_kwargs={"b_inv": n_samples},
1418
        ufunc_kwargs=ufunc_kwargs,
1419
        **kwargs,
1420
    )
1421

1422 2
    vars_lpd = log_likelihood.var(dim="sample")
1423 2
    warn_mg = False
1424 2
    if np.any(vars_lpd > 0.4):
1425 2
        warnings.warn(
1426
            (
1427
                "For one or more samples the posterior variance of the log predictive "
1428
                "densities exceeds 0.4. This could be indication of WAIC starting to fail. \n"
1429
                "See http://arxiv.org/abs/1507.04544 for details"
1430
            )
1431
        )
1432 2
        warn_mg = True
1433

1434 2
    waic_i = scale_value * (lppd_i - vars_lpd)
1435 2
    waic_se = (n_data_points * np.var(waic_i.values)) ** 0.5
1436 2
    waic_sum = np.sum(waic_i.values)
1437 2
    p_waic = np.sum(vars_lpd.values)
1438

1439 2
    if pointwise:
1440 2
        if np.equal(waic_sum, waic_i).all():  # pylint: disable=no-member
1441 2
            warnings.warn(
1442
                """The point-wise WAIC is the same with the sum WAIC, please double check
1443
            the Observed RV in your model to make sure it returns element-wise logp.
1444
            """
1445
            )
1446 2
        return ELPDData(
1447
            data=[
1448
                waic_sum,
1449
                waic_se,
1450
                p_waic,
1451
                n_samples,
1452
                n_data_points,
1453
                warn_mg,
1454
                waic_i.rename("waic_i"),
1455
                scale,
1456
            ],
1457
            index=[
1458
                "waic",
1459
                "waic_se",
1460
                "p_waic",
1461
                "n_samples",
1462
                "n_data_points",
1463
                "warning",
1464
                "waic_i",
1465
                "waic_scale",
1466
            ],
1467
        )
1468
    else:
1469 2
        return ELPDData(
1470
            data=[waic_sum, waic_se, p_waic, n_samples, n_data_points, warn_mg, scale],
1471
            index=[
1472
                "waic",
1473
                "waic_se",
1474
                "p_waic",
1475
                "n_samples",
1476
                "n_data_points",
1477
                "warning",
1478
                "waic_scale",
1479
            ],
1480
        )
1481

1482

1483 2
def loo_pit(idata=None, *, y=None, y_hat=None, log_weights=None):
1484
    """Compute leave one out (PSIS-LOO) probability integral transform (PIT) values.
1485

1486
    Parameters
1487
    ----------
1488
    idata: InferenceData
1489
        InferenceData object.
1490
    y: array, DataArray or str
1491
        Observed data. If str, idata must be present and contain the observed data group
1492
    y_hat: array, DataArray or str
1493
        Posterior predictive samples for ``y``. It must have the same shape as y plus an
1494
        extra dimension at the end of size n_samples (chains and draws stacked). If str or
1495
        None, idata must contain the posterior predictive group. If None, y_hat is taken
1496
        equal to y, thus, y must be str too.
1497
    log_weights: array or DataArray
1498
        Smoothed log_weights. It must have the same shape as ``y_hat``
1499

1500
    Returns
1501
    -------
1502
    loo_pit: array or DataArray
1503
        Value of the LOO-PIT at each observed data point.
1504

1505
    Examples
1506
    --------
1507
    Calculate LOO-PIT values using as test quantity the observed values themselves.
1508

1509
    .. ipython::
1510

1511
        In [1]: import arviz as az
1512
           ...: data = az.load_arviz_data("centered_eight")
1513
           ...: az.loo_pit(idata=data, y="obs")
1514

1515
    Calculate LOO-PIT values using as test quantity the square of the difference between
1516
    each observation and `mu`. Both ``y`` and ``y_hat`` inputs will be array-like,
1517
    but ``idata`` will still be passed in order to calculate the ``log_weights`` from
1518
    there.
1519

1520
    .. ipython::
1521

1522
        In [1]: T = data.observed_data.obs - data.posterior.mu.median(dim=("chain", "draw"))
1523
           ...: T_hat = data.posterior_predictive.obs - data.posterior.mu
1524
           ...: T_hat = T_hat.stack(sample=("chain", "draw"))
1525
           ...: az.loo_pit(idata=data, y=T**2, y_hat=T_hat**2)
1526

1527
    """
1528 2
    y_str = ""
1529 2
    if idata is not None and not isinstance(idata, InferenceData):
1530 2
        raise ValueError("idata must be of type InferenceData or None")
1531

1532 2
    if idata is None:
1533 2
        if not all(isinstance(arg, (np.ndarray, xr.DataArray)) for arg in (y, y_hat, log_weights)):
1534 2
            raise ValueError(
1535
                "all 3 y, y_hat and log_weights must be array or DataArray when idata is None "
1536
                "but they are of types {}".format([type(arg) for arg in (y, y_hat, log_weights)])
1537
            )
1538

1539
    else:
1540 2
        if y_hat is None and isinstance(y, str):
1541 2
            y_hat = y
1542 2
        elif y_hat is None:
1543 2
            raise ValueError("y_hat cannot be None if y is not a str")
1544 2
        if isinstance(y, str):
1545 2
            y_str = y
1546 2
            y = idata.observed_data[y].values
1547 2
        elif not isinstance(y, (np.ndarray, xr.DataArray)):
1548 2
            raise ValueError("y must be of types array, DataArray or str, not {}".format(type(y)))
1549 2
        if isinstance(y_hat, str):
1550 2
            y_hat = idata.posterior_predictive[y_hat].stack(sample=("chain", "draw")).values
1551 2
        elif not isinstance(y_hat, (np.ndarray, xr.DataArray)):
1552 2
            raise ValueError(
1553
                "y_hat must be of types array, DataArray or str, not {}".format(type(y_hat))
1554
            )
1555 2
        if log_weights is None:
1556 2
            if y_str:
1557 2
                try:
1558 2
                    log_likelihood = _get_log_likelihood(idata, var_name=y)
1559 2
                except TypeError:
1560 2
                    log_likelihood = _get_log_likelihood(idata)
1561
            else:
1562 2
                log_likelihood = _get_log_likelihood(idata)
1563 2
            log_likelihood = log_likelihood.stack(sample=("chain", "draw"))
1564 2
            posterior = convert_to_dataset(idata, group="posterior")
1565 2
            n_chains = len(posterior.chain)
1566 2
            n_samples = len(log_likelihood.sample)
1567 2
            ess_p = ess(posterior, method="mean")
1568
            # this mean is over all data variables
1569 2
            reff = (
1570
                (np.hstack([ess_p[v].values.flatten() for v in ess_p.data_vars]).mean() / n_samples)
1571
                if n_chains > 1
1572
                else 1
1573
            )
1574 2
            log_weights = psislw(-log_likelihood, reff=reff)[0].values
1575 2
        elif not isinstance(log_weights, (np.ndarray, xr.DataArray)):
1576 2
            raise ValueError(
1577
                "log_weights must be None or of types array or DataArray, not {}".format(
1578
                    type(log_weights)
1579
                )
1580
            )
1581

1582 2
    if len(y.shape) + 1 != len(y_hat.shape):
1583 2
        raise ValueError(
1584
            "y_hat must have 1 more dimension than y, but y_hat has {} dims and y has "
1585
            "{} dims".format(len(y.shape), len(y_hat.shape))
1586
        )
1587

1588 2
    if y.shape != y_hat.shape[:-1]:
1589 2
        raise ValueError(
1590
            "y has shape: {} which should be equal to y_hat shape (omitting the last "
1591
            "dimension): {}".format(y.shape, y_hat.shape)
1592
        )
1593

1594 2
    if y_hat.shape != log_weights.shape:
1595 2
        raise ValueError(
1596
            "y_hat and log_weights must have the same shape but have shapes {} and {}".format(
1597
                y_hat.shape, log_weights.shape
1598
            )
1599
        )
1600

1601 2
    kwargs = {
1602
        "input_core_dims": [[], ["sample"], ["sample"]],
1603
        "output_core_dims": [[]],
1604
        "join": "left",
1605
    }
1606 2
    ufunc_kwargs = {"n_dims": 1}
1607

1608 2
    return _wrap_xarray_ufunc(_loo_pit, y, y_hat, log_weights, ufunc_kwargs=ufunc_kwargs, **kwargs)
1609

1610

1611 2
def _loo_pit(y, y_hat, log_weights):
1612
    """Compute LOO-PIT values."""
1613 2
    sel = y_hat <= y
1614 2
    if np.sum(sel) > 0:
1615 2
        value = np.exp(_logsumexp(log_weights[sel]))
1616 2
        return min(1, value)
1617
    else:
1618 0
        return 0
1619

1620

1621 2
def apply_test_function(
1622
    idata,
1623
    func,
1624
    group="both",
1625
    var_names=None,
1626
    pointwise=False,
1627
    out_data_shape=None,
1628
    out_pp_shape=None,
1629
    out_name_data="T",
1630
    out_name_pp=None,
1631
    func_args=None,
1632
    func_kwargs=None,
1633
    ufunc_kwargs=None,
1634
    wrap_data_kwargs=None,
1635
    wrap_pp_kwargs=None,
1636
    inplace=True,
1637
    overwrite=None,
1638
):
1639
    """Apply a Bayesian test function to an InferenceData object.
1640

1641
    Parameters
1642
    ----------
1643
    idata: InferenceData
1644
        InferenceData object on which to apply the test function. This function will add
1645
        new variables to the InferenceData object to store the result without modifying the
1646
        existing ones.
1647
    func: callable
1648
        Callable that calculates the test function. It must have the following call signature
1649
        ``func(y, theta, *args, **kwargs)`` (where ``y`` is the observed data or posterior
1650
        predictive and ``theta`` the model parameters) even if not all the arguments are
1651
        used.
1652
    group: str, optional
1653
        Group on which to apply the test function. Can be observed_data, posterior_predictive
1654
        or both.
1655
    var_names: dict group -> var_names, optional
1656
        Mapping from group name to the variables to be passed to func. It can be a dict of
1657
        strings or lists of strings. There is also the option of using ``both`` as key,
1658
        in which case, the same variables are used in observed data and posterior predictive
1659
        groups
1660
    pointwise: bool, optional
1661
        If True, apply the test function to each observation and sample, otherwise, apply
1662
        test function to each sample.
1663
    out_data_shape, out_pp_shape: tuple, optional
1664
        Output shape of the test function applied to the observed/posterior predictive data.
1665
        If None, the default depends on the value of pointwise.
1666
    out_name_data, out_name_pp: str, optional
1667
        Name of the variables to add to the observed_data and posterior_predictive datasets
1668
        respectively. ``out_name_pp`` can be ``None``, in which case will be taken equal to
1669
        ``out_name_data``.
1670
    func_args: sequence, optional
1671
        Passed as is to ``func``
1672
    func_kwargs: mapping, optional
1673
        Passed as is to ``func``
1674
    wrap_data_kwargs, wrap_pp_kwargs: mapping, optional
1675
        kwargs passed to ``az.stats.wrap_xarray_ufunc``. By default, some suitable input_core_dims
1676
        are used.
1677
    inplace: bool, optional
1678
        If True, add the variables inplace, othewise, return a copy of idata with the variables
1679
        added.
1680
    overwrite: bool, optional
1681
        Overwrite data in case ``out_name_data`` or ``out_name_pp`` are already variables in
1682
        dataset. If ``None`` it will be the opposite of inplace.
1683

1684
    Returns
1685
    -------
1686
    idata: InferenceData
1687
        Output InferenceData object. If ``inplace=True``, it is the same input object modified
1688
        inplace.
1689

1690
    Notes
1691
    -----
1692
    This function is provided for convenience to wrap scalar or functions working on low
1693
    dims to inference data object. It is not optimized to be faster nor as fast as vectorized
1694
    computations.
1695

1696
    Examples
1697
    --------
1698
    Use ``apply_test_function`` to wrap ``np.min`` for illustration purposes. And plot the
1699
    results.
1700

1701
    .. plot::
1702
        :context: close-figs
1703

1704
        >>> import arviz as az
1705
        >>> idata = az.load_arviz_data("centered_eight")
1706
        >>> az.apply_test_function(idata, lambda y, theta: np.min(y))
1707
        >>> T = np.asscalar(idata.observed_data.T)
1708
        >>> az.plot_posterior(idata, var_names=["T"], group="posterior_predictive", ref_val=T)
1709

1710
    """
1711 2
    out = idata if inplace else deepcopy(idata)
1712

1713 2
    valid_groups = ("observed_data", "posterior_predictive", "both")
1714 2
    if group not in valid_groups:
1715 2
        raise ValueError(
1716
            "Invalid group argument. Must be one of {} not {}.".format(valid_groups, group)
1717
        )
1718 2
    if overwrite is None:
1719 2
        overwrite = not inplace
1720

1721 2
    if out_name_pp is None:
1722 2
        out_name_pp = out_name_data
1723

1724 2
    if func_args is None:
1725 2
        func_args = tuple()
1726

1727 2
    if func_kwargs is None:
1728 2
        func_kwargs = {}
1729

1730 2
    if ufunc_kwargs is None:
1731 2
        ufunc_kwargs = {}
1732 2
    ufunc_kwargs.setdefault("check_shape", False)
1733 2
    ufunc_kwargs.setdefault("ravel", False)
1734

1735 2
    if wrap_data_kwargs is None:
1736 2
        wrap_data_kwargs = {}
1737 2
    if wrap_pp_kwargs is None:
1738 2
        wrap_pp_kwargs = {}
1739 2
    if var_names is None:
1740 2
        var_names = {}
1741

1742 2
    both_var_names = var_names.pop("both", None)
1743 2
    var_names.setdefault("posterior", list(out.posterior.data_vars))
1744

1745 2
    in_posterior = out.posterior[var_names["posterior"]]
1746 2
    if isinstance(in_posterior, xr.Dataset):
1747 2
        in_posterior = in_posterior.to_array().squeeze()
1748

1749 2
    groups = ("posterior_predictive", "observed_data") if group == "both" else [group]
1750 2
    for grp in groups:
1751 2
        out_group_shape = out_data_shape if grp == "observed_data" else out_pp_shape
1752 2
        out_name_group = out_name_data if grp == "observed_data" else out_name_pp
1753 2
        wrap_group_kwargs = wrap_data_kwargs if grp == "observed_data" else wrap_pp_kwargs
1754 2
        if not hasattr(out, grp):
1755 2
            raise ValueError("InferenceData object must have {} group".format(grp))
1756 2
        if not overwrite and out_name_group in getattr(out, grp).data_vars:
1757 2
            raise ValueError(
1758
                "Should overwrite: {} variable present in group {}, but overwrite is False".format(
1759
                    out_name_group, grp
1760
                )
1761
            )
1762 2
        var_names.setdefault(
1763
            grp, list(getattr(out, grp).data_vars) if both_var_names is None else both_var_names
1764
        )
1765 2
        in_group = getattr(out, grp)[var_names[grp]]
1766 2
        if isinstance(in_group, xr.Dataset):
1767 2
            in_group = in_group.to_array(dim="{}_var".format(grp)).squeeze()
1768

1769 2
        if pointwise:
1770 2
            out_group_shape = in_group.shape if out_group_shape is None else out_group_shape
1771 2
        elif grp == "observed_data":
1772 2
            out_group_shape = () if out_group_shape is None else out_group_shape
1773 2
        elif grp == "posterior_predictive":
1774 2
            out_group_shape = in_group.shape[:2] if out_group_shape is None else out_group_shape
1775 2
        loop_dims = in_group.dims[: len(out_group_shape)]
1776

1777 2
        wrap_group_kwargs.setdefault(
1778
            "input_core_dims",
1779
            [
1780
                [dim for dim in dataset.dims if dim not in loop_dims]
1781
                for dataset in [in_group, in_posterior]
1782
            ],
1783
        )
1784 2
        func_kwargs["out"] = np.empty(out_group_shape)
1785

1786 2
        out_group = getattr(out, grp)
1787 2
        try:
1788 2
            out_group[out_name_group] = _wrap_xarray_ufunc(
1789
                func,
1790
                in_group.values,
1791
                in_posterior.values,
1792
                func_args=func_args,
1793
                func_kwargs=func_kwargs,
1794
                ufunc_kwargs=ufunc_kwargs,
1795
                **wrap_group_kwargs,
1796
            )
1797 2
        except IndexError:
1798 2
            excluded_dims = set(
1799
                wrap_group_kwargs["input_core_dims"][0] + wrap_group_kwargs["input_core_dims"][1]
1800
            )
1801 2
            out_group[out_name_group] = _wrap_xarray_ufunc(
1802
                func,
1803
                *xr.broadcast(in_group, in_posterior, exclude=excluded_dims),
1804
                func_args=func_args,
1805
                func_kwargs=func_kwargs,
1806
                ufunc_kwargs=ufunc_kwargs,
1807
                **wrap_group_kwargs,
1808
            )
1809 2
        setattr(out, grp, out_group)
1810

1811 2
    return out

Read our documentation on viewing source code .

Loading