1
# pylint: disable=too-many-nested-blocks
2 6
"""General utilities."""
3 6
import functools
4 6
import importlib
5 6
import re
6 6
import warnings
7 6
from functools import lru_cache
8

9 6
import matplotlib.pyplot as plt
10 6
import numpy as np
11 6
import pkg_resources
12 6
from numpy import newaxis
13

14 6
from .rcparams import rcParams
15

16 6
STATIC_FILES = ("static/html/icons-svg-inline.html", "static/css/style.css")
17

18

19 6
def _var_names(var_names, data, filter_vars=None):
20
    """Handle var_names input across arviz.
21

22
    Parameters
23
    ----------
24
    var_names: str, list, or None
25
    data : xarray.Dataset
26
        Posterior data in an xarray
27
    filter_vars: {None, "like", "regex"}, optional, default=None
28
        If `None` (default), interpret var_names as the real variables names. If "like",
29
         interpret var_names as substrings of the real variables names. If "regex",
30
         interpret var_names as regular expressions on the real variables names. A la
31
        `pandas.filter`.
32

33
    Returns
34
    -------
35
    var_name: list or None
36
    """
37 6
    if var_names is not None:
38 6
        if isinstance(data, (list, tuple)):
39 4
            all_vars = []
40 4
            for dataset in data:
41 4
                dataset_vars = list(dataset.data_vars)
42 4
                for var in dataset_vars:
43 4
                    if var not in all_vars:
44 4
                        all_vars.append(var)
45
        else:
46 6
            all_vars = list(data.data_vars)
47

48 6
        all_vars_tilde = [var for var in all_vars if var.startswith("~")]
49 6
        if all_vars_tilde:
50 4
            warnings.warn(
51
                """ArviZ treats '~' as a negation character for variable selection.
52
                   Your model has variables names starting with '~', {0}. Please double check
53
                   your results to ensure all variables are included""".format(
54
                    ", ".join(all_vars_tilde)
55
                )
56
            )
57

58 6
        try:
59 6
            var_names = _subset_list(var_names, all_vars, filter_items=filter_vars, warn=False)
60 4
        except KeyError as err:
61 4
            msg = " ".join(("var names:", f"{err}", "in dataset"))
62 4
            raise KeyError(msg) from err
63 6
    return var_names
64

65

66 6
def _subset_list(subset, whole_list, filter_items=None, warn=True):
67
    """Handle list subsetting (var_names, groups...) across arviz.
68

69
    Parameters
70
    ----------
71
    subset : str, list, or None
72
    whole_list : list
73
        List from which to select a subset according to subset elements and
74
        filter_items value.
75
    filter_items : {None, "like", "regex"}, optional
76
        If `None` (default), interpret `subset` as the exact elements in `whole_list`
77
        names. If "like", interpret `subset` as substrings of the elements in
78
        `whole_list`. If "regex", interpret `subset` as regular expressions to match
79
        elements in `whole_list`. A la `pandas.filter`.
80

81
    Returns
82
    -------
83
    list or None
84
        A subset of ``whole_list`` fulfilling the requests imposed by ``subset``
85
        and ``filter_items``.
86
    """
87 6
    if subset is not None:
88

89 6
        if isinstance(subset, str):
90 4
            subset = [subset]
91

92 6
        whole_list_tilde = [item for item in whole_list if item.startswith("~")]
93 6
        if whole_list_tilde and warn:
94 0
            warnings.warn(
95
                "ArviZ treats '~' as a negation character for selection. There are "
96
                "elements in `whole_list` starting with '~', {0}. Please double check"
97
                "your results to ensure all elements are included".format(
98
                    ", ".join(whole_list_tilde)
99
                )
100
            )
101

102 6
        excluded_items = [
103
            item[1:] for item in subset if item.startswith("~") and item not in whole_list
104
        ]
105 6
        filter_items = str(filter_items).lower()
106 6
        not_found = []
107

108 6
        if excluded_items:
109 4
            if filter_items in ("like", "regex"):
110 4
                for pattern in excluded_items[:]:
111 4
                    excluded_items.remove(pattern)
112 4
                    if filter_items == "like":
