1
# pylint: disable=too-many-lines
2 2
"""Data structure for using netcdf groups with xarray."""
3 2
import uuid
4 2
import warnings
5 2
from collections import OrderedDict, defaultdict
6 2
from collections.abc import Sequence
7 2
from copy import copy as ccopy
8 2
from copy import deepcopy
9 2
from datetime import datetime
10 2
from html import escape
11

12 2
import netCDF4 as nc
13 2
import numpy as np
14 2
import xarray as xr
15 2
from xarray.core.options import OPTIONS
16 2
from xarray.core.utils import either_dict_or_kwargs
17

18 2
from ..rcparams import rcParams
19 2
from ..utils import HtmlTemplate, _subset_list
20 2
from .base import _extend_xr_method, dict_to_dataset, _make_json_serializable
21

22 2
try:
23 2
    import ujson as json
24 0
except ImportError:
25 0
    import json
26

27 2
SUPPORTED_GROUPS = [
28
    "posterior",
29
    "posterior_predictive",
30
    "predictions",
31
    "log_likelihood",
32
    "sample_stats",
33
    "prior",
34
    "prior_predictive",
35
    "sample_stats_prior",
36
    "observed_data",
37
    "constant_data",
38
    "predictions_constant_data",
39
]
40

41 2
WARMUP_TAG = "warmup_"
42

43 2
SUPPORTED_GROUPS_WARMUP = [
44
    "{}posterior".format(WARMUP_TAG),
45
    "{}posterior_predictive".format(WARMUP_TAG),
46
    "{}predictions".format(WARMUP_TAG),
47
    "{}sample_stats".format(WARMUP_TAG),
48
    "{}log_likelihood".format(WARMUP_TAG),
49
]
50

51 2
SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
52

53

54 2
class InferenceData:
55
    """Container for inference data storage using xarray.
56

57
    For a detailed introduction to ``InferenceData`` objects and their usage, see
58
    :doc:`/notebooks/XarrayforArviZ`. This page provides help and documentation
59
    on ``InferenceData`` methods and their low level implementation.
60
    """
61

62 2
    def __init__(self, **kwargs):
63
        """Initialize InferenceData object from keyword xarray datasets.
64

65
        Parameters
66
        ----------
67
        kwargs :
68
            Keyword arguments of xarray datasets
69

70
        Examples
71
        --------
72
        Initiate an InferenceData object from scratch, not recommended. InferenceData
73
        objects should be initialized using ``from_xyz`` methods, see :ref:`data_api` for more
74
        details.
75

76
        .. ipython::
77

78
            In [1]: import arviz as az
79
               ...: import numpy as np
80
               ...: import xarray as xr
81
               ...: dataset = xr.Dataset(
82
               ...:     {
83
               ...:         "a": (["chain", "draw", "a_dim"], np.random.normal(size=(4, 100, 3))),
84
               ...:         "b": (["chain", "draw"], np.random.normal(size=(4, 100))),
85
               ...:     },
86
               ...:     coords={
87
               ...:         "chain": (["chain"], np.arange(4)),
88
               ...:         "draw": (["draw"], np.arange(100)),
89
               ...:         "a_dim": (["a_dim"], ["x", "y", "z"]),
90
               ...:     }
91
               ...: )
92
               ...: idata = az.InferenceData(posterior=dataset, prior=dataset)
93
               ...: idata
94

95
        We have created an ``InferenceData`` object with two groups. Now we can check its
96
        contents:
97

98
        .. ipython::
99

100
            In [1]: idata.posterior
101

102
        """
103 2
        self._groups = []
104 2
        self._groups_warmup = []
105 2
        save_warmup = kwargs.pop("save_warmup", False)
106 2
        key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
107 2
        for key in kwargs:
108 2
            if key not in SUPPORTED_GROUPS_ALL:
109 2
                key_list.append(key)
110 2
                warnings.warn(
111
                    "{} group is not defined in the InferenceData scheme".format(key), UserWarning
112
                )
113 2
        for key in key_list:
114 2
            dataset = kwargs[key]
115 2
            dataset_warmup = None
116 2
            if dataset is None:
117 2
                continue
118 2
            elif isinstance(dataset, (list, tuple)):
119 2
                dataset, dataset_warmup = kwargs[key]
120 2
            elif not isinstance(dataset, xr.Dataset):
121 0
                raise ValueError(
122
                    "Arguments to InferenceData must be xarray Datasets "
123
                    "(argument '{}' was type '{}')".format(key, type(dataset))
124
                )
125 2
            if not key.startswith(WARMUP_TAG):
126 2
                if dataset:
127 2
                    setattr(self, key, dataset)
128 2
                    self._groups.append(key)
129 0
            elif key.startswith(WARMUP_TAG):
130 0
                if dataset:
131 0
                    setattr(self, key, dataset)
132 0
                    self._groups_warmup.append(key)
133 2
            if save_warmup and dataset_warmup is not None:
134 2
                if dataset_warmup:
135 2
                    key = "{}{}".format(WARMUP_TAG, key)
136 2
                    setattr(self, key, dataset_warmup)
137 2
                    self._groups_warmup.append(key)
138

139 2
    def __repr__(self):
140
        """Make string representation of InferenceData object."""
141 2
        msg = "Inference data with groups:\n\t> {options}".format(
142
            options="\n\t> ".join(self._groups)
143
        )
144 2
        if self._groups_warmup:
145 0
            msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG)
146 2
        return msg
147

148 2
    def _repr_html_(self):
149
        """Make html representation of InferenceData object."""
150 2
        display_style = OPTIONS["display_style"]
151 2
        if display_style == "text":
152 2
            html_repr = f"<pre>{escape(repr(self))}</pre>"
153
        else:
154 2
            elements = "".join(
155
                [
156
                    HtmlTemplate.element_template.format(
157
                        group_id=group + str(uuid.uuid4()),
158
                        group=group,
159
                        xr_data=getattr(  # pylint: disable=protected-access
160
                            self, group
161
                        )._repr_html_(),
162
                    )
163
                    for group in self._groups_all
164
                ]
165
            )
166 2
            formatted_html_template = (  # pylint: disable=possibly-unused-variable
167
                HtmlTemplate.html_template.format(elements)
168
            )
169 2
            css_template = HtmlTemplate.css_template  # pylint: disable=possibly-unused-variable
170 2
            html_repr = "%(formatted_html_template)s%(css_template)s" % locals()
171 2
        return html_repr
172

173 2
    def __delattr__(self, group):
174
        """Delete a group from the InferenceData object."""
175 2
        if group in self._groups:
176 2
            self._groups.remove(group)
177 0
        elif group in self._groups_warmup:
178 0
            self._groups_warmup.remove(group)
179 2
        object.__delattr__(self, group)
180

181 2
    @property
182
    def _groups_all(self):
183 2
        return self._groups + self._groups_warmup
