1
# pylint: disable=too-many-lines
2 6
"""Data structure for using netcdf groups with xarray."""
3 6
import uuid
4 6
import warnings
5 6
from collections import OrderedDict, defaultdict
6 6
from collections.abc import Sequence
7 6
from copy import copy as ccopy
8 6
from copy import deepcopy
9 6
from datetime import datetime
10 6
from html import escape
11

12 6
import netCDF4 as nc
13 6
import numpy as np
14 6
import xarray as xr
15 4
from xarray.core.options import OPTIONS
16 6
from xarray.core.utils import either_dict_or_kwargs
17 2

18 6
from ..rcparams import rcParams
19 4
from ..utils import HtmlTemplate, _subset_list
20 6
from .base import _extend_xr_method, dict_to_dataset, _make_json_serializable
21 2

22 4
try:
23 4
    import ujson as json
24 0
except ImportError:
25 2
    import json
26

27 4
SUPPORTED_GROUPS = [
28
    "posterior",
29
    "posterior_predictive",
30
    "predictions",
31
    "log_likelihood",
32
    "sample_stats",
33
    "prior",
34
    "prior_predictive",
35
    "sample_stats_prior",
36
    "observed_data",
37
    "constant_data",
38
    "predictions_constant_data",
39 2
]
40

41 6
WARMUP_TAG = "warmup_"
42

43 4
SUPPORTED_GROUPS_WARMUP = [
44
    "{}posterior".format(WARMUP_TAG),
45
    "{}posterior_predictive".format(WARMUP_TAG),
46
    "{}predictions".format(WARMUP_TAG),
47
    "{}sample_stats".format(WARMUP_TAG),
48
    "{}log_likelihood".format(WARMUP_TAG),
49 2
]
50

51 4
SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
52 2

53

54 4
class InferenceData:
55
    """Container for inference data storage using xarray.
56

57
    For a detailed introduction to ``InferenceData`` objects and their usage, see
58
    :doc:`/notebooks/XarrayforArviZ`. This page provides help and documentation
59
    on ``InferenceData`` methods and their low level implementation.
60 2
    """
61

62 4
    def __init__(self, attrs=None, **kwargs):
63
        """Initialize InferenceData object from keyword xarray datasets.
64

65
        Parameters
66
        ----------
67
        attrs : dict
68
            sets global atrribute for InferenceData object.
69
        kwargs :
70
            Keyword arguments of xarray datasets
71

72
        Examples
73
        --------
74
        Initiate an InferenceData object from scratch, not recommended. InferenceData
75
        objects should be initialized using ``from_xyz`` methods, see :ref:`data_api` for more
76
        details.
77

78
        .. ipython::
79

80
            In [1]: import arviz as az
81
               ...: import numpy as np
82
               ...: import xarray as xr
83
               ...: dataset = xr.Dataset(
84
               ...:     {
85
               ...:         "a": (["chain", "draw", "a_dim"], np.random.normal(size=(4, 100, 3))),
86
               ...:         "b": (["chain", "draw"], np.random.normal(size=(4, 100))),
87
               ...:     },
88
               ...:     coords={
89
               ...:         "chain": (["chain"], np.arange(4)),
90
               ...:         "draw": (["draw"], np.arange(100)),
91
               ...:         "a_dim": (["a_dim"], ["x", "y", "z"]),
92
               ...:     }
93
               ...: )
94
               ...: idata = az.InferenceData(posterior=dataset, prior=dataset)
95
               ...: idata
96

97
        We have created an ``InferenceData`` object with two groups. Now we can check its
98
        contents:
99

100
        .. ipython::
101

102
            In [1]: idata.posterior
103 2

104 2
        """
105 6
        self._groups = []
106 6
        self._groups_warmup = []
107 6
        self._attrs = dict(attrs) if attrs is not None else None
108 6
        save_warmup = kwargs.pop("save_warmup", False)
109 6
        key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
110 6
        for key in kwargs:
111 6
            if key not in SUPPORTED_GROUPS_ALL:
112 4
                key_list.append(key)
113 4
                warnings.warn(
114 2
                    "{} group is not defined in the InferenceData scheme".format(key), UserWarning
115 2
                )
116 6
        for key in key_list:
117 6
            dataset = kwargs[key]
118 6
            dataset_warmup = None
119 6
            if dataset is None:
120 6
                continue
121 6
            elif isinstance(dataset, (list, tuple)):
122 4
                dataset, dataset_warmup = kwargs[key]
123 4
            elif not isinstance(dataset, xr.Dataset):
124 0
                raise ValueError(
125
                    "Arguments to InferenceData must be xarray Datasets "
126 2
                    "(argument '{}' was type '{}')".format(key, type(dataset))
127 2
                )
128 6
            if not key.startswith(WARMUP_TAG):
129 6
                if dataset:
130 4
                    setattr(self, key, dataset)
131 4
                    self._groups.append(key)
132 0
            elif key.startswith(WARMUP_TAG):
133 0
                if dataset:
134 2
                    setattr(self, key, dataset)
135 2
                    self._groups_warmup.append(key)
136 6
            if save_warmup and dataset_warmup is not None:
137 6
                if dataset_warmup:
138 6
                    key = "{}{}".format(WARMUP_TAG, key)
139 4
                    setattr(self, key, dataset_warmup)
140 6
                    self._groups_warmup.append(key)
141 1

142 4
    @property
143 4
    def attrs(self):
144 2
        """Attributes of InferenceData object."""
145 6
        if self._attrs is None:
146 4
            self._attrs = {}
147 6
        return self._attrs
148 1

149 4
    @attrs.setter
150 2
    def attrs(self, value):
151 2
        self._attrs = dict(value)
152

153 6
    def __repr__(self):
154
        """Make string representation of InferenceData object."""
155 3
        msg = "Inference data with groups:\n\t> {options}".format(
156 2
            options="\n\t> ".join(self._groups)
157 0
        )
158 5
        if self._groups_warmup:
159 0
            msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG)
160 5
        return msg
161

162 6
    def _repr_html_(self):
163 2
        """Make html representation of InferenceData object."""
164 3
        display_style = OPTIONS["display_style"]
165 5
        if display_style == "text":
166 5
            html_repr = f"<pre>{escape(repr(self))}</pre>"
167 2
        else:
168 3
            elements = "".join(
169 2
                [
170
                    HtmlTemplate.element_template.format(
171
                        group_id=group + str(uuid.uuid4()),
172
                        group=group,
173
                        xr_data=getattr(  # pylint: disable=protected-access
174
                            self, group
175
                        )._repr_html_(),
176
                    )
177
                    for group in self._groups_all
178
                ]
179
            )
180 3
            formatted_html_template = (  # pylint: disable=possibly-unused-variable
181 2
                HtmlTemplate.html_template.format(elements)
182
            )
183 3
            css_template = HtmlTemplate.css_template  # pylint: disable=possibly-unused-variable
184 5
            html_repr = "%(formatted_html_template)s%(css_template)s" % locals()