113 4
                        real_items = [real_item for real_item in whole_list if pattern in real_item]
114
                    else:
115
                        # i.e filter_items == "regex"
116 4
                        real_items = [
117
                            real_item for real_item in whole_list if re.search(pattern, real_item)
118
                        ]
119 4
                    if not real_items:
120 4
                        not_found.append(pattern)
121 4
                    excluded_items.extend(real_items)
122 4
            not_found.extend([item for item in excluded_items if item not in whole_list])
123 4
            if not_found:
124 4
                warnings.warn(
125
                    f"Items starting with ~: {not_found} have not been found and will be ignored"
126
                )
127 4
            subset = [item for item in whole_list if item not in excluded_items]
128

129
        else:
130 6
            if filter_items == "like":
131 4
                subset = [item for item in whole_list for name in subset if name in item]
132 6
            elif filter_items == "regex":
133 4
                subset = [item for item in whole_list for name in subset if re.search(name, item)]
134

135 6
        existing_items = np.isin(subset, whole_list)
136 6
        if not np.all(existing_items):
137 4
            raise KeyError("{} are not present".format(np.array(subset)[~existing_items]))
138

139 6
    return subset
140

141

142 6
class lazy_property:  # pylint: disable=invalid-name
143
    """Used to load numba first time it is needed."""
144

145 6
    def __init__(self, fget):
146
        """Lazy load a property with `fget`."""
147 6
        self.fget = fget
148

149
        # copy the getter function's docstring and other attributes
150 6
        functools.update_wrapper(self, fget)
151

152 6
    def __get__(self, obj, cls):
153
        """Call the function, set the attribute."""
154 6
        if obj is None:
155 0
            return self
156

157 6
        value = self.fget(obj)
158 6
        setattr(obj, self.fget.__name__, value)
159 6
        return value
160

161

162 6
class maybe_numba_fn:  # pylint: disable=invalid-name
163
    """Wrap a function to (maybe) use a (lazy) jit-compiled version."""
164

165 6
    def __init__(self, function, **kwargs):
166
        """Wrap a function and save compilation keywords."""
167 6
        self.function = function
168 6
        self.kwargs = kwargs
169

170 6
    @lazy_property
171 4
    def numba_fn(self):
172
        """Memoized compiled function."""
173 6
        try:
174 6
            numba = importlib.import_module("numba")
175 6
            numba_fn = numba.jit(**self.kwargs)(self.function)
176 4
        except ImportError:
177 4
            numba_fn = self.function
178 6
        return numba_fn
179

180 6
    def __call__(self, *args, **kwargs):
181
        """Call the jitted function or normal, depending on flag."""
182 6
        if Numba.numba_flag:
183 6
            return self.numba_fn(*args, **kwargs)
184
        else:
185 4
            return self.function(*args, **kwargs)
186

187

188 6
class interactive_backend:  # pylint: disable=invalid-name
189
    """Context manager to change backend temporarily in ipython sesson.
190

191
    It uses ipython magic to change temporarily from the ipython inline backend to
192
    an interactive backend of choice. It cannot be used outside ipython sessions nor
193
    to change backends different than inline -> interactive.
194

195
    Notes
196
    -----
197
    The first time ``interactive_backend`` context manager is called, any of the available
198
    interactive backends can be chosen. The following times, this same backend must be used
199
    unless the kernel is restarted.
200

201
    Parameters
202
    ----------
203
    backend : str, optional
204
        Interactive backend to use. It will be passed to ``%matplotlib`` magic, refer to
205
        its docs to see available options.
206

207
    Examples
208
    --------
209
    Inside an ipython session (i.e. a jupyter notebook) with the inline backend set:
210

211
    .. code::
212

213
        >>> import arviz as az
214
        >>> idata = az.load_arviz_data("centered_eight")
215
        >>> az.plot_posterior(idata) # inline
216
        >>> with az.interactive_backend():
217
        ...     az.plot_density(idata) # interactive
218
        >>> az.plot_trace(idata) # inline
219

220
    """
221

222
    # based on matplotlib.rc_context
223 6
    def __init__(self, backend=""):