184

185 2
    def __iter__(self):
186
        """Iterate over groups in InferenceData object."""
187 2
        for group in self._groups_all:
188 2
            yield group
189

190 2
    def groups(self):
191
        """Return all groups present in InferenceData object."""
192 2
        return self._groups_all
193

194 2
    def values(self):
195
        """Xarray Datasets present in InferenceData object."""
196 2
        for group in self._groups_all:
197 2
            yield getattr(self, group)
198

199 2
    def items(self):
200
        """Yield groups and corresponding datasets present in InferenceData object."""
201 2
        for group in self._groups_all:
202 2
            yield (group, getattr(self, group))
203

204 2
    @staticmethod
205
    def from_netcdf(filename):
206
        """Initialize object from a netcdf file.
207

208
        Expects that the file will have groups, each of which can be loaded by xarray.
209
        By default, the datasets of the InferenceData object will be lazily loaded instead
210
        of being loaded into memory. This
211
        behaviour is regulated by the value of ``az.rcParams["data.load"]``.
212

213
        Parameters
214
        ----------
215
        filename : str
216
            location of netcdf file
217

218
        Returns
219
        -------
220
        InferenceData object
221
        """
222 2
        groups = {}
223 2
        with nc.Dataset(filename, mode="r") as data:
224 2
            data_groups = list(data.groups)
225

226 2
        for group in data_groups:
227 2
            with xr.open_dataset(filename, group=group) as data:
228 2
                if rcParams["data.load"] == "eager":
229 2
                    groups[group] = data.load()
230
                else:
231 2
                    groups[group] = data
232 2
        return InferenceData(**groups)
233

234 2
    def to_netcdf(self, filename, compress=True, groups=None):
235
        """Write InferenceData to file using netcdf4.
236

237
        Parameters
238
        ----------
239
        filename : str
240
            Location to write to
241
        compress : bool, optional
242
            Whether to compress result. Note this saves disk space, but may make
243
            saving and loading somewhat slower (default: True).
244
        groups : list, optional
245
            Write only these groups to netcdf file.
246

247
        Returns
248
        -------
249
        str
250
            Location of netcdf file
251
        """
252 2
        mode = "w"  # overwrite first, then append
253 2
        if self._groups_all:  # check's whether a group is present or not.
254 2
            if groups is None:
255 2
                groups = self._groups_all
256
            else:
257 2
                groups = [group for group in self._groups_all if group in groups]
258

259 2
            for group in groups:
260 2
                data = getattr(self, group)
261 2
                kwargs = {}
262 2
                if compress:
263 2
                    kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
264 2
                data.to_netcdf(filename, mode=mode, group=group, **kwargs)
265 2
                data.close()
266 2
                mode = "a"
267
        else:  # creates a netcdf file for an empty InferenceData object.
268 2
            empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
269 2
            empty_netcdf_file.close()
270 2
        return filename
271

272 2
    def to_dict(self, groups=None, filter_groups=None):
273
        """Convert InferenceData to a dictionary following xarray naming conventions.
274

275
        Parameters
276
        ----------
277
        groups : list, optional
278
            Write only these groups to netcdf file.
279

280
        Returns
281
        -------
282
        dict
283
            A dictionary containing all groups of InferenceData object.
284
            When `data=False` return just the schema.
285
        """
286 2
        ret = defaultdict(dict)
287 2
        attrs = None
288 2
        if self._groups_all:  # check's whether a group is present or not.
289 2
            if groups is None:
290 2
                groups = self._group_names(groups, filter_groups)
291
            else:
292 0
                groups = [group for group in self._groups_all if group in groups]
293

294 2
            for group in groups:
295 2
                dataset = getattr(self, group)
296 2
                data = {}
297 2
                for var_name, dataarray in dataset.items():
298 2
                    data[var_name] = dataarray.values
299 2
                    dims = []
300 2
                    for coord_name, coord_values in dataarray.coords.items():
301 2
                        if coord_name not in ("chain", "draw") and not coord_name.startswith(
302
                            var_name + "_dim_"
303
                        ):
304 2
                            dims.append(coord_name)
305 2
                            ret["coords"][coord_name] = coord_values.values
306

307 2
                    if group in (
308
                        "predictions",
309
                        "predictions_constant_data",
310
                    ):
311 0
                        dims_key = "pred_dims"
312
                    else:
313 2
                        dims_key = "dims"
314 2
                    if len(dims) > 0:
315 2
                        ret[dims_key][var_name] = dims
316 2
                    ret[group] = data
317 2
                if attrs is None:
318 2
                    attrs = dataset.attrs
319 2
                elif attrs != dataset.attrs:
320 2
                    warnings.warn(
321
                        "The attributes are not same for all groups."
322
                        " Considering only the first group `attrs`"
323
                    )
324

325 2
        ret["attrs"] = attrs
326 2
        return ret
327

328 2
    def to_json(self, filename, **kwargs):
329
        """Write InferenceData to a json file.
330

331
        Parameters
332
        ----------
333
        filename : str
334
            Location to write to
335
        kwargs : dict
336
            kwargs passed to json.dump()
337

338
        Returns
339
        -------
340
        str
341
            Location of json file
342
        """
343 2
        idata_dict = _make_json_serializable(self.to_dict())
344

345 2
        with open(filename, "w") as file:
346 2
            json.dump(idata_dict, file, **kwargs)
347

348 2
        return filename
349

350 2
    def __add__(self, other):
351
        """Concatenate two InferenceData objects."""
352 2
        return concat(self, other, copy=True, inplace=False)
353

354 2
    def sel(
355
        self,
356
        groups=None,
357
        filter_groups=None,
358
        inplace=False,
359
        chain_prior=None,
360
        **kwargs,
361
    ):