185 5
        return html_repr
186 0

187 4
    def __delattr__(self, group):
188 2
        """Delete a group from the InferenceData object."""
189 3
        if group in self._groups:
190 5
            self._groups.remove(group)
191 0
        elif group in self._groups_warmup:
192 2
            self._groups_warmup.remove(group)
193 5
        object.__delattr__(self, group)
194 0

195 4
    @property
196 4
    def _groups_all(self):
197 4
        return self._groups + self._groups_warmup
198 2

199 5
    @staticmethod
200 4
    def from_netcdf(filename):
201
        """Initialize object from a netcdf file.
202 2

203 1
        Expects that the file will have groups, each of which can be loaded by xarray.
204
        By default, the datasets of the InferenceData object will be lazily loaded instead
205
        of being loaded into memory. This
206
        behaviour is regulated by the value of ``az.rcParams["data.load"]``.
207

208
        Parameters
209
        ----------
210
        filename : str
211
            location of netcdf file
212

213
        Returns
214
        -------
215
        InferenceData object
216
        """
217 3
        groups = {}
218 3
        with nc.Dataset(filename, mode="r") as data:
219 3
            data_groups = list(data.groups)
220 2

221 5
        for group in data_groups:
222 5
            with xr.open_dataset(filename, group=group) as data:
223 3
                if rcParams["data.load"] == "eager":
224 5
                    groups[group] = data.load()
225 2
                else:
226 5
                    groups[group] = data
227 5
        return InferenceData(**groups)
228

229 6
    def to_netcdf(self, filename, compress=True, groups=None):
230 2
        """Write InferenceData to file using netcdf4.
231

232 2
        Parameters
233
        ----------
234
        filename : str
235
            Location to write to
236
        compress : bool, optional
237
            Whether to compress result. Note this saves disk space, but may make
238
            saving and loading somewhat slower (default: True).
239
        groups : list, optional
240
            Write only these groups to netcdf file.
241

242
        Returns
243
        -------
244
        str
245
            Location of netcdf file
246
        """
247 3
        mode = "w"  # overwrite first, then append
248 3
        if self._groups_all:  # check's whether a group is present or not.
249 3
            if groups is None:
250 5
                groups = self._groups_all
251 2
            else:
252 5
                groups = [group for group in self._groups_all if group in groups]
253 2

254 3
            for group in groups:
255 5
                data = getattr(self, group)
256 3
                kwargs = {}
257 5
                if compress:
258 5
                    kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
259 5
                data.to_netcdf(filename, mode=mode, group=group, **kwargs)
260 5
                data.close()
261 5
                mode = "a"
262 2
        else:  # creates a netcdf file for an empty InferenceData object.
263 5
            empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
264 5
            empty_netcdf_file.close()
265 3
        return filename
266 2

267 6
    def to_dict(self, groups=None, filter_groups=None):
268 2
        """Convert InferenceData to a dictionary following xarray naming conventions.
269

270 2
        Parameters
271
        ----------
272
        groups : list, optional
273
            Write only these groups to netcdf file.
274

275
        Returns
276
        -------
277
        dict
278
            A dictionary containing all groups of InferenceData object.
279
            When `data=False` return just the schema.
280
        """
281 3
        ret = defaultdict(dict)
282 3
        if self._groups_all:  # check's whether a group is present or not.
283 3
            if groups is None:
284 5
                groups = self._group_names(groups, filter_groups)
285 2
            else:
286 2
                groups = [group for group in self._groups_all if group in groups]
287 2

288 3
            for group in groups:
289 3
                dataset = getattr(self, group)
290 3
                data = {}
291 5
                for var_name, dataarray in dataset.items():
292 5
                    data[var_name] = dataarray.values
293 5
                    dims = []
294 5
                    for coord_name, coord_values in dataarray.coords.items():
295 5
                        if coord_name not in ("chain", "draw") and not coord_name.startswith(
296 2
                            var_name + "_dim_"
297 2
                        ):
298 5
                            dims.append(coord_name)
299 3
                            ret["coords"][coord_name] = coord_values.values
300

301 5
                    if group in (
302 2
                        "predictions",
303
                        "predictions_constant_data",
304 2
                    ):
305 0
                        dims_key = "pred_dims"
306
                    else:
307 3
                        dims_key = "dims"
308 3
                    if len(dims) > 0:
309 3
                        ret[dims_key][var_name] = dims
310 5
                    ret[group] = data
311 5
                ret[group + "_attrs"] = dataset.attrs
312 2

313 5
        ret["attrs"] = self.attrs
314 5
        return ret
315

316 6
    def to_json(self, filename, **kwargs):
317 2
        """Write InferenceData to a json file.
318

319 2
        Parameters
320
        ----------
321
        filename : str
322
            Location to write to
323
        kwargs : dict
324
            kwargs passed to json.dump()
325

326
        Returns
327
        -------
328
        str
329
            Location of json file
330
        """
331 3
        idata_dict = _make_json_serializable(self.to_dict())
332

333 3
        with open(filename, "w") as file:
334 5
            json.dump(idata_dict, file, **kwargs)
335

336 5
        return filename
337 2

338 4
    def __add__(self, other):
339 2
        """Concatenate two InferenceData objects."""
340 3
        return concat(self, other, copy=True, inplace=False)
341 2

342 4
    def sel(
343 2
        self,
344
        groups=None,
345 2
        filter_groups=None,
346
        inplace=False,
347
        chain_prior=None,
348
        **kwargs,
349
    ):
350
        """Perform an xarray selection on all groups.
351

352
        Loops groups to perform Dataset.sel(key=item)
353
        for every kwarg if key is a dimension of the dataset.
354
        One example could be performing a burn in cut on the InferenceData object
355
        or discarding a chain. The selection is performed on all relevant groups (like
356
        posterior, prior, sample stats) while non relevant groups like observed data are
357
        omitted. See :meth:`xarray.Dataset.sel <xarray:xarray.Dataset.sel>`
358

359
        Parameters
360
        ----------
361
        groups: str or list of str, optional
362
            Groups where the selection is to be applied. Can either be group names
363
            or metagroup names.
364
        filter_groups: {None, "like", "regex"}, optional, default=None
365
            If `None` (default), interpret groups as the real group or metagroup names.
366
            If "like", interpret groups as substrings of the real group or metagroup names.
367
            If "regex", interpret groups as regular expressions on the real group or
368
            metagroup names. A la `pandas.filter`.
369
        inplace: bool, optional
370
            If ``True``, modify the InferenceData object inplace,
371
            otherwise, return the modified copy.
372
        chain_prior: bool, optional, deprecated
373
            If ``False``, do not select prior related groups using ``chain`` dim.
374
            Otherwise, use selection on ``chain`` if present. Default=False
375
        **kwargs: mapping
376
            It must be accepted by Dataset.sel().
377

378
        Returns
379
        -------
380
        InferenceData
381
            A new InferenceData object by default.
382
            When `inplace==True` perform selection in-place and return `None`
383

384
        Examples
385
        --------
386
        Use ``sel`` to discard one chain of the InferenceData object. We first check the
387
        dimensions of the original object:
388

389
        .. ipython::
390

391
            In [1]: import arviz as az
392
               ...: idata = az.load_arviz_data("centered_eight")
393
               ...: del idata.prior  # prior group only has 1 chain currently
394
               ...: print(idata.posterior.coords)
395
               ...: print(idata.posterior_predictive.coords)
396
               ...: print(idata.observed_data.coords)
397

398
        In order to remove the third chain:
399

400
        .. ipython::
401

402
            In [1]: idata_subset = idata.sel(chain=[0, 1, 3])
403
               ...: print(idata_subset.posterior.coords)
404
               ...: print(idata_subset.posterior_predictive.coords)
405
               ...: print(idata_subset.observed_data.coords)
406

407
        """