224
        """Initialize context manager."""
225 0
        try:
226 0
            from IPython import get_ipython
227 0
        except ImportError as err:
228 0
            raise ImportError(
229
                "The exception below was risen while importing Ipython, this "
230
                "context manager can only be used inside ipython sessions:\n{}".format(err)
231
            ) from err
232 0
        self.ipython = get_ipython()
233 0
        if self.ipython is None:
234 0
            raise EnvironmentError("This context manager can only be used inside ipython sessions")
235 0
        self.ipython.magic("matplotlib {}".format(backend))
236

237 6
    def __enter__(self):
238
        """Enter context manager."""
239 0
        return self
240

241 6
    def __exit__(self, exc_type, exc_value, exc_tb):
242
        """Exit context manager."""
243 0
        plt.show(block=True)
244 0
        self.ipython.magic("matplotlib inline")
245

246

247 6
def conditional_jit(_func=None, **kwargs):
248
    """Use numba's jit decorator if numba is installed.
249

250
    Notes
251
    -----
252
        If called without arguments  then return wrapped function.
253

254
        @conditional_jit
255
        def my_func():
256
            return
257

258
        else called with arguments
259

260
        @conditional_jit(nopython=True)
261
        def my_func():
262
            return
263

264
    """
265 6
    if _func is None:
266 6
        return lambda fn: functools.wraps(fn)(maybe_numba_fn(fn, **kwargs))
267
    else:
268 6
        lazy_numba = maybe_numba_fn(_func, **kwargs)
269 6
        return functools.wraps(_func)(lazy_numba)
270

271

272 6
def conditional_vect(function=None, **kwargs):  # noqa: D202
273
    """Use numba's vectorize decorator if numba is installed.
274

275
    Notes
276
    -----
277
        If called without arguments  then return wrapped function.
278
        @conditional_vect
279
        def my_func():
280
            return
281
        else called with arguments
282
        @conditional_vect(nopython=True)
283
        def my_func():
284
            return
285

286
    """
287

288 6
    def wrapper(function):
289 6
        try:
290 6
            numba = importlib.import_module("numba")
291 6
            return numba.vectorize(**kwargs)(function)
292

293 4
        except ImportError:
294 4
            return function
295

296 6
    if function:
297 6
        return wrapper(function)
298
    else:
299 4
        return wrapper
300

301

302 6
def numba_check():
303
    """Check if numba is installed."""
304 6
    numba = importlib.util.find_spec("numba")
305 6
    return numba is not None
306

307

308 6
class Numba:
309
    """A class to toggle numba states."""
310

311 6
    numba_flag = numba_check()
312

313 6
    @classmethod
314 4
    def disable_numba(cls):
315
        """To disable numba."""
316 4
        cls.numba_flag = False
317

318 6
    @classmethod
319 4
    def enable_numba(cls):
320
        """To enable numba."""
321 4
        if numba_check():
322 4
            cls.numba_flag = True
323
        else:
324 0
            raise ValueError("Numba is not installed")
325

326

327 6
def _numba_var(numba_function, standard_numpy_func, data, axis=None, ddof=0):
328
    """Replace the numpy methods used to calculate variance.
329

330
    Parameters
331
    ----------
332
    numba_function : function()
333
        Custom numba function included in stats/stats_utils.py.
334

335
    standard_numpy_func: function()
336
        Standard function included in the numpy library.
337

338
    data : array.
339
    axis : axis along which the variance is calculated.
340
    ddof : degrees of freedom allowed while calculating variance.
341

342
    Returns
343
    -------
344
    array:
345
        variance values calculate by appropriate function for numba speedup
346
        if Numba is installed or enabled.
347

348
    """
349 6
    if Numba.numba_flag:
350 6
        return numba_function(data, axis=axis, ddof=ddof)
351
    else:
352 4
        return standard_numpy_func(data, axis=axis, ddof=ddof)
353

354

355 6
def _stack(x, y):
356 6
    assert x.shape[1:] == y.shape[1:]
357 6
    return np.vstack((x, y))
358

359

360 6
def arange(x):
361
    """Jitting numpy arange."""
362 6
    return np.arange(x)