362
        """Perform an xarray selection on all groups.
363

364
        Loops groups to perform Dataset.sel(key=item)
365
        for every kwarg if key is a dimension of the dataset.
366
        One example could be performing a burn in cut on the InferenceData object
367
        or discarding a chain. The selection is performed on all relevant groups (like
368
        posterior, prior, sample stats) while non relevant groups like observed data are
369
        omitted. See :meth:`xarray.Dataset.sel <xarray:xarray.Dataset.sel>`
370

371
        Parameters
372
        ----------
373
        groups: str or list of str, optional
374
            Groups where the selection is to be applied. Can either be group names
375
            or metagroup names.
376
        filter_groups: {None, "like", "regex"}, optional, default=None
377
            If `None` (default), interpret groups as the real group or metagroup names.
378
            If "like", interpret groups as substrings of the real group or metagroup names.
379
            If "regex", interpret groups as regular expressions on the real group or
380
            metagroup names. A la `pandas.filter`.
381
        inplace: bool, optional
382
            If ``True``, modify the InferenceData object inplace,
383
            otherwise, return the modified copy.
384
        chain_prior: bool, optional, deprecated
385
            If ``False``, do not select prior related groups using ``chain`` dim.
386
            Otherwise, use selection on ``chain`` if present. Default=False
387
        **kwargs: mapping
388
            It must be accepted by Dataset.sel().
389

390
        Returns
391
        -------
392
        InferenceData
393
            A new InferenceData object by default.
394
            When `inplace==True` perform selection in-place and return `None`
395

396
        Examples
397
        --------
398
        Use ``sel`` to discard one chain of the InferenceData object. We first check the
399
        dimensions of the original object:
400

401
        .. ipython::
402

403
            In [1]: import arviz as az
404
               ...: idata = az.load_arviz_data("centered_eight")
405
               ...: del idata.prior  # prior group only has 1 chain currently
406
               ...: print(idata.posterior.coords)
407
               ...: print(idata.posterior_predictive.coords)
408
               ...: print(idata.observed_data.coords)
409

410
        In order to remove the third chain:
411

412
        .. ipython::
413

414
            In [1]: idata_subset = idata.sel(chain=[0, 1, 3])
415
               ...: print(idata_subset.posterior.coords)
416
               ...: print(idata_subset.posterior_predictive.coords)
417
               ...: print(idata_subset.observed_data.coords)
418

419
        """
420 2
        if chain_prior is not None:
421 2
            warnings.warn(
422
                "chain_prior has been deprecated. Use groups argument and "
423
                "rcParams['data.metagroups'] instead.",
424
                DeprecationWarning,
425
            )
426
        else:
427 2
            chain_prior = False
428 2
        groups = self._group_names(groups, filter_groups)
429

430 2
        out = self if inplace else deepcopy(self)
431 2
        for group in groups:
432 2
            dataset = getattr(self, group)
433 2
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
434 2
            if not chain_prior and "prior" in group:
435 2
                valid_keys -= {"chain"}
436 2
            dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
437 2
            setattr(out, group, dataset)
438 2
        if inplace:
439 2
            return None
440
        else:
441 2
            return out
442

443 2
    def isel(
444
        self,
445
        groups=None,
446
        filter_groups=None,
447
        inplace=False,
448
        **kwargs,
449
    ):
450
        """Perform an xarray selection on all groups.
451

452
        Loops groups to perform Dataset.isel(key=item)
453
        for every kwarg if key is a dimension of the dataset.
454
        One example could be performing a burn in cut on the InferenceData object
455
        or discarding a chain. The selection is performed on all relevant groups (like
456
        posterior, prior, sample stats) while non relevant groups like observed data are
457
        omitted. See :meth:`xarray:xarray.Dataset.isel`
458

459
        Parameters
460
        ----------
461
        groups: str or list of str, optional
462
            Groups where the selection is to be applied. Can either be group names
463
            or metagroup names.
464
        filter_groups: {None, "like", "regex"}, optional, default=None
465
            If `None` (default), interpret groups as the real group or metagroup names.
466
            If "like", interpret groups as substrings of the real group or metagroup names.
467
            If "regex", interpret groups as regular expressions on the real group or
468
            metagroup names. A la `pandas.filter`.
469
        inplace: bool, optional
470
            If ``True``, modify the InferenceData object inplace,
471
            otherwise, return the modified copy.
472
        **kwargs: mapping
473
            It must be accepted by :meth:`xarray:xarray.Dataset.isel`.
474

475
        Returns
476
        -------
477
        InferenceData
478
            A new InferenceData object by default.
479
            When `inplace==True` perform selection in-place and return `None`
480

481
        """
482 2
        groups = self._group_names(groups, filter_groups)
483

484 2
        out = self if inplace else deepcopy(self)
485 2
        for group in groups:
486 2
            dataset = getattr(self, group)
487 2
            valid_keys = set(kwargs.keys()).intersection(dataset.dims)
488 2
            dataset = dataset.isel(**{key: kwargs[key] for key in valid_keys})
489 2
            setattr(out, group, dataset)
490 2
        if inplace:
491 2
            return None
492
        else:
493 2
            return out
494

495 2
    def stack(
496
        self,
497
        dimensions=None,
498
        groups=None,
499
        filter_groups=None,
500
        inplace=False,
501
        **kwargs,
502
    ):
503
        """Perform an xarray stacking on all groups.
504

505
        Stack any number of existing dimensions into a single new dimension.
506
        Loops groups to perform Dataset.stack(key=value)
507
        for every kwarg if value is a dimension of the dataset.
508
        The selection is performed on all relevant groups (like
509
        posterior, prior, sample stats) while non relevant groups like observed data are
510
        omitted. See :meth:`xarray:xarray.Dataset.stack`
511

512
        Parameters
513
        ----------
514
        dimensions: dict
515
            Names of new dimensions, and the existing dimensions that they replace.
516
        groups: str or list of str, optional
517
            Groups where the selection is to be applied. Can either be group names
518
            or metagroup names.
519
        filter_groups: {None, "like", "regex"}, optional, default=None
520
            If `None` (default), interpret groups as the real group or metagroup names.
521
            If "like", interpret groups as substrings of the real group or metagroup names.
522
            If "regex", interpret groups as regular expressions on the real group or
523
            metagroup names. A la `pandas.filter`.
524
        inplace: bool, optional
525
            If ``True``, modify the InferenceData object inplace,
526
            otherwise, return the modified copy.
527
        **kwargs: mapping
528
            It must be accepted by :meth:`xarray:xarray.Dataset.stack`.
529

530
        Returns
531
        -------
532
        InferenceData
533
            A new InferenceData object by default.
534
            When `inplace==True` perform selection in-place and return `None`
535

536
        """
537 2
        groups = self._group_names(groups, filter_groups)
538

539 2
        dimensions = {} if dimensions is None else dimensions
540 2
        dimensions.update(kwargs)
541 2
        out = self if inplace else deepcopy(self)
542 2
        for group in groups:
543 2
            dataset = getattr(self, group)
544 2
            kwarg_dict = {}
545 2
            for key, value in dimensions.items():
546 2
                if not set(value).difference(dataset.dims):
547 2
                    kwarg_dict[key] = value
548 2
            dataset = dataset.stack(**kwarg_dict)
549 2
            setattr(out, group, dataset)
550 2
        if inplace:
551 0
            return None
552
        else:
553 2
            return out
554

555 2
    def unstack(self, dim=None, groups=None, filter_groups=None, inplace=False):