408 3
        if chain_prior is not None:
409 3
            warnings.warn(
410
                "chain_prior has been deprecated. Use groups argument and "
411 2
                "rcParams['data.metagroups'] instead.",
412 2
                DeprecationWarning,
413
            )
414
        else:
415 3
            chain_prior = False
416 3
        groups = self._group_names(groups, filter_groups)
417

418 5
        out = self if inplace else deepcopy(self)
419 5
        for group in groups:
420 3
            dataset = getattr(self, group)
421 5
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
422 5
            if not chain_prior and "prior" in group:
423 5
                valid_keys -= {"chain"}
424 5
            dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
425 5
            setattr(out, group, dataset)
426 5
        if inplace:
427 5
            return None
428 2
        else:
429 5
            return out
430 2

431 4
    def isel(
432 2
        self,
433
        groups=None,
434 2
        filter_groups=None,
435
        inplace=False,
436
        **kwargs,
437
    ):
438
        """Perform an xarray selection on all groups.
439

440
        Loops groups to perform Dataset.isel(key=item)
441
        for every kwarg if key is a dimension of the dataset.
442
        One example could be performing a burn in cut on the InferenceData object
443
        or discarding a chain. The selection is performed on all relevant groups (like
444
        posterior, prior, sample stats) while non relevant groups like observed data are
445
        omitted. See :meth:`xarray:xarray.Dataset.isel`
446

447
        Parameters
448
        ----------
449
        groups: str or list of str, optional
450
            Groups where the selection is to be applied. Can either be group names
451
            or metagroup names.
452
        filter_groups: {None, "like", "regex"}, optional, default=None
453
            If `None` (default), interpret groups as the real group or metagroup names.
454
            If "like", interpret groups as substrings of the real group or metagroup names.
455
            If "regex", interpret groups as regular expressions on the real group or
456
            metagroup names. A la `pandas.filter`.
457
        inplace: bool, optional
458
            If ``True``, modify the InferenceData object inplace,
459
            otherwise, return the modified copy.
460
        **kwargs: mapping
461
            It must be accepted by :meth:`xarray:xarray.Dataset.isel`.
462

463
        Returns
464
        -------
465
        InferenceData
466
            A new InferenceData object by default.
467
            When `inplace==True` perform selection in-place and return `None`
468

469
        """
470 3
        groups = self._group_names(groups, filter_groups)
471

472 3
        out = self if inplace else deepcopy(self)
473 5
        for group in groups:
474 3
            dataset = getattr(self, group)
475 5
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
476 5
            dataset = dataset.isel(**{key: kwargs[key] for key in valid_keys})
477 5
            setattr(out, group, dataset)
478 5
        if inplace:
479 5
            return None
480 2
        else:
481 5
            return out
482 2

483 4
    def stack(
484 2
        self,
485
        dimensions=None,
486 2
        groups=None,
487
        filter_groups=None,
488
        inplace=False,
489
        **kwargs,
490
    ):
491
        """Perform an xarray stacking on all groups.
492

493
        Stack any number of existing dimensions into a single new dimension.
494
        Loops groups to perform Dataset.stack(key=value)
495
        for every kwarg if value is a dimension of the dataset.
496
        The selection is performed on all relevant groups (like
497
        posterior, prior, sample stats) while non relevant groups like observed data are
498
        omitted. See :meth:`xarray:xarray.Dataset.stack`
499

500
        Parameters
501
        ----------
502
        dimensions: dict
503
            Names of new dimensions, and the existing dimensions that they replace.
504
        groups: str or list of str, optional
505
            Groups where the selection is to be applied. Can either be group names
506
            or metagroup names.
507
        filter_groups: {None, "like", "regex"}, optional, default=None
508
            If `None` (default), interpret groups as the real group or metagroup names.
509
            If "like", interpret groups as substrings of the real group or metagroup names.
510
            If "regex", interpret groups as regular expressions on the real group or
511
            metagroup names. A la `pandas.filter`.
512
        inplace: bool, optional
513
            If ``True``, modify the InferenceData object inplace,
514
            otherwise, return the modified copy.
515
        **kwargs: mapping
516
            It must be accepted by :meth:`xarray:xarray.Dataset.stack`.
517

518
        Returns
519
        -------
520
        InferenceData
521
            A new InferenceData object by default.
522
            When `inplace==True` perform selection in-place and return `None`
523

524
        """
525 3
        groups = self._group_names(groups, filter_groups)
526

527 3
        dimensions = {} if dimensions is None else dimensions
528 5
        dimensions.update(kwargs)
529 3
        out = self if inplace else deepcopy(self)
530 5
        for group in groups:
531 5
            dataset = getattr(self, group)
532 5
            kwarg_dict = {}
533 5
            for key, value in dimensions.items():
534 5
                if not set(value).difference(dataset.dims):
535 5
                    kwarg_dict[key] = value
536 5
            dataset = dataset.stack(**kwarg_dict)
537 5
            setattr(out, group, dataset)
538 5
        if inplace:
539 2
            return None
540 2
        else:
541 5
            return out
542 0

543 4
    def unstack(self, dim=None, groups=None, filter_groups=None, inplace=False):
544 2
        """Perform an xarray unstacking on all groups.
545

546 2
        Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
547
        Loops groups to perform Dataset.unstack(key=value).
548
        The selection is performed on all relevant groups (like posterior, prior,
549
        sample stats) while non relevant groups like observed data are omitted.
550
        See :meth:`xarray:xarray.Dataset.unstack`
551

552
        Parameters
553
        ----------
554
        dim: Hashable or iterable of Hashable, optional
555
            Dimension(s) over which to unstack. By default unstacks all MultiIndexes.
556
        groups: str or list of str, optional
557
            Groups where the selection is to be applied. Can either be group names
558
            or metagroup names.
559
        filter_groups: {None, "like", "regex"}, optional, default=None
560
            If `None` (default), interpret groups as the real group or metagroup names.
561
            If "like", interpret groups as substrings of the real group or metagroup names.
562
            If "regex", interpret groups as regular expressions on the real group or
563
            metagroup names. A la `pandas.filter`.
564
        inplace: bool, optional
565
            If ``True``, modify the InferenceData object inplace,
566
            otherwise, return the modified copy.
567

568
        Returns
569
        -------
570
        InferenceData
571
            A new InferenceData object by default.
572
            When `inplace==True` perform selection in place and return `None`
573

574
        """