363

364

365 6
def one_de(x):
366
    """Jitting numpy atleast_1d."""
367 6
    if not isinstance(x, np.ndarray):
368 6
        return np.atleast_1d(x)
369 6
    if x.ndim == 0:
370 2
        result = x.reshape(1)
371
    else:
372 6
        result = x
373 6
    return result
374

375

376 6
def two_de(x):
377
    """Jitting numpy at_least_2d."""
378 6
    if not isinstance(x, np.ndarray):
379 6
        return np.atleast_2d(x)
380 6
    if x.ndim == 0:
381 0
        result = x.reshape(1, 1)
382 6
    elif x.ndim == 1:
383 4
        result = x[newaxis, :]
384
    else:
385 6
        result = x
386 6
    return result
387

388

389 6
def expand_dims(x):
390
    """Jitting numpy expand_dims."""
391 6
    if not isinstance(x, np.ndarray):
392 6
        return np.expand_dims(x, 0)
393 6
    shape = x.shape
394 6
    return x.reshape(shape[:0] + (1,) + shape[0:])
395

396

397 6
@conditional_jit(cache=True, nopython=True)
398 4
def _dot(x, y):
399 4
    return np.dot(x, y)
400

401

402 6
@conditional_jit(cache=True, nopython=True)
403 4
def _cov_1d(x):
404 4
    x = x - x.mean(axis=0)
405 4
    ddof = x.shape[0] - 1
406 4
    return np.dot(x.T, x.conj()) / ddof
407

408

409 6
@conditional_jit(cache=True)
410 4
def _cov(data):
411 4
    if data.ndim == 1:
412 4
        return _cov_1d(data)
413 4
    elif data.ndim == 2:
414 4
        x = data.astype(float)
415 4
        avg, _ = np.average(x, axis=1, weights=None, returned=True)
416 4
        ddof = x.shape[1] - 1
417 4
        if ddof <= 0:
418 0
            warnings.warn("Degrees of freedom <= 0 for slice", RuntimeWarning, stacklevel=2)
419 0
            ddof = 0.0
420 4
        x -= avg[:, None]
421 4
        prod = _dot(x, x.T.conj())
422 4
        prod *= np.true_divide(1, ddof)
423 4
        prod = prod.squeeze()
424 4
        prod += 1e-6 * np.eye(prod.shape[0])
425 4
        return prod
426
    else:
427 4
        raise ValueError("{} dimension arrays are not supported".format(data.ndim))
428

429

430 6
@conditional_jit(nopython=True)
431 6
def full(shape, x, dtype=None):
432
    """Jitting numpy full."""
433 6
    return np.full(shape, x, dtype=dtype)
434

435

436 6
def flatten_inference_data_to_dict(
437
    data,
438
    var_names=None,
439
    groups=None,
440
    dimensions=None,
441
    group_info=False,
442
    var_name_format=None,
443
    index_origin=None,
444
):
445
    """Transform data to dictionary.
446

447
    Parameters
448
    ----------
449
    data : obj
450
        Any object that can be converted to an az.InferenceData object
451
        Refer to documentation of az.convert_to_inference_data for details
452
    var_names : str or list of str, optional
453
        Variables to be processed, if None all variables are processed.
454
    groups : str or list of str, optional
455
        Select groups for CDS. Default groups are
456
        {"posterior_groups", "prior_groups", "posterior_groups_warmup"}
457
            - posterior_groups: posterior, posterior_predictive, sample_stats
458
            - prior_groups: prior, prior_predictive, sample_stats_prior
459
            - posterior_groups_warmup: warmup_posterior, warmup_posterior_predictive,
460
                                       warmup_sample_stats
461
    ignore_groups : str or list of str, optional
462
        Ignore specific groups from CDS.
463
    dimension : str, or list of str, optional
464
        Select dimensions along to slice the data. By default uses ("chain", "draw").
465
    group_info : bool
466
        Add group info for `var_name_format`
467
    var_name_format : str or tuple of tuple of string, optional
468
        Select column name format for non-scalar input.
469
        Predefined options are {"brackets", "underscore", "cds"}
470
            "brackets":
471
                - add_group_info == False: theta[0,0]
472
                - add_group_info == True: theta_posterior[0,0]
473
            "underscore":
474
                - add_group_info == False: theta_0_0
475
                - add_group_info == True: theta_posterior_0_0_
476
            "cds":
477
                - add_group_info == False: theta_ARVIZ_CDS_SELECTION_0_0
478
                - add_group_info == True: theta_ARVIZ_GROUP_posterior__ARVIZ_CDS_SELECTION_0_0
479
            tuple:
480
                Structure:
481
                    tuple: (dim_info, group_info)
482
                        dim_info: (str: `.join` separator,
483
                                   str: dim_separator_start,
484
                                   str: dim_separator_end)
485
                        group_info: (str: group separator start, str: group separator end)
486
                Example: ((",", "[", "]"), ("_", ""))
487
                    - add_group_info == False: theta[0,0]
488
                    - add_group_info == True: theta_posterior[0,0]
489
    index_origin : int, optional
490
        Start parameter indices from `index_origin`. Either 0 or 1.
491

492
    Returns
493
    -------
494
    dict
495
    """