556
        """Perform an xarray unstacking on all groups.
557

558
        Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
559
        Loops groups to perform Dataset.unstack(key=value).
560
        The selection is performed on all relevant groups (like posterior, prior,
561
        sample stats) while non relevant groups like observed data are omitted.
562
        See :meth:`xarray:xarray.Dataset.unstack`
563

564
        Parameters
565
        ----------
566
        dim: Hashable or iterable of Hashable, optional
567
            Dimension(s) over which to unstack. By default unstacks all MultiIndexes.
568
        groups: str or list of str, optional
569
            Groups where the selection is to be applied. Can either be group names
570
            or metagroup names.
571
        filter_groups: {None, "like", "regex"}, optional, default=None
572
            If `None` (default), interpret groups as the real group or metagroup names.
573
            If "like", interpret groups as substrings of the real group or metagroup names.
574
            If "regex", interpret groups as regular expressions on the real group or
575
            metagroup names. A la `pandas.filter`.
576
        inplace: bool, optional
577
            If ``True``, modify the InferenceData object inplace,
578
            otherwise, return the modified copy.
579

580
        Returns
581
        -------
582
        InferenceData
583
            A new InferenceData object by default.
584
            When `inplace==True` perform selection in place and return `None`
585

586
        """
587 2
        groups = self._group_names(groups, filter_groups)
588 2
        if isinstance(dim, str):
589 2
            dim = [dim]
590

591 2
        out = self if inplace else deepcopy(self)
592 2
        for group in groups:
593 2
            dataset = getattr(self, group)
594 2
            valid_dims = set(dim).intersection(dataset.dims) if dim is not None else dim
595 2
            dataset = dataset.unstack(dim=valid_dims)
596 2
            setattr(out, group, dataset)
597 2
        if inplace:
598 0
            return None
599
        else:
600 2
            return out
601