575 3
        groups = self._group_names(groups, filter_groups)
576 3
        if isinstance(dim, str):
577 3
            dim = [dim]
578 2

579 5
        out = self if inplace else deepcopy(self)
580 5
        for group in groups:
581 3
            dataset = getattr(self, group)
582 5
            valid_dims = set(dim).intersection(dataset.dims) if dim is not None else dim
583 5
            dataset = dataset.unstack(dim=valid_dims)
584 5
            setattr(out, group, dataset)
585 5
        if inplace:
586 2
            return None
587 2
        else:
588 5
            return out
589 0

590 4
    def rename(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
591 2
        """Perform xarray renaming of variable and dimensions on all groups.
592

593 2
        Loops groups to perform Dataset.rename(name_dict)
594
        for every key in name_dict if key is a dimension/data_vars of the dataset.
595
        The renaming is performed on all relevant groups (like
596
        posterior, prior, sample stats) while non relevant groups like observed data are
597
        omitted. See :meth:`xarray:xarray.Dataset.rename`
598

599
        Parameters
600
        ----------
601
        name_dict: dict
602
            Dictionary whose keys are current variable or dimension names
603
            and whose values are the desired names.
604
        groups: str or list of str, optional
605
            Groups where the selection is to be applied. Can either be group names
606
            or metagroup names.
607
        filter_groups: {None, "like", "regex"}, optional, default=None
608
            If `None` (default), interpret groups as the real group or metagroup names.
609
            If "like", interpret groups as substrings of the real group or metagroup names.
610
            If "regex", interpret groups as regular expressions on the real group or
611
            metagroup names. A la `pandas.filter`.
612
        inplace: bool, optional
613
            If ``True``, modify the InferenceData object inplace,
614
            otherwise, return the modified copy.
615

616

617
        Returns
618
        -------
619
        InferenceData
620
            A new InferenceData object by default.
621
            When `inplace==True` perform renaming in-place and return `None`
622

623
        """
624 3
        groups = self._group_names(groups, filter_groups)
625 3
        if "chain" in name_dict.keys() or "draw" in name_dict.keys():
626 0
            raise KeyError("'chain' or 'draw' dimensions can't be renamed")
627 5
        out = self if inplace else deepcopy(self)
628 2

629 3
        for group in groups:
630 5
            dataset = getattr(self, group)
631 3
            expected_keys = list(dataset.data_vars) + list(dataset.dims)
632 5
            valid_keys = set(name_dict.keys()).intersection(expected_keys)
633 5
            dataset = dataset.rename({key: name_dict[key] for key in valid_keys})
634 5
            setattr(out, group, dataset)
635 5
        if inplace:
636 2
            return None
637 2
        else:
638 5
            return out
639 0

640 4
    def rename_vars(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
641 2
        """Perform xarray renaming of variable or coordinate names on all groups.
642

643 2
        Loops groups to perform Dataset.rename_vars(name_dict)
644
        for every key in name_dict if key is a variable or coordinate names of the dataset.
645
        The renaming is performed on all relevant groups (like
646
        posterior, prior, sample stats) while non relevant groups like observed data are
647
        omitted. See :meth:`xarray:xarray.Dataset.rename_vars`
648

649
        Parameters
650
        ----------
651
        name_dict: dict
652
            Dictionary whose keys are current variable or coordinate names
653
            and whose values are the desired names.
654
        groups: str or list of str, optional
655
            Groups where the selection is to be applied. Can either be group names
656
            or metagroup names.
657
        filter_groups: {None, "like", "regex"}, optional, default=None
658
            If `None` (default), interpret groups as the real group or metagroup names.
659
            If "like", interpret groups as substrings of the real group or metagroup names.
660
            If "regex", interpret groups as regular expressions on the real group or
661
            metagroup names. A la `pandas.filter`.
662
        inplace: bool, optional
663
            If ``True``, modify the InferenceData object inplace,
664
            otherwise, return the modified copy.
665

666

667
        Returns
668
        -------
669
        InferenceData
670
            A new InferenceData object with renamed variables including coordinates by default.
671
            When `inplace==True` perform renaming in-place and return `None`
672

673
        """
674 3
        groups = self._group_names(groups, filter_groups)
675

676 3
        out = self if inplace else deepcopy(self)
677 5
        for group in groups:
678 3
            dataset = getattr(self, group)
679 5
            valid_keys = set(name_dict.keys()).intersection(dataset.data_vars)
680 5
            dataset = dataset.rename_vars({key: name_dict[key] for key in valid_keys})
681 5
            setattr(out, group, dataset)
682 5
        if inplace:
683 2
            return None
684 2
        else:
685 5
            return out
686 0

687 4
    def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
688 2
        """Perform xarray renaming of dimensions on all groups.
689

690 2
        Loops groups to perform Dataset.rename_dims(name_dict)
691
        for every key in name_dict if key is a dimension of the dataset.
692
        The renaming is performed on all relevant groups (like
693
        posterior, prior, sample stats) while non relevant groups like observed data are
694
        omitted. See :meth:`xarray:xarray.Dataset.rename_dims`
695

696
        Parameters
697
        ----------
698
        name_dict: dict
699
            Dictionary whose keys are current dimension names and whose values are the desired
700
            names.
701
        groups: str or list of str, optional
702
            Groups where the selection is to be applied. Can either be group names
703
            or metagroup names.
704
        filter_groups: {None, "like", "regex"}, optional, default=None
705
            If `None` (default), interpret groups as the real group or metagroup names.
706
            If "like", interpret groups as substrings of the real group or metagroup names.
707
            If "regex", interpret groups as regular expressions on the real group or
708
            metagroup names. A la `pandas.filter`.
709
        inplace: bool, optional
710
            If ``True``, modify the InferenceData object inplace,
711
            otherwise, return the modified copy.
712

713

714
        Returns
715
        -------
716
        InferenceData
717
            A new InferenceData object with renamed dimension by default.
718
            When `inplace==True` perform renaming in-place and return `None`
719

720
        """
721 3
        groups = self._group_names(groups, filter_groups)
722 3
        if "chain" in name_dict.keys() or "draw" in name_dict.keys():
723 0
            raise KeyError("'chain' or 'draw' dimensions can't be renamed")
724 2

725 5
        out = self if inplace else deepcopy(self)
726 3
        for group in groups:
727 3
            dataset = getattr(self, group)
728 5
            valid_keys = set(name_dict.keys()).intersection(dataset.dims)
729 5
            dataset = dataset.rename_dims({key: name_dict[key] for key in valid_keys})
730 5
            setattr(out, group, dataset)
731 5
        if inplace:
732 2
            return None
733 2
        else:
734 5
            return out
735 0

736 4
    def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
737 2
        """Add new groups to InferenceData object.
738

739 2
        Parameters
740
        ----------
741
        group_dict: dict of {str : dict or xarray.Dataset}, optional
742
            Groups to be added
743
        coords : dict[str] -> ndarray
744
            Coordinates for the dataset
745
        dims : dict[str] -> list[str]
746
            Dimensions of each variable. The keys are variable names, values are lists of
747
            coordinates.
748
        **kwargs: mapping
749
            The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
750

751
        See Also
752
        --------
753
        extend : Extend InferenceData with groups from another InferenceData.
754
        concat : Concatenate InferenceData objects.
755
        """
756 3
        group_dict = either_dict_or_kwargs(group_dict, kwargs, "add_groups")
757 3
        if not group_dict:
758 3
            raise ValueError("One of group_dict or kwargs must be provided.")
759 5
        repeated_groups = [group for group in group_dict.keys() if group in self._groups]
760 5
        if repeated_groups:
761 5
            raise ValueError("{} group(s) already exists.".format(repeated_groups))
762 5
        for group, dataset in group_dict.items():
763 5
            if group not in SUPPORTED_GROUPS_ALL:
764 5
                warnings.warn(
765 2
                    "The group {} is not defined in the InferenceData scheme".format(group),
766 2
                    UserWarning,
767 2
                )
768 3
            if dataset is None:
769 0
                continue
770 3
            elif isinstance(dataset, dict):
771 5
                if (
772 0
                    group in ("observed_data", "constant_data", "predictions_constant_data")
773 2
                    or group not in SUPPORTED_GROUPS_ALL
774 2
                ):
775 3
                    warnings.warn(
776
                        "the default dims 'chain' and 'draw' will be added automatically",
777
                        UserWarning,
778 2
                    )
779 3
                dataset = dict_to_dataset(dataset, coords=coords, dims=dims)
780 3
            elif isinstance(dataset, xr.DataArray):
781 0
                if dataset.name is None:
782 2
                    dataset.name = "x"
783 2
                dataset = dataset.to_dataset()
784 3
            elif not isinstance(dataset, xr.Dataset):
785 3
                raise ValueError(
786 0
                    "Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts\
787 2
                    (argument '{}' was type '{}')".format(
788 2
                        group, type(dataset)
789
                    )
790
                )
791 3
            if dataset:
792 3
                setattr(self, group, dataset)
793 3
                if group.startswith(WARMUP_TAG):
794 2
                    self._groups_warmup.append(group)
795 2
                else:
796 5
                    self._groups.append(group)
797 2

798 4
    def extend(self, other, join="left"):
799 2
        """Extend InferenceData with groups from another InferenceData.
800

801 2
        Parameters
802
        ----------
803
        other : InferenceData
804
            InferenceData to be added
805
        join : {'left', 'right'}, default 'left'
806
            Defines how the two decide which group to keep when the same group is
807
            present in both objects. 'left' will discard the group in ``other`` whereas 'right'
808
            will keep the group in ``other`` and discard the one in ``self``.
809

810
        See Also
811
        --------
812
        add_groups : Add new groups to InferenceData object.
813
        concat : Concatenate InferenceData objects.
814

815
        """
816 3
        if not isinstance(other, InferenceData):
817 3
            raise ValueError("Extending is possible between two InferenceData objects only.")
818 3
        if join not in ("left", "right"):
819 5
            raise ValueError("join must be either 'left' or 'right', found {}".format(join))
820 5
        for group in other._groups_all:  # pylint: disable=protected-access
821 5
            if hasattr(self, group):
822 5
                if join == "left":
823 5
                    continue
824 5
            if group not in SUPPORTED_GROUPS_ALL:
825 5
                warnings.warn(
826 2
                    "{} group is not defined in the InferenceData scheme".format(group), UserWarning
827 2
                )
828 5
            dataset = getattr(other, group)
829 3
            setattr(self, group, dataset)
830

831 6
    set_index = _extend_xr_method(xr.Dataset.set_index)
832 6
    get_index = _extend_xr_method(xr.Dataset.get_index)
833 6
    reset_index = _extend_xr_method(xr.Dataset.reset_index)
834 6
    set_coords = _extend_xr_method(xr.Dataset.set_coords)
835 4
    reset_coords = _extend_xr_method(xr.Dataset.reset_coords)
836 6
    assign = _extend_xr_method(xr.Dataset.assign)
837 4
    assign_coords = _extend_xr_method(xr.Dataset.assign_coords)
838 6
    sortby = _extend_xr_method(xr.Dataset.sortby)
839 6
    chunk = _extend_xr_method(xr.Dataset.chunk)
840 6
    unify_chunks = _extend_xr_method(xr.Dataset.unify_chunks)
841 6
    load = _extend_xr_method(xr.Dataset.load)
842 6
    compute = _extend_xr_method(xr.Dataset.compute)
843 6
    persist = _extend_xr_method(xr.Dataset.persist)
844 2

845 6
    mean = _extend_xr_method(xr.Dataset.mean)
846 6
    median = _extend_xr_method(xr.Dataset.median)
847 6
    min = _extend_xr_method(xr.Dataset.min)
848 6
    max = _extend_xr_method(xr.Dataset.max)
849 6
    cumsum = _extend_xr_method(xr.Dataset.cumsum)
850 6
    sum = _extend_xr_method(xr.Dataset.sum)
851 4
    quantile = _extend_xr_method(xr.Dataset.quantile)
852 2

853 6
    def _group_names(self, groups, filter_groups=None):
854 2
        """Handle expansion of group names input across arviz.
855 2

856 2
        Parameters
857 2
        ----------
858 2
        groups: str, list of str or None
859
            group or metagroup names.
860 2
        idata: xarray.Dataset
861
            Posterior data in an xarray
862
        filter_groups: {None, "like", "regex"}, optional, default=None
863
            If `None` (default), interpret groups as the real group or metagroup names.
864
            If "like", interpret groups as substrings of the real group or metagroup names.
865
            If "regex", interpret groups as regular expressions on the real group or
866
            metagroup names. A la `pandas.filter`.
867

868
        Returns
869
        -------
870
        groups: list
871
        """
872 3
        all_groups = self._groups_all
873 3
        if groups is None:
874 3
            return all_groups
875 3
        if isinstance(groups, str):
876 3
            groups = [groups]
877 3
        sel_groups = []
878 3
        metagroups = rcParams["data.metagroups"]
879 5
        for group in groups:
880 5
            if group[0] == "~":
881 5
                sel_groups.extend(
882 2
                    [f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
883 2
                    if group[1:] in metagroups
884 2
                    else [group]
885 2
                )
886 2
            else:
887 5
                sel_groups.extend(
888 2
                    [item for item in metagroups[group] if item in all_groups]
889
                    if group in metagroups
890
                    else [group]
891
                )
892

893 3
        try:
894 5
            group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
895 0
        except KeyError as err:
896 0
            msg = " ".join(("groups:", f"{err}", "in InferenceData"))
897 0
            raise KeyError(msg) from err
898 3
        return group_names
899

900 6
    def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
901 2
        """Apply a function to multiple groups.
902 0

903 0
        Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the
904 0
        group with the result of the function.
905 2

906
        Parameters
907 2
        ----------
908
        fun : callable
909
            Function to be applied to each group. Assumes the function is called as
910
            ``fun(dataset, *args, **kwargs)``.
911
        groups : str or list of str, optional
912
            Groups where the selection is to be applied. Can either be group names
913
            or metagroup names.
914
        filter_groups : {None, "like", "regex"}, optional
915
            If `None` (default), interpret var_names as the real variables names. If "like",
916
            interpret var_names as substrings of the real variables names. If "regex",
917
            interpret var_names as regular expressions on the real variables names. A la
918
            `pandas.filter`.
919
        inplace : bool, optional
920
            If ``True``, modify the InferenceData object inplace,
921
            otherwise, return the modified copy.
922
        args : array_like, optional
923
            Positional arguments passed to ``fun``.
924
        **kwargs : mapping, optional
925
            Keyword arguments passed to ``fun``.
926

927
        Returns
928
        -------
929
        InferenceData
930
            A new InferenceData object by default.
931
            When `inplace==True` perform selection in place and return `None`
932

933
        Examples
934
        --------
935
        Shift observed_data, prior_predictive and posterior_predictive.
936

937
        .. ipython::
938

939
            In [1]: import arviz as az
940
               ...: idata = az.load_arviz_data("non_centered_eight")
941
               ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
942
               ...: print(idata_shifted_obs.observed_data)
943
               ...: print(idata_shifted_obs.posterior_predictive)
944

945
        Rename and update the coordinate values in both posterior and prior groups.
946

947
        .. ipython::
948

949
            In [1]: idata = az.load_arviz_data("radon")
950
               ...: idata = idata.map(
951
               ...:     lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
952
               ...:         uranium_coefs=["intercept", "u_slope"]
953
               ...:     ),
954
               ...:     groups=["posterior", "prior"]
955
               ...: )
956
               ...: idata.posterior
957

958
        Add extra coordinates to all groups containing observed variables
959

960
        .. ipython::
961

962
            In [1]: idata = az.load_arviz_data("rugby")
963
               ...: home_team, away_team = np.array([
964
               ...:     m.split() for m in idata.observed_data.match.values
965
               ...: ]).T
966
               ...: idata = idata.map(
967
               ...:     lambda ds, **kwargs: ds.assign_coords(**kwargs),
968
               ...:     groups="observed_vars",
969
               ...:     home_team=("match", home_team),
970
               ...:     away_team=("match", away_team),
971
               ...: )
972
               ...: print(idata.posterior_predictive)
973
               ...: print(idata.observed_data)
974

975
        """
976 3
        if args is None:
977 0
            args = []
978 3
        groups = self._group_names(groups, filter_groups)
979

980 3
        out = self if inplace else deepcopy(self)
981 3
        for group in groups:
982 3
            dataset = getattr(self, group)
983 5
            dataset = fun(dataset, *args, **kwargs)
984 3
            setattr(out, group, dataset)
985 5
        if inplace:
986 0
            return None
987 2
        else:
988 5
            return out
989 2

990 6
    def _wrap_xarray_method(
991 2
        self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
992 2
    ):
993 0
        """Extend and xarray.Dataset method to InferenceData object.
994

995 2
        Parameters
996
        ----------
997 2
        method: str
998
            Method to be extended. Must be a ``xarray.Dataset`` method.
999
        groups: str or list of str, optional
1000
            Groups where the selection is to be applied. Can either be group names
1001
            or metagroup names.
1002
        inplace: bool, optional
1003
            If ``True``, modify the InferenceData object inplace,
1004
            otherwise, return the modified copy.
1005
        **kwargs: mapping, optional
1006
            Keyword arguments passed to the xarray Dataset method.
1007

1008
        Returns
1009
        -------
1010
        InferenceData
1011
            A new InferenceData object by default.
1012
            When `inplace==True` perform selection in place and return `None`
1013

1014
        Examples
1015
        --------
1016
        Compute the mean of `posterior_groups`:
1017

1018
        .. ipython::
1019

1020
            In [1]: import arviz as az
1021
               ...: idata = az.load_arviz_data("non_centered_eight")
1022
               ...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars")
1023
               ...: print(idata_means.posterior)
1024
               ...: print(idata_means.observed_data)
1025

1026
        .. ipython::
1027

1028
            In [1]: idata_stack = idata._wrap_xarray_method(
1029
               ...:     "stack",
1030
               ...:     groups=["posterior_groups", "prior_groups"],
1031
               ...:     sample=["chain", "draw"]
1032
               ...: )
1033
               ...: print(idata_stack.posterior)
1034
               ...: print(idata_stack.prior)
1035
               ...: print(idata_stack.observed_data)
1036

1037
        """
1038 0
        if args is None:
1039 0
            args = []
1040 0
        groups = self._group_names(groups, filter_groups)
1041

1042 0
        method = getattr(xr.Dataset, method)
1043

1044 0
        out = self if inplace else deepcopy(self)
1045 0
        for group in groups:
1046 0
            dataset = getattr(self, group)
1047 0
            dataset = method(dataset, *args, **kwargs)
1048 0
            setattr(out, group, dataset)
1049 0
        if inplace:
1050 0
            return None
1051 0
        else:
1052 0
            return out
1053 0

1054 0

1055 0
# pylint: disable=protected-access, inconsistent-return-statements
1056 4
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
1057 0
    """Concatenate InferenceData objects.
1058

1059 0
    Concatenates over `group`, `chain` or `draw`.
1060
    By default concatenates over unique groups.
1061
    To concatenate over `chain` or `draw` function
1062
    needs identical groups and variables.
1063 2

1064
    The `variables` in the `data` -group are merged if `dim` are not found.
1065

1066

1067
    Parameters
1068
    ----------
1069
    *args : InferenceData
1070
        Variable length InferenceData list or
1071
        Sequence of InferenceData.
1072
    dim : str, optional
1073
        If defined, concatenated over the defined dimension.
1074
        Dimension which is concatenated. If None, concatenates over
1075
        unique groups.
1076
    copy : bool
1077
        If True, groups are copied to the new InferenceData object.
1078
        Used only if `dim` is None.
1079
    inplace : bool
1080
        If True, merge args to first object.
1081
    reset_dim : bool
1082
        Valid only if dim is not None.
1083

1084
    Returns
1085
    -------
1086
    InferenceData
1087
        A new InferenceData object by default.
1088
        When `inplace==True` merge args to first arg and return `None`
1089

1090
    See Also
1091
    --------
1092
    add_groups : Add new groups to InferenceData object.
1093
    extend : Extend InferenceData with groups from another InferenceData.
1094

1095
    Examples
1096
    --------
1097
    Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
1098
    unique groups by default. We first create an ``InferenceData`` object:
1099

1100
    .. ipython::
1101

1102
        In [1]: import arviz as az
1103
           ...: import numpy as np
1104
           ...: data = {
1105
           ...:     "a": np.random.normal(size=(4, 100, 3)),
1106
           ...:     "b": np.random.normal(size=(4, 100)),
1107
           ...: }
1108
           ...: coords = {"a_dim": ["x", "y", "z"]}
1109
           ...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
1110
           ...: dataA
1111

1112
    We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
1113
    create another ``InferenceData`` object:
1114

1115
    .. ipython::
1116

1117
        In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
1118
           ...: dataB
1119

1120
    We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
1121
    these two ``InferenceData`` objects:
1122

1123
    .. ipython::
1124

1125
        In [1]: az.concat(dataA, dataB)
1126

1127
    Now, we will concatenate over chain (or draw). It requires identical groups and variables.
1128
    Here we are concatenating two identical ``InferenceData`` objects over dimension chain:
1129

1130
    .. ipython::
1131

1132
        In [1]: az.concat(dataA, dataA, dim="chain")
1133

1134
    It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
1135
    we can also concatenate over draws.
1136

1137
    """
1138
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
1139 4
    if len(args) == 0:
1140 3
        if inplace:
1141 0
            return
1142 3
        return InferenceData()
1143

1144 4
    if len(args) == 1 and isinstance(args[0], Sequence):
1145 4
        args = args[0]
1146 2

1147 2
    # assert that all args are InferenceData
1148 4
    for i, arg in enumerate(args):
1149 6
        if not isinstance(arg, InferenceData):
1150 3
            raise TypeError(
1151 2
                "Concatenating is supported only"
1152 2
                "between InferenceData objects. Input arg {} is {}".format(i, type(arg))
1153
            )
1154

1155 6
    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
1156 2
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}')
1157 2
        raise TypeError(msg)
1158 4
    dim = dim.lower() if dim is not None else dim
1159

1160 4
    if len(args) == 1 and isinstance(args[0], InferenceData):
1161 3
        if inplace:
1162 5
            return None
1163 0
        else:
1164 3
            if copy:
1165 5
                return deepcopy(args[0])
1166
            else:
1167 5
                return args[0]
1168 2

1169 6
    current_time = str(datetime.now())
1170 4
    combined_attr = defaultdict(list)
1171 6
    for idata in args:
1172 6
        for key, val in idata.attrs.items():
1173 0
            combined_attr[key].append(val)
1174 2

1175 4
    for key, val in combined_attr.items():
1176 2
        all_same = True
1177 2
        for indx in range(len(val) - 1):
1178 2
            if val[indx] != val[indx + 1]:
1179 2
                all_same = False
1180 0
                break
1181 0
        if all_same:
1182 2
            combined_attr[key] = val[0]
1183 4
    if inplace:
1184 4
        setattr(args[0], "_attrs", dict(combined_attr))
1185 0

1186 4
    if not inplace:
1187 0
        # Keep order for python 3.5
1188 3
        inference_data_dict = OrderedDict()
1189 0

1190 6
    if dim is None:
1191 6
        arg0 = args[0]
1192 4
        arg0_groups = ccopy(arg0._groups_all)
1193 6
        args_groups = dict()
1194
        # check if groups are independent
1195 2
        # Concat over unique groups
1196 4
        for arg in args[1:]:
1197 6
            for group in arg._groups_all:
1198 6
                if group in args_groups or group in arg0_groups:
1199 5
                    msg = (
1200 2
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
1201
                        " Valid dimensions are `chain` and `draw`. Alternatively, use extend to"
1202
                        " combine InferenceData with overlapping groups"
1203 2
                    )
1204 5
                    raise TypeError(msg)
1205 6
                group_data = getattr(arg, group)
1206 6
                args_groups[group] = deepcopy(group_data) if copy else group_data
1207
        # add arg0 to args_groups if inplace is False
1208
        # otherwise it will merge args_groups to arg0
1209
        # inference data object
1210 4
        if not inplace:
1211 5
            for group in arg0_groups:
1212 5
                group_data = getattr(arg0, group)
1213 5
                args_groups[group] = deepcopy(group_data) if copy else group_data
1214

1215 4
        other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]