496 4
    from .data import convert_to_inference_data
497

498 4
    data = convert_to_inference_data(data)
499

500 4
    if groups is None:
501 4
        groups = ["posterior", "posterior_predictive", "sample_stats"]
502 4
    elif isinstance(groups, str):
503 4
        if groups.lower() == "posterior_groups":
504 4
            groups = ["posterior", "posterior_predictive", "sample_stats"]
505 4
        elif groups.lower() == "prior_groups":
506 4
            groups = ["prior", "prior_predictive", "sample_stats_prior"]
507 0
        elif groups.lower() == "posterior_groups_warmup":
508 0
            groups = ["warmup_posterior", "warmup_posterior_predictive", "warmup_sample_stats"]
509
        else:
510 0
            raise TypeError(
511
                (
512
                    "Valid predefined groups are "
513
                    "{posterior_groups, prior_groups, posterior_groups_warmup}"
514
                )
515
            )
516

517 4
    if dimensions is None:
518 4
        dimensions = "chain", "draw"
519 4
    elif isinstance(dimensions, str):
520 4
        dimensions = (dimensions,)
521

522 4
    if var_name_format is None:
523 4
        var_name_format = "brackets"
524

525 4
    if isinstance(var_name_format, str):
526 4
        var_name_format = var_name_format.lower()
527

528 4
    if var_name_format == "brackets":
529 4
        dim_join_separator, dim_separator_start, dim_separator_end = ",", "[", "]"
530 4
        group_separator_start, group_separator_end = "_", ""
531 4
    elif var_name_format == "underscore":
532 4
        dim_join_separator, dim_separator_start, dim_separator_end = "_", "_", ""
533 4
        group_separator_start, group_separator_end = "_", ""
534 4
    elif var_name_format == "cds":
535 4
        dim_join_separator, dim_separator_start, dim_separator_end = (
536
            "_",
537
            "_ARVIZ_CDS_SELECTION_",
538
            "",
539
        )
540 4
        group_separator_start, group_separator_end = "_ARVIZ_GROUP_", ""
541 4
    elif isinstance(var_name_format, str):
542 0
        msg = 'Invalid predefined format. Select one {"brackets", "underscore", "cds"}'
543 0
        raise TypeError(msg)
544
    else:
545 4
        (
546
            (dim_join_separator, dim_separator_start, dim_separator_end),
547
            (group_separator_start, group_separator_end),
548
        ) = var_name_format
549

550 4
    if index_origin is None:
551 4
        index_origin = rcParams["data.index_origin"]
552

553 4
    data_dict = {}
554 4
    for group in groups:
555 4
        if hasattr(data, group):
556 4
            group_data = getattr(data, group).stack(stack_dimension=dimensions)
557 4
            for var_name, var in group_data.data_vars.items():
558 4
                var_values = var.values
559 4
                if var_names is not None and var_name not in var_names:
560 4
                    continue
561 4
                for dim_name in dimensions:
562 4
                    if dim_name not in data_dict:
563 4
                        data_dict[dim_name] = var.coords.get(dim_name).values
564 4
                if len(var.shape) == 1:
565 4
                    if group_info:
566 4
                        var_name_dim = (
567
                            "{var_name}" "{group_separator_start}{group}{group_separator_end}"
568
                        ).format(
569
                            var_name=var_name,
570
                            group_separator_start=group_separator_start,
571
                            group=group,
572
                            group_separator_end=group_separator_end,
573
                        )
574
                    else:
575 4
                        var_name_dim = "{var_name}".format(var_name=var_name)
576 4
                    data_dict[var_name_dim] = var.values
577
                else:
578 4
                    for loc in np.ndindex(var.shape[:-1]):
579 4
                        if group_info:
580 4
                            var_name_dim = (
581
                                "{var_name}"
582
                                "{group_separator_start}{group}{group_separator_end}"
583
                                "{dim_separator_start}{dim_join}{dim_separator_end}"
584
                            ).format(
585
                                var_name=var_name,
586
                                group_separator_start=group_separator_start,
587
                                group=group,
588
                                group_separator_end=group_separator_end,
589
                                dim_separator_start=dim_separator_start,
590
                                dim_join=dim_join_separator.join(
591
                                    (str(item + index_origin) for item in loc)
592
                                ),
593
                                dim_separator_end=dim_separator_end,
594
                            )
595
                        else:
596 4
                            var_name_dim = (
597
                                "{var_name}" "{dim_separator_start}{dim_join}{dim_separator_end}"
598
                            ).format(
599
                                var_name=var_name,
600
                                dim_separator_start=dim_separator_start,
601
                                dim_join=dim_join_separator.join(
602
                                    (str(item + index_origin) for item in loc)
603
                                ),
604
                                dim_separator_end=dim_separator_end,
605
                            )
606

607 4
                        data_dict[var_name_dim] = var_values[loc]
608 4
    return data_dict
609

610

611 6
def get_coords(data, coords):
612
    """Subselects xarray DataSet or DataArray object to provided coords. Raises exception if fails.
613

614
    Raises
615
    ------
616
    ValueError
617
        If coords name are not available in data
618

619
    KeyError
620
        If coords dims are not available in data
621

622
    Returns
623
    -------
624
    data: xarray
625
        xarray.DataSet or xarray.DataArray object, same type as input
626
    """
627 4
    if not isinstance(data, (list, tuple)):
628 4
        try:
629 4
            return data.sel(**coords)
630

631 4
        except ValueError as err:
632 4
            invalid_coords = set(coords.keys()) - set(data.coords.keys())
633 4
            raise ValueError(
634
                "Coords {} are invalid coordinate keys".format(invalid_coords)
635
            ) from err
636

637 4
        except KeyError as err:
638 4
            raise KeyError(
639
                (
640
                    "Coords should follow mapping format {{coord_name:[dim1, dim2]}}. "
641
                    "Check that coords structure is correct and"
642
                    " dimensions are valid. {}"
643
                ).format(err)
644
            ) from err
645 4
    if not isinstance(coords, (list, tuple)):
646 4
        coords = [coords] * len(data)
647 4
    data_subset = []
648 4
    for idx, (datum, coords_dict) in enumerate(zip(data, coords)):
649 4
        try:
650 4
            data_subset.append(get_coords(datum, coords_dict))
651 4
        except ValueError as err:
652 4
            raise ValueError("Error in data[{}]: {}".format(idx, err)) from err
653 4
        except KeyError as err:
654 4
            raise KeyError("Error in data[{}]: {}".format(idx, err)) from err
655 4
    return data_subset
656

657

658 6
def credible_interval_warning(credible_interval, hdi_prob):
659
    """Replace credible_interval with hdi_prob and to warns of be deprecation."""
660 0
    warnings.warn(
661
        ("Keyword argument credible_interval has been deprecated " "Please replace with hdi_prob"),
662
    )
663

664 0
    if isinstance(credible_interval, str) and credible_interval == "auto":
665 0
        raise Exception("Argument value 'auto' has been renamed to 'hide'")
666

667 0
    if hdi_prob:
668 0
        raise Exception(
669
            "Both 'credible_interval' and 'hdi_prob' are in "
670
            "keyword arguments. Please remove 'credible_interval'"
671
        )
672

673 0
    hdi_prob = credible_interval
674 0
    return hdi_prob
675

676

677 6
@lru_cache(None)
678 4
def _load_static_files():
679
    """Lazily load the resource files into memory the first time they are needed.
680

681
    Clone from xarray.core.formatted_html_template.
682
    """
683 6
    return [pkg_resources.resource_string("arviz", fname).decode("utf8") for fname in STATIC_FILES]
684

685

686 6
class HtmlTemplate:
687
    """Contain html templates for InferenceData repr."""
688

689 6
    html_template = """
690
            <div>
691
              <div class='xr-header'>
692
                <div class="xr-obj-type">arviz.InferenceData</div>
693
              </div>
694
              <ul class="xr-sections group-sections">
695
              {}
696
              </ul>
697
            </div>
698
            """
699 6
    element_template = """
700
            <li class = "xr-section-item">
701
                  <input id="idata_{group_id}" class="xr-section-summary-in" type="checkbox">
702
                  <label for="idata_{group_id}" class = "xr-section-summary">{group}</label>
703
                  <div class="xr-section-inline-details"></div>
704
                  <div class="xr-section-details">
705
                      <ul id="xr-dataset-coord-list" class="xr-var-list">
706
                          <div style="padding-left:2rem;">{xr_data}<br></div>
707
                      </ul>
708
                  </div>
709
            </li>
710
            """
711 6
    _, css_style = _load_static_files()  # pylint: disable=protected-access
712 6
    specific_style = ".xr-wrap{width:700px!important;}"
713 6
    css_template = f"<style> {css_style}{specific_style} </style>"
714

715

716 6
def either_dict_or_kwargs(
717
    pos_kwargs,
718
    kw_kwargs,
719
    func_name,
720
):
721
    """Clone from xarray.core.utils."""
722 4
    if pos_kwargs is not None:
723 4
        if not hasattr(pos_kwargs, "keys") and hasattr(pos_kwargs, "__getitem__"):
724 0
            raise ValueError("the first argument to .%s must be a dictionary" % func_name)
725 4
        if kw_kwargs:
726 0
            raise ValueError(
727
                "cannot specify both keyword and positional " "arguments to .%s" % func_name
728
            )
729 4
        return pos_kwargs
730
    else:
731 4
        return kw_kwargs
732

733

734 6
class Dask:
735
    """Class to toggle Dask states.
736

737
    Warnings
738
    --------
739
    Dask integration is an experimental feature still in progress. It can already be used
740
    but it doesn't work with all stats nor diagnostics yet.
741
    """
742

743 6
    dask_flag = False
744 6
    dask_kwargs = None
745

746 6
    @classmethod
747 6
    def enable_dask(cls, dask_kwargs=None):
748
        """To enable Dask.
749

750
        Parameters
751
        ----------
752
        dask_kwargs : dict
753
            Dask related kwargs passed to :func:`~arviz.wrap_xarray_ufunc`.
754
        """
755 0
        cls.dask_flag = True
756 0
        cls.dask_kwargs = dask_kwargs
757

758 6
    @classmethod
759 4
    def disable_dask(cls):
760
        """To disable Dask."""
761 0
        cls.dask_flag = False
762 0
        cls.dask_kwargs = None
763

764

765 6
def conditional_dask(func):
766
    """Conditionally pass dask kwargs to `wrap_xarray_ufunc`."""
767

768 6
    @functools.wraps(func)
769 4
    def wrapper(*args, **kwargs):
770

771 6
        if Dask.dask_flag:
772 0
            user_kwargs = kwargs.pop("dask_kwargs", None)
773 0
            if user_kwargs is None:
774 0
                user_kwargs = {}
775 0
            default_kwargs = Dask.dask_kwargs
776 0
            return func(dask_kwargs={**default_kwargs, **user_kwargs}, *args, **kwargs)
777
        else:
778 6
            return func(*args, **kwargs)
779

780 6
    return wrapper

Read our documentation on viewing source code .

Loading