602 2
    def rename(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
603
        """Perform xarray renaming of variable and dimensions on all groups.
604

605
        Loops groups to perform Dataset.rename(name_dict)
606
        for every key in name_dict if key is a dimension/data_vars of the dataset.
607
        The renaming is performed on all relevant groups (like
608
        posterior, prior, sample stats) while non relevant groups like observed data are
609
        omitted. See :meth:`xarray:xarray.Dataset.rename`
610

611
        Parameters
612
        ----------
613
        name_dict: dict
614
            Dictionary whose keys are current variable or dimension names
615
            and whose values are the desired names.
616
        groups: str or list of str, optional
617
            Groups where the selection is to be applied. Can either be group names
618
            or metagroup names.
619
        filter_groups: {None, "like", "regex"}, optional, default=None
620
            If `None` (default), interpret groups as the real group or metagroup names.
621
            If "like", interpret groups as substrings of the real group or metagroup names.
622
            If "regex", interpret groups as regular expressions on the real group or
623
            metagroup names. A la `pandas.filter`.
624
        inplace: bool, optional
625
            If ``True``, modify the InferenceData object inplace,
626
            otherwise, return the modified copy.
627

628

629
        Returns
630
        -------
631
        InferenceData
632
            A new InferenceData object by default.
633
            When `inplace==True` perform renaming in-place and return `None`
634

635
        """
636 2
        groups = self._group_names(groups, filter_groups)
637 2
        if "chain" in name_dict.keys() or "draw" in name_dict.keys():
638 0
            raise KeyError("'chain' or 'draw' dimensions can't be renamed")
639 2
        out = self if inplace else deepcopy(self)
640

641 2
        for group in groups:
642 2
            dataset = getattr(self, group)
643 2
            expected_keys = list(dataset.data_vars) + list(dataset.dims)
644 2
            valid_keys = set(name_dict.keys()).intersection(expected_keys)
645 2
            dataset = dataset.rename({key: name_dict[key] for key in valid_keys})
646 2
            setattr(out, group, dataset)
647 2
        if inplace:
648 0
            return None
649
        else:
650 2
            return out
651

652 2
    def rename_vars(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
653
        """Perform xarray renaming of variable or coordinate names on all groups.
654

655
        Loops groups to perform Dataset.rename_vars(name_dict)
656
        for every key in name_dict if key is a variable or coordinate names of the dataset.
657
        The renaming is performed on all relevant groups (like
658
        posterior, prior, sample stats) while non relevant groups like observed data are
659
        omitted. See :meth:`xarray:xarray.Dataset.rename_vars`
660

661
        Parameters
662
        ----------
663
        name_dict: dict
664
            Dictionary whose keys are current variable or coordinate names
665
            and whose values are the desired names.
666
        groups: str or list of str, optional
667
            Groups where the selection is to be applied. Can either be group names
668
            or metagroup names.
669
        filter_groups: {None, "like", "regex"}, optional, default=None
670
            If `None` (default), interpret groups as the real group or metagroup names.
671
            If "like", interpret groups as substrings of the real group or metagroup names.
672
            If "regex", interpret groups as regular expressions on the real group or
673
            metagroup names. A la `pandas.filter`.
674
        inplace: bool, optional
675
            If ``True``, modify the InferenceData object inplace,
676
            otherwise, return the modified copy.
677

678

679
        Returns
680
        -------
681
        InferenceData
682
            A new InferenceData object with renamed variables including coordinates by default.
683
            When `inplace==True` perform renaming in-place and return `None`
684

685
        """
686 2
        groups = self._group_names(groups, filter_groups)
687

688 2
        out = self if inplace else deepcopy(self)
689 2
        for group in groups:
690 2
            dataset = getattr(self, group)
691 2
            valid_keys = set(name_dict.keys()).intersection(dataset.data_vars)
692 2
            dataset = dataset.rename_vars({key: name_dict[key] for key in valid_keys})
693 2
            setattr(out, group, dataset)
694 2
        if inplace:
695 0
            return None
696
        else:
697 2
            return out
698

699 2
    def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
700
        """Perform xarray renaming of dimensions on all groups.
701

702
        Loops groups to perform Dataset.rename_dims(name_dict)
703
        for every key in name_dict if key is a dimension of the dataset.
704
        The renaming is performed on all relevant groups (like
705
        posterior, prior, sample stats) while non relevant groups like observed data are
706
        omitted. See :meth:`xarray:xarray.Dataset.rename_dims`
707

708
        Parameters
709
        ----------
710
        name_dict: dict
711
            Dictionary whose keys are current dimension names and whose values are the desired
712
            names.
713
        groups: str or list of str, optional
714
            Groups where the selection is to be applied. Can either be group names
715
            or metagroup names.
716
        filter_groups: {None, "like", "regex"}, optional, default=None
717
            If `None` (default), interpret groups as the real group or metagroup names.
718
            If "like", interpret groups as substrings of the real group or metagroup names.
719
            If "regex", interpret groups as regular expressions on the real group or
720
            metagroup names. A la `pandas.filter`.
721
        inplace: bool, optional
722
            If ``True``, modify the InferenceData object inplace,
723
            otherwise, return the modified copy.
724

725

726
        Returns
727
        -------
728
        InferenceData
729
            A new InferenceData object with renamed dimension by default.
730
            When `inplace==True` perform renaming in-place and return `None`
731

732
        """
733 2
        groups = self._group_names(groups, filter_groups)
734 2
        if "chain" in name_dict.keys() or "draw" in name_dict.keys():
735 0
            raise KeyError("'chain' or 'draw' dimensions can't be renamed")
736

737 2
        out = self if inplace else deepcopy(self)
738 2
        for group in groups:
739 2
            dataset = getattr(self, group)
740 2
            valid_keys = set(name_dict.keys()).intersection(dataset.dims)
741 2
            dataset = dataset.rename_dims({key: name_dict[key] for key in valid_keys})
742 2
            setattr(out, group, dataset)
743 2
        if inplace:
744 0
            return None
745
        else:
746 2
            return out
747

748 2
    def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
749
        """Add new groups to InferenceData object.
750

751
        Parameters
752
        ----------
753
        group_dict: dict of {str : dict or xarray.Dataset}, optional
754
            Groups to be added
755
        coords : dict[str] -> ndarray
756
            Coordinates for the dataset
757
        dims : dict[str] -> list[str]
758
            Dimensions of each variable. The keys are variable names, values are lists of
759
            coordinates.
760
        **kwargs: mapping
761
            The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
762

763
        See Also
764
        --------
765
        extend : Extend InferenceData with groups from another InferenceData.
766
        concat : Concatenate InferenceData objects.
767
        """
768 2
        group_dict = either_dict_or_kwargs(group_dict, kwargs, "add_groups")
769 2
        if not group_dict:
770 2
            raise ValueError("One of group_dict or kwargs must be provided.")
771 2
        repeated_groups = [group for group in group_dict.keys() if group in self._groups]
772 2
        if repeated_groups:
773 2
            raise ValueError("{} group(s) already exists.".format(repeated_groups))
774 2
        for group, dataset in group_dict.items():
775 2
            if group not in SUPPORTED_GROUPS_ALL:
776 2
                warnings.warn(
777
                    "The group {} is not defined in the InferenceData scheme".format(group),
778
                    UserWarning,
779
                )
780 2
            if dataset is None:
781 0
                continue
782 2
            elif isinstance(dataset, dict):
783 2
                if (
784
                    group in ("observed_data", "constant_data", "predictions_constant_data")
785
                    or group not in SUPPORTED_GROUPS_ALL
786
                ):
787 2
                    warnings.warn(
788
                        "the default dims 'chain' and 'draw' will be added automatically",
789
                        UserWarning,
790
                    )
791 2
                dataset = dict_to_dataset(dataset, coords=coords, dims=dims)
792 2
            elif isinstance(dataset, xr.DataArray):
793 0
                if dataset.name is None:
794 0
                    dataset.name = "x"
795 0
                dataset = dataset.to_dataset()
796 2
            elif not isinstance(dataset, xr.Dataset):
797 2
                raise ValueError(
798
                    "Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts\
799
                    (argument '{}' was type '{}')".format(
800
                        group, type(dataset)
801
                    )
802
                )
803 2
            if dataset:
804 2
                setattr(self, group, dataset)
805 2
                if group.startswith(WARMUP_TAG):
806 0
                    self._groups_warmup.append(group)
807
                else:
808 2
                    self._groups.append(group)
809

810 2
    def extend(self, other, join="left"):
811
        """Extend InferenceData with groups from another InferenceData.
812

813
        Parameters
814
        ----------
815
        other : InferenceData
816
            InferenceData to be added
817
        join : {'left', 'right'}, default 'left'
818
            Defines how the two decide which group to keep when the same group is
819
            present in both objects. 'left' will discard the group in ``other`` whereas 'right'
820
            will keep the group in ``other`` and discard the one in ``self``.
821

822
        See Also
823
        --------
824
        add_groups : Add new groups to InferenceData object.
825
        concat : Concatenate InferenceData objects.
826

827
        """
828 2
        if not isinstance(other, InferenceData):
829 2
            raise ValueError("Extending is possible between two InferenceData objects only.")
830 2
        if join not in ("left", "right"):
831 2
            raise ValueError("join must be either 'left' or 'right', found {}".format(join))
832 2
        for group in other._groups_all:  # pylint: disable=protected-access
833 2
            if hasattr(self, group):
834 2
                if join == "left":
835 2
                    continue
836 2
            if group not in SUPPORTED_GROUPS_ALL:
837 2
                warnings.warn(
838
                    "{} group is not defined in the InferenceData scheme".format(group), UserWarning
839
                )
840 2
            dataset = getattr(other, group)
841 2
            setattr(self, group, dataset)
842

843 2
    set_index = _extend_xr_method(xr.Dataset.set_index)
844 2
    get_index = _extend_xr_method(xr.Dataset.get_index)
845 2
    reset_index = _extend_xr_method(xr.Dataset.reset_index)
846 2
    set_coords = _extend_xr_method(xr.Dataset.set_coords)
847 2
    reset_coords = _extend_xr_method(xr.Dataset.reset_coords)
848 2
    assign = _extend_xr_method(xr.Dataset.assign)
849 2
    assign_coords = _extend_xr_method(xr.Dataset.assign_coords)
850 2
    sortby = _extend_xr_method(xr.Dataset.sortby)
851 2
    chunk = _extend_xr_method(xr.Dataset.chunk)
852 2
    unify_chunks = _extend_xr_method(xr.Dataset.unify_chunks)
853 2
    load = _extend_xr_method(xr.Dataset.load)
854 2
    compute = _extend_xr_method(xr.Dataset.compute)
855 2
    persist = _extend_xr_method(xr.Dataset.persist)
856

857 2
    mean = _extend_xr_method(xr.Dataset.mean)
858 2
    median = _extend_xr_method(xr.Dataset.median)
859 2
    min = _extend_xr_method(xr.Dataset.min)
860 2
    max = _extend_xr_method(xr.Dataset.max)
861 2
    cumsum = _extend_xr_method(xr.Dataset.cumsum)
862 2
    sum = _extend_xr_method(xr.Dataset.sum)
863 2
    quantile = _extend_xr_method(xr.Dataset.quantile)
864

865 2
    def _group_names(self, groups, filter_groups=None):
866
        """Handle expansion of group names input across arviz.
867

868
        Parameters
869
        ----------
870
        groups: str, list of str or None
871
            group or metagroup names.
872
        idata: xarray.Dataset
873
            Posterior data in an xarray
874
        filter_groups: {None, "like", "regex"}, optional, default=None
875
            If `None` (default), interpret groups as the real group or metagroup names.
876
            If "like", interpret groups as substrings of the real group or metagroup names.
877
            If "regex", interpret groups as regular expressions on the real group or
878
            metagroup names. A la `pandas.filter`.
879

880
        Returns
881
        -------
882
        groups: list
883
        """
884 2
        all_groups = self._groups_all
885 2
        if groups is None:
886 2
            return all_groups
887 2
        if isinstance(groups, str):
888 2
            groups = [groups]
889 2
        sel_groups = []
890 2
        metagroups = rcParams["data.metagroups"]
891 2
        for group in groups:
892 2
            if group[0] == "~":
893 2
                sel_groups.extend(
894
                    [f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
895
                    if group[1:] in metagroups
896
                    else [group]
897
                )
898
            else:
899 2
                sel_groups.extend(
900
                    [item for item in metagroups[group] if item in all_groups]
901
                    if group in metagroups
902
                    else [group]
903
                )
904

905 2
        try:
906 2
            group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
907 0
        except KeyError as err:
908 0
            msg = " ".join(("groups:", f"{err}", "in InferenceData"))
909 0
            raise KeyError(msg) from err
910 2
        return group_names
911

912 2
    def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
913
        """Apply a function to multiple groups.
914

915
        Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the
916
        group with the result of the function.
917

918
        Parameters
919
        ----------
920
        fun : callable
921
            Function to be applied to each group. Assumes the function is called as
922
            ``fun(dataset, *args, **kwargs)``.
923
        groups : str or list of str, optional
924
            Groups where the selection is to be applied. Can either be group names
925
            or metagroup names.
926
        filter_groups : {None, "like", "regex"}, optional
927
            If `None` (default), interpret var_names as the real variables names. If "like",
928
            interpret var_names as substrings of the real variables names. If "regex",
929
            interpret var_names as regular expressions on the real variables names. A la
930
            `pandas.filter`.
931
        inplace : bool, optional
932
            If ``True``, modify the InferenceData object inplace,
933
            otherwise, return the modified copy.
934
        args : array_like, optional
935
            Positional arguments passed to ``fun``.
936
        **kwargs : mapping, optional
937
            Keyword arguments passed to ``fun``.
938

939
        Returns
940
        -------
941
        InferenceData
942
            A new InferenceData object by default.
943
            When `inplace==True` perform selection in place and return `None`
944

945
        Examples
946
        --------
947
        Shift observed_data, prior_predictive and posterior_predictive.
948

949
        .. ipython::
950

951
            In [1]: import arviz as az
952
               ...: idata = az.load_arviz_data("non_centered_eight")
953
               ...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
954
               ...: print(idata_shifted_obs.observed_data)
955
               ...: print(idata_shifted_obs.posterior_predictive)
956

957
        Rename and update the coordinate values in both posterior and prior groups.
958

959
        .. ipython::
960

961
            In [1]: idata = az.load_arviz_data("radon")
962
               ...: idata = idata.map(
963
               ...:     lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
964
               ...:         uranium_coefs=["intercept", "u_slope"]
965
               ...:     ),
966
               ...:     groups=["posterior", "prior"]
967
               ...: )
968
               ...: idata.posterior
969

970
        Add extra coordinates to all groups containing observed variables
971

972
        .. ipython::
973

974
            In [1]: idata = az.load_arviz_data("rugby")
975
               ...: home_team, away_team = np.array([
976
               ...:     m.split() for m in idata.observed_data.match.values
977
               ...: ]).T
978
               ...: idata = idata.map(
979
               ...:     lambda ds, **kwargs: ds.assign_coords(**kwargs),
980
               ...:     groups="observed_vars",
981
               ...:     home_team=("match", home_team),
982
               ...:     away_team=("match", away_team),
983
               ...: )
984
               ...: print(idata.posterior_predictive)
985
               ...: print(idata.observed_data)
986

987
        """
988 2
        if args is None:
989 0
            args = []
990 2
        groups = self._group_names(groups, filter_groups)
991

992 2
        out = self if inplace else deepcopy(self)
993 2
        for group in groups:
994 2
            dataset = getattr(self, group)
995 2
            dataset = fun(dataset, *args, **kwargs)
996 2
            setattr(out, group, dataset)
997 2
        if inplace:
998 0
            return None
999
        else:
1000 2
            return out
1001

1002 2
    def _wrap_xarray_method(
1003
        self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
1004
    ):
1005
        """Extend and xarray.Dataset method to InferenceData object.
1006

1007
        Parameters
1008
        ----------
1009
        method: str
1010
            Method to be extended. Must be a ``xarray.Dataset`` method.
1011
        groups: str or list of str, optional
1012
            Groups where the selection is to be applied. Can either be group names
1013
            or metagroup names.
1014
        inplace: bool, optional
1015
            If ``True``, modify the InferenceData object inplace,
1016
            otherwise, return the modified copy.
1017
        **kwargs: mapping, optional
1018
            Keyword arguments passed to the xarray Dataset method.
1019

1020
        Returns
1021
        -------
1022
        InferenceData
1023
            A new InferenceData object by default.
1024
            When `inplace==True` perform selection in place and return `None`
1025

1026
        Examples
1027
        --------
1028
        Compute the mean of `posterior_groups`:
1029

1030
        .. ipython::
1031

1032
            In [1]: import arviz as az
1033
               ...: idata = az.load_arviz_data("non_centered_eight")
1034
               ...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars")
1035
               ...: print(idata_means.posterior)
1036
               ...: print(idata_means.observed_data)
1037

1038
        .. ipython::
1039

1040
            In [1]: idata_stack = idata._wrap_xarray_method(
1041
               ...:     "stack",
1042
               ...:     groups=["posterior_groups", "prior_groups"],
1043
               ...:     sample=["chain", "draw"]
1044
               ...: )
1045
               ...: print(idata_stack.posterior)
1046
               ...: print(idata_stack.prior)
1047
               ...: print(idata_stack.observed_data)
1048

1049
        """
1050 0
        if args is None:
1051 0
            args = []
1052 0
        groups = self._group_names(groups, filter_groups)
1053

1054 0
        method = getattr(xr.Dataset, method)
1055

1056 0
        out = self if inplace else deepcopy(self)
1057 0
        for group in groups:
1058 0
            dataset = getattr(self, group)
1059 0
            dataset = method(dataset, *args, **kwargs)
1060 0
            setattr(out, group, dataset)
1061 0
        if inplace:
1062 0
            return None
1063
        else:
1064 0
            return out
1065

1066

1067
# pylint: disable=protected-access, inconsistent-return-statements
1068 2
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
1069
    """Concatenate InferenceData objects.
1070

1071
    Concatenates over `group`, `chain` or `draw`.
1072
    By default concatenates over unique groups.
1073
    To concatenate over `chain` or `draw` function
1074
    needs identical groups and variables.
1075

1076
    The `variables` in the `data` -group are merged if `dim` are not found.
1077

1078

1079
    Parameters
1080
    ----------
1081
    *args : InferenceData
1082
        Variable length InferenceData list or
1083
        Sequence of InferenceData.
1084
    dim : str, optional
1085
        If defined, concatenated over the defined dimension.
1086
        Dimension which is concatenated. If None, concatenates over
1087
        unique groups.
1088
    copy : bool
1089
        If True, groups are copied to the new InferenceData object.
1090
        Used only if `dim` is None.
1091
    inplace : bool
1092
        If True, merge args to first object.
1093
    reset_dim : bool
1094
        Valid only if dim is not None.
1095

1096
    Returns
1097
    -------
1098
    InferenceData
1099
        A new InferenceData object by default.
1100
        When `inplace==True` merge args to first arg and return `None`
1101

1102
    See Also
1103
    --------
1104
    add_groups : Add new groups to InferenceData object.
1105
    extend : Extend InferenceData with groups from another InferenceData.
1106

1107
    Examples
1108
    --------
1109
    Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
1110
    unique groups by default. We first create an ``InferenceData`` object:
1111

1112
    .. ipython::
1113

1114
        In [1]: import arviz as az
1115
           ...: import numpy as np
1116
           ...: data = {
1117
           ...:     "a": np.random.normal(size=(4, 100, 3)),
1118
           ...:     "b": np.random.normal(size=(4, 100)),
1119
           ...: }
1120
           ...: coords = {"a_dim": ["x", "y", "z"]}
1121
           ...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
1122
           ...: dataA
1123

1124
    We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
1125
    create another ``InferenceData`` object:
1126

1127
    .. ipython::
1128

1129
        In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
1130
           ...: dataB
1131

1132
    We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
1133
    these two ``InferenceData`` objects:
1134

1135
    .. ipython::
1136

1137
        In [1]: az.concat(dataA, dataB)
1138

1139
    Now, we will concatenate over chain (or draw). It requires identical groups and variables.
1140
    Here we are concatenating two identical ``InferenceData`` objects over dimension chain:
1141

1142
    .. ipython::
1143

1144
        In [1]: az.concat(dataA, dataA, dim="chain")
1145

1146
    It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
1147
    we can also concatenate over draws.
1148

1149
    """
1150
    # pylint: disable=undefined-loop-variable, too-many-nested-blocks
1151 2
    if len(args) == 0:
1152 2
        if inplace:
1153 0
            return
1154 2
        return InferenceData()
1155

1156 2
    if len(args) == 1 and isinstance(args[0], Sequence):
1157 2
        args = args[0]
1158

1159
    # assert that all args are InferenceData
1160 2
    for i, arg in enumerate(args):
1161 2
        if not isinstance(arg, InferenceData):
1162 2
            raise TypeError(
1163
                "Concatenating is supported only"
1164
                "between InferenceData objects. Input arg {} is {}".format(i, type(arg))
1165
            )
1166

1167 2
    if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
1168 0
        msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}')
1169 0
        raise TypeError(msg)
1170 2
    dim = dim.lower() if dim is not None else dim
1171

1172 2
    if len(args) == 1 and isinstance(args[0], InferenceData):
1173 2
        if inplace:
1174 2
            return None
1175
        else:
1176 2
            if copy:
1177 2
                return deepcopy(args[0])
1178
            else:
1179 2
                return args[0]
1180

1181 2
    current_time = str(datetime.now())
1182

1183 2
    if not inplace:
1184
        # Keep order for python 3.5
1185 2
        inference_data_dict = OrderedDict()
1186

1187 2
    if dim is None:
1188 2
        arg0 = args[0]
1189 2
        arg0_groups = ccopy(arg0._groups_all)
1190 2
        args_groups = dict()
1191
        # check if groups are independent
1192
        # Concat over unique groups
1193 2
        for arg in args[1:]:
1194 2
            for group in arg._groups_all:
1195 2
                if group in args_groups or group in arg0_groups:
1196 2
                    msg = (
1197
                        "Concatenating overlapping groups is not supported unless `dim` is defined."
1198
                        " Valid dimensions are `chain` and `draw`. Alternatively, use extend to"
1199
                        " combine InferenceData with overlapping groups"
1200
                    )
1201 2
                    raise TypeError(msg)
1202 2
                group_data = getattr(arg, group)
1203 2
                args_groups[group] = deepcopy(group_data) if copy else group_data
1204
        # add arg0 to args_groups if inplace is False
1205
        # otherwise it will merge args_groups to arg0
1206
        # inference data object
1207 2
        if not inplace:
1208 2
            for group in arg0_groups:
1209 2
                group_data = getattr(arg0, group)
1210 2
                args_groups[group] = deepcopy(group_data) if copy else group_data
1211

1212 2
        other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]