1216

1217 6
        for group in SUPPORTED_GROUPS_ALL + other_groups:
1218 6
            if group not in args_groups:
1219 6
                continue
1220 6
            if inplace:
1221 4
                if group.startswith(WARMUP_TAG):
1222 2
                    arg0._groups_warmup.append(group)
1223
                else:
1224 6
                    arg0._groups.append(group)
1225 6
                setattr(arg0, group, args_groups[group])
1226 2
            else:
1227 5
                inference_data_dict[group] = args_groups[group]
1228 6
        if inplace:
1229 4
            other_groups = [
1230
                group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
1231 2
            ] + other_groups
1232 6
            sorted_groups = [
1233
                group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
1234 2
            ]
1235 6
            setattr(arg0, "_groups", sorted_groups)
1236 6
            sorted_groups_warmup = [
1237
                group
1238
                for group in SUPPORTED_GROUPS_WARMUP + other_groups
1239 2
                if group in arg0._groups_warmup
1240
            ]
1241 4
            setattr(arg0, "_groups_warmup", sorted_groups_warmup)
1242 2
    else:
1243 5
        arg0 = args[0]
1244 3
        arg0_groups = arg0._groups_all
1245 3
        for arg in args[1:]:
1246 3
            for group0 in arg0_groups:
1247 3
                if group0 not in arg._groups_all:
1248 5
                    if group0 == "observed_data":
1249 0
                        continue
1250 5
                    msg = "Mismatch between the groups."
1251 5
                    raise TypeError(msg)
1252 5
            for group in arg._groups_all:
1253 2
                # handle data groups seperately
1254 5
                if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
1255 2
                    # assert that groups are equal
1256 3
                    if group not in arg0_groups:
1257 2
                        msg = "Mismatch between the groups."
1258 2
                        raise TypeError(msg)
1259 2

1260
                    # assert that variables are equal
1261 5
                    group_data = getattr(arg, group)
1262 3
                    group_vars = group_data.data_vars
1263 2

1264 3
                    if not inplace and group in inference_data_dict:
1265 3
                        group0_data = inference_data_dict[group]
1266
                    else:
1267 3
                        group0_data = getattr(arg0, group)
1268 5
                    group0_vars = group0_data.data_vars
1269 2

1270 3
                    for var in group0_vars:
1271 5
                        if var not in group_vars:
1272 5
                            msg = "Mismatch between the variables."
1273 3
                            raise TypeError(msg)
1274 2

1275 5
                    for var in group_vars:
1276 3
                        if var not in group0_vars:
1277 5
                            msg = "Mismatch between the variables."
1278 5
                            raise TypeError(msg)
1279 5
                        var_dims = getattr(group_data, var).dims
1280 5
                        var0_dims = getattr(group0_data, var).dims
1281 3
                        if var_dims != var0_dims:
1282 2
                            msg = "Mismatch between the dimensions."
1283 2
                            raise TypeError(msg)
1284 2

1285 5
                        if dim not in var_dims or dim not in var0_dims:
1286 2
                            msg = "Dimension {} missing.".format(dim)
1287 2
                            raise TypeError(msg)
1288 2

1289 0
                    # xr.concat
1290 3
                    concatenated_group = xr.concat((group0_data, group_data), dim=dim)
1291 3
                    if reset_dim:
1292 5
                        concatenated_group[dim] = range(concatenated_group[dim].size)
1293 0

1294 0
                    # handle attrs
1295 3
                    if hasattr(group0_data, "attrs"):
1296 3
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
1297 2
                    else:
1298 2
                        group0_attrs = OrderedDict()
1299 2

1300 3
                    if hasattr(group_data, "attrs"):
1301 3
                        group_attrs = getattr(group_data, "attrs")
1302 2
                    else:
1303 2
                        group_attrs = dict()
1304

1305 0
                    # gather attrs results to group0_attrs
1306 3
                    for attr_key, attr_values in group_attrs.items():