1213

1214 2
        for group in SUPPORTED_GROUPS_ALL + other_groups:
1215 2
            if group not in args_groups:
1216 2
                continue
1217 2
            if inplace:
1218 2
                if group.startswith(WARMUP_TAG):
1219 0
                    arg0._groups_warmup.append(group)
1220
                else:
1221 2
                    arg0._groups.append(group)
1222 2
                setattr(arg0, group, args_groups[group])
1223
            else:
1224 2
                inference_data_dict[group] = args_groups[group]
1225 2
        if inplace:
1226 2
            other_groups = [
1227
                group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
1228
            ] + other_groups
1229 2
            sorted_groups = [
1230
                group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
1231
            ]
1232 2
            setattr(arg0, "_groups", sorted_groups)
1233 2
            sorted_groups_warmup = [
1234
                group
1235
                for group in SUPPORTED_GROUPS_WARMUP + other_groups
1236
                if group in arg0._groups_warmup
1237
            ]
1238 2
            setattr(arg0, "_groups_warmup", sorted_groups_warmup)
1239
    else:
1240 2
        arg0 = args[0]
1241 2
        arg0_groups = arg0._groups_all
1242 2
        for arg in args[1:]:
1243 2
            for group0 in arg0_groups:
1244 2
                if group0 not in arg._groups_all:
1245 2
                    if group0 == "observed_data":
1246 0
                        continue
1247 2
                    msg = "Mismatch between the groups."
1248 2
                    raise TypeError(msg)
1249 2
            for group in arg._groups_all:
1250
                # handle data groups seperately
1251 2
                if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
1252
                    # assert that groups are equal
1253 2
                    if group not in arg0_groups:
1254 0
                        msg = "Mismatch between the groups."
1255 0
                        raise TypeError(msg)
1256

1257
                    # assert that variables are equal
1258 2
                    group_data = getattr(arg, group)
1259 2
                    group_vars = group_data.data_vars
1260

1261 2
                    if not inplace and group in inference_data_dict:
1262 2
                        group0_data = inference_data_dict[group]
1263
                    else:
1264 2
                        group0_data = getattr(arg0, group)
1265 2
                    group0_vars = group0_data.data_vars
1266

1267 2
                    for var in group0_vars:
1268 2
                        if var not in group_vars:
1269 2
                            msg = "Mismatch between the variables."
1270 2
                            raise TypeError(msg)
1271

1272 2
                    for var in group_vars:
1273 2
                        if var not in group0_vars:
1274 2
                            msg = "Mismatch between the variables."
1275 2
                            raise TypeError(msg)
1276 2
                        var_dims = getattr(group_data, var).dims
1277 2
                        var0_dims = getattr(group0_data, var).dims
1278 2
                        if var_dims != var0_dims:
1279 0
                            msg = "Mismatch between the dimensions."
1280 0
                            raise TypeError(msg)
1281

1282 2
                        if dim not in var_dims or dim not in var0_dims:
1283 0
                            msg = "Dimension {} missing.".format(dim)
1284 0
                            raise TypeError(msg)
1285

1286
                    # xr.concat
1287 2
                    concatenated_group = xr.concat((group0_data, group_data), dim=dim)
1288 2
                    if reset_dim:
1289 2
                        concatenated_group[dim] = range(concatenated_group[dim].size)
1290

1291
                    # handle attrs
1292 2
                    if hasattr(group0_data, "attrs"):
1293 2
                        group0_attrs = deepcopy(getattr(group0_data, "attrs"))
1294
                    else:
1295 0
                        group0_attrs = OrderedDict()
1296

1297 2
                    if hasattr(group_data, "attrs"):
1298 2
                        group_attrs = getattr(group_data, "attrs")
1299
                    else:
1300 0
                        group_attrs = dict()
1301

1302
                    # gather attrs results to group0_attrs
1303 2
                    for attr_key, attr_values in group_attrs.items():
1304 2
                        group0_attr_values = group0_attrs.get(attr_key, None)
1305 2
                        equality = attr_values == group0_attr_values
1306 2
                        if hasattr(equality, "__iter__"):
1307 0
                            equality = np.all(equality)
1308 2
                        if equality:
1309 2
                            continue
1310
                        # handle special cases:
1311 2
                        if attr_key in ("created_at", "previous_created_at"):
1312
                            # check the defaults
1313 2
                            if not hasattr(group0_attrs, "previous_created_at"):
1314 2
                                group0_attrs["previous_created_at"] = []
1315 2
                                if group0_attr_values is not None:
1316 2
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
1317
                            # check previous values
1318 2
                            if attr_key == "previous_created_at":
1319 0
                                if not isinstance(attr_values, list):
1320 0
                                    attr_values = [attr_values]
1321 0
                                group0_attrs["previous_created_at"].extend(attr_values)
1322 0
                                continue
1323
                            # update "created_at"
1324 2
                            if group0_attr_values != current_time:
1325 2
                                group0_attrs[attr_key] = current_time
1326 2
                            group0_attrs["previous_created_at"].append(attr_values)
1327

1328 0
                        elif attr_key in group0_attrs:
1329 0
                            combined_key = "combined_{}".format(attr_key)
1330 0
                            if combined_key not in group0_attrs:
1331 0
                                group0_attrs[combined_key] = [group0_attr_values]
1332 0
                            group0_attrs[combined_key].append(attr_values)
1333
                        else:
1334 0
                            group0_attrs[attr_key] = attr_values
1335
                    # update attrs
1336 2
                    setattr(concatenated_group, "attrs", group0_attrs)
1337

1338 2
                    if inplace:
1339 2
                        setattr(arg0, group, concatenated_group)
1340
                    else:
1341 2
                        inference_data_dict[group] = concatenated_group
1342
                else:
1343
                    # observed_data, "constant_data", "predictions_constant_data",
1344 2
                    if group not in arg0_groups:
1345 0
                        setattr(arg0, group, deepcopy(group_data) if copy else group_data)
1346 0
                        arg0._groups.append(group)
1347 0
                        continue
1348

1349
                    # assert that variables are equal
1350 2
                    group_data = getattr(arg, group)
1351 2
                    group_vars = group_data.data_vars
1352

1353 2
                    group0_data = getattr(arg0, group)
1354 2
                    if not inplace:
1355 2
                        group0_data = deepcopy(group0_data)
1356 2
                    group0_vars = group0_data.data_vars
1357

1358 2
                    for var in group_vars:
1359 2
                        if var not in group0_vars:
1360 0
                            var_data = getattr(group_data, var)
1361 0
                            getattr(arg0, group)[var] = var_data
1362
                        else:
1363 2
                            var_data = getattr(group_data, var)
1364 2
                            var0_data = getattr(group0_data, var)
1365 2
                            if dim in var_data.dims and dim in var0_data.dims:
1366 0
                                concatenated_var = xr.concat((group_data, group0_data), dim=dim)
1367 0
                                group0_data[var] = concatenated_var
1368

1369
                    # handle attrs
1370 2
                    if hasattr(group0_data, "attrs"):
1371 2
                        group0_attrs = getattr(group0_data, "attrs")
1372
                    else:
1373 0
                        group0_attrs = OrderedDict()
1374

1375 2
                    if hasattr(group_data, "attrs"):
1376 2
                        group_attrs = getattr(group_data, "attrs")
1377
                    else:
1378 0
                        group_attrs = dict()
1379

1380
                    # gather attrs results to group0_attrs
1381 2
                    for attr_key, attr_values in group_attrs.items():
1382 2
                        group0_attr_values = group0_attrs.get(attr_key, None)
1383 2
                        equality = attr_values == group0_attr_values
1384 2
                        if hasattr(equality, "__iter__"):
1385 0
                            equality = np.all(equality)
1386 2
                        if equality:
1387 2
                            continue
1388
                        # handle special cases:
1389 2
                        if attr_key in ("created_at", "previous_created_at"):
1390
                            # check the defaults
1391 2
                            if not hasattr(group0_attrs, "previous_created_at"):
1392 2
                                group0_attrs["previous_created_at"] = []
1393 2
                                if group0_attr_values is not None:
1394 2
                                    group0_attrs["previous_created_at"].append(group0_attr_values)
1395
                            # check previous values
1396 2
                            if attr_key == "previous_created_at":
1397 0
                                if not isinstance(attr_values, list):
1398 0
                                    attr_values = [attr_values]
1399 0
                                group0_attrs["previous_created_at"].extend(attr_values)
1400 0
                                continue
1401
                            # update "created_at"
1402 2
                            if group0_attr_values != current_time:
1403 2
                                group0_attrs[attr_key] = current_time
1404 2
                            group0_attrs["previous_created_at"].append(attr_values)
1405

1406 0
                        elif attr_key in group0_attrs:
1407 0
                            combined_key = "combined_{}".format(attr_key)
1408 0
                            if combined_key not in group0_attrs:
1409 0
                                group0_attrs[combined_key] = [group0_attr_values]
1410 0
                            group0_attrs[combined_key].append(attr_values)
1411

1412
                        else:
1413 0
                            group0_attrs[attr_key] = attr_values
1414
                    # update attrs
1415 2
                    setattr(group0_data, "attrs", group0_attrs)
1416

1417 2
                    if inplace:
1418 2
                        setattr(arg0, group, group0_data)
1419
                    else:
1420 2
                        inference_data_dict[group] = group0_data
1421

1422 2
    return None if inplace else InferenceData(**inference_data_dict)

Read our documentation on viewing source code .

Loading