1307 5
                        group0_attr_values = group0_attrs.get(attr_key, None)
1308 5
                        equality = attr_values == group0_attr_values
1309 3
                        if hasattr(equality, "__iter__"):
1310 0
                            equality = np.all(equality)
1311 3
                        if equality:
1312 3
                            continue
1313 2
                        # handle special cases:
1314 5
                        if attr_key in ("created_at", "previous_created_at"):
1315 2
                            # check the defaults
1316 5
                            if not hasattr(group0_attrs, "previous_created_at"):
1317 3
                                group0_attrs["previous_created_at"] = []
1318 5
                                if group0_attr_values is not None:
1319 5
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
1320
                            # check previous values
1321 5
                            if attr_key == "previous_created_at":
1322 0
                                if not isinstance(attr_values, list):
1323 2
                                    attr_values = [attr_values]
1324 2
                                group0_attrs["previous_created_at"].extend(attr_values)
1325 2
                                continue
1326 2
                            # update "created_at"
1327 3
                            if group0_attr_values != current_time:
1328 5
                                group0_attrs[attr_key] = current_time
1329 3
                            group0_attrs["previous_created_at"].append(attr_values)
1330 0

1331 0
                        elif attr_key in group0_attrs:
1332 0
                            combined_key = "combined_{}".format(attr_key)
1333 0
                            if combined_key not in group0_attrs:
1334 2
                                group0_attrs[combined_key] = [group0_attr_values]
1335 2
                            group0_attrs[combined_key].append(attr_values)
1336 2
                        else:
1337 0
                            group0_attrs[attr_key] = attr_values
1338 0
                    # update attrs
1339 3
                    setattr(concatenated_group, "attrs", group0_attrs)
1340 0

1341 3
                    if inplace:
1342 3
                        setattr(arg0, group, concatenated_group)
1343
                    else:
1344 3
                        inference_data_dict[group] = concatenated_group
1345
                else:
1346 2
                    # observed_data, "constant_data", "predictions_constant_data",
1347 3
                    if group not in arg0_groups:
1348 2
                        setattr(arg0, group, deepcopy(group_data) if copy else group_data)
1349 2
                        arg0._groups.append(group)
1350 0
                        continue
1351 2

1352
                    # assert that variables are equal
1353 3
                    group_data = getattr(arg, group)
1354 5
                    group_vars = group_data.data_vars
1355 0

1356 3
                    group0_data = getattr(arg0, group)
1357 3
                    if not inplace:
1358 3
                        group0_data = deepcopy(group0_data)
1359 3
                    group0_vars = group0_data.data_vars
1360 2

1361 5
                    for var in group_vars:
1362 3
                        if var not in group0_vars:
1363 2
                            var_data = getattr(group_data, var)
1364 2
                            getattr(arg0, group)[var] = var_data
1365 2
                        else:
1366 5
                            var_data = getattr(group_data, var)
1367 3
                            var0_data = getattr(group0_data, var)
1368 5
                            if dim in var_data.dims and dim in var0_data.dims:
1369 2
                                concatenated_var = xr.concat((group_data, group0_data), dim=dim)
1370 0
                                group0_data[var] = concatenated_var
1371 0

1372
                    # handle attrs
1373 5
                    if hasattr(group0_data, "attrs"):
1374 5
                        group0_attrs = getattr(group0_data, "attrs")
1375 2
                    else:
1376 0
                        group0_attrs = OrderedDict()
1377 0

1378 3
                    if hasattr(group_data, "attrs"):
1379 3
                        group_attrs = getattr(group_data, "attrs")
1380 2
                    else:
1381 2
                        group_attrs = dict()
1382

1383 0
                    # gather attrs results to group0_attrs
1384 3
                    for attr_key, attr_values in group_attrs.items():
1385 5
                        group0_attr_values = group0_attrs.get(attr_key, None)
1386 5
                        equality = attr_values == group0_attr_values
1387 3
                        if hasattr(equality, "__iter__"):
1388 0
                            equality = np.all(equality)
1389 3
                        if equality:
1390 3
                            continue
1391 2
                        # handle special cases:
1392 5
                        if attr_key in ("created_at", "previous_created_at"):
1393 2
                            # check the defaults
1394 5
                            if not hasattr(group0_attrs, "previous_created_at"):
1395 3
                                group0_attrs["previous_created_at"] = []
1396 5
                                if group0_attr_values is not None:
1397 5
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
1398
                            # check previous values
1399 5
                            if attr_key == "previous_created_at":
1400 0
                                if not isinstance(attr_values, list):
1401 2
                                    attr_values = [attr_values]
1402 2
                                group0_attrs["previous_created_at"].extend(attr_values)
1403 2
                                continue
1404 2
                            # update "created_at"
1405 3
                            if group0_attr_values != current_time:
1406 5
                                group0_attrs[attr_key] = current_time
1407 3
                            group0_attrs["previous_created_at"].append(attr_values)
1408 0

1409 0
                        elif attr_key in group0_attrs:
1410 0
                            combined_key = "combined_{}".format(attr_key)
1411 0
                            if combined_key not in group0_attrs:
1412 2
                                group0_attrs[combined_key] = [group0_attr_values]
1413 2
                            group0_attrs[combined_key].append(attr_values)
1414 2

1415
                        else:
1416 0
                            group0_attrs[attr_key] = attr_values
1417 0
                    # update attrs
1418 3
                    setattr(group0_data, "attrs", group0_attrs)
1419 0

1420 3
                    if inplace:
1421 3
                        setattr(arg0, group, group0_data)
1422
                    else:
1423 3
                        inference_data_dict[group] = group0_data
1424

1425 6
    if not inplace:
1426 3
        inference_data_dict["attrs"] = combined_attr
1427 2

1428 6
    return None if inplace else InferenceData(**inference_data_dict)

Read our documentation on viewing source code .

Loading