1
|
|
# pylint: disable=too-many-lines
|
2
|
2
|
"""Data structure for using netcdf groups with xarray."""
|
3
|
2
|
import uuid
|
4
|
2
|
import warnings
|
5
|
2
|
from collections import OrderedDict, defaultdict
|
6
|
2
|
from collections.abc import Sequence
|
7
|
2
|
from copy import copy as ccopy
|
8
|
2
|
from copy import deepcopy
|
9
|
2
|
from datetime import datetime
|
10
|
2
|
from html import escape
|
11
|
|
|
12
|
2
|
import netCDF4 as nc
|
13
|
2
|
import numpy as np
|
14
|
2
|
import xarray as xr
|
15
|
2
|
from xarray.core.options import OPTIONS
|
16
|
2
|
from xarray.core.utils import either_dict_or_kwargs
|
17
|
|
|
18
|
2
|
from ..rcparams import rcParams
|
19
|
2
|
from ..utils import HtmlTemplate, _subset_list
|
20
|
2
|
from .base import _extend_xr_method, dict_to_dataset, _make_json_serializable
|
21
|
|
|
22
|
2
|
try:
|
23
|
2
|
import ujson as json
|
24
|
0
|
except ImportError:
|
25
|
0
|
import json
|
26
|
|
|
27
|
2
|
SUPPORTED_GROUPS = [
|
28
|
|
"posterior",
|
29
|
|
"posterior_predictive",
|
30
|
|
"predictions",
|
31
|
|
"log_likelihood",
|
32
|
|
"sample_stats",
|
33
|
|
"prior",
|
34
|
|
"prior_predictive",
|
35
|
|
"sample_stats_prior",
|
36
|
|
"observed_data",
|
37
|
|
"constant_data",
|
38
|
|
"predictions_constant_data",
|
39
|
|
]
|
40
|
|
|
41
|
2
|
WARMUP_TAG = "warmup_"
|
42
|
|
|
43
|
2
|
SUPPORTED_GROUPS_WARMUP = [
|
44
|
|
"{}posterior".format(WARMUP_TAG),
|
45
|
|
"{}posterior_predictive".format(WARMUP_TAG),
|
46
|
|
"{}predictions".format(WARMUP_TAG),
|
47
|
|
"{}sample_stats".format(WARMUP_TAG),
|
48
|
|
"{}log_likelihood".format(WARMUP_TAG),
|
49
|
|
]
|
50
|
|
|
51
|
2
|
SUPPORTED_GROUPS_ALL = SUPPORTED_GROUPS + SUPPORTED_GROUPS_WARMUP
|
52
|
|
|
53
|
|
|
54
|
2
|
class InferenceData:
|
55
|
|
"""Container for inference data storage using xarray.
|
56
|
|
|
57
|
|
For a detailed introduction to ``InferenceData`` objects and their usage, see
|
58
|
|
:doc:`/notebooks/XarrayforArviZ`. This page provides help and documentation
|
59
|
|
on ``InferenceData`` methods and their low level implementation.
|
60
|
|
"""
|
61
|
|
|
62
|
2
|
def __init__(self, **kwargs):
|
63
|
|
"""Initialize InferenceData object from keyword xarray datasets.
|
64
|
|
|
65
|
|
Parameters
|
66
|
|
----------
|
67
|
|
kwargs :
|
68
|
|
Keyword arguments of xarray datasets
|
69
|
|
|
70
|
|
Examples
|
71
|
|
--------
|
72
|
|
Initiate an InferenceData object from scratch, not recommended. InferenceData
|
73
|
|
objects should be initialized using ``from_xyz`` methods, see :ref:`data_api` for more
|
74
|
|
details.
|
75
|
|
|
76
|
|
.. ipython::
|
77
|
|
|
78
|
|
In [1]: import arviz as az
|
79
|
|
...: import numpy as np
|
80
|
|
...: import xarray as xr
|
81
|
|
...: dataset = xr.Dataset(
|
82
|
|
...: {
|
83
|
|
...: "a": (["chain", "draw", "a_dim"], np.random.normal(size=(4, 100, 3))),
|
84
|
|
...: "b": (["chain", "draw"], np.random.normal(size=(4, 100))),
|
85
|
|
...: },
|
86
|
|
...: coords={
|
87
|
|
...: "chain": (["chain"], np.arange(4)),
|
88
|
|
...: "draw": (["draw"], np.arange(100)),
|
89
|
|
...: "a_dim": (["a_dim"], ["x", "y", "z"]),
|
90
|
|
...: }
|
91
|
|
...: )
|
92
|
|
...: idata = az.InferenceData(posterior=dataset, prior=dataset)
|
93
|
|
...: idata
|
94
|
|
|
95
|
|
We have created an ``InferenceData`` object with two groups. Now we can check its
|
96
|
|
contents:
|
97
|
|
|
98
|
|
.. ipython::
|
99
|
|
|
100
|
|
In [1]: idata.posterior
|
101
|
|
|
102
|
|
"""
|
103
|
2
|
self._groups = []
|
104
|
2
|
self._groups_warmup = []
|
105
|
2
|
save_warmup = kwargs.pop("save_warmup", False)
|
106
|
2
|
key_list = [key for key in SUPPORTED_GROUPS_ALL if key in kwargs]
|
107
|
2
|
for key in kwargs:
|
108
|
2
|
if key not in SUPPORTED_GROUPS_ALL:
|
109
|
2
|
key_list.append(key)
|
110
|
2
|
warnings.warn(
|
111
|
|
"{} group is not defined in the InferenceData scheme".format(key), UserWarning
|
112
|
|
)
|
113
|
2
|
for key in key_list:
|
114
|
2
|
dataset = kwargs[key]
|
115
|
2
|
dataset_warmup = None
|
116
|
2
|
if dataset is None:
|
117
|
2
|
continue
|
118
|
2
|
elif isinstance(dataset, (list, tuple)):
|
119
|
2
|
dataset, dataset_warmup = kwargs[key]
|
120
|
2
|
elif not isinstance(dataset, xr.Dataset):
|
121
|
0
|
raise ValueError(
|
122
|
|
"Arguments to InferenceData must be xarray Datasets "
|
123
|
|
"(argument '{}' was type '{}')".format(key, type(dataset))
|
124
|
|
)
|
125
|
2
|
if not key.startswith(WARMUP_TAG):
|
126
|
2
|
if dataset:
|
127
|
2
|
setattr(self, key, dataset)
|
128
|
2
|
self._groups.append(key)
|
129
|
0
|
elif key.startswith(WARMUP_TAG):
|
130
|
0
|
if dataset:
|
131
|
0
|
setattr(self, key, dataset)
|
132
|
0
|
self._groups_warmup.append(key)
|
133
|
2
|
if save_warmup and dataset_warmup is not None:
|
134
|
2
|
if dataset_warmup:
|
135
|
2
|
key = "{}{}".format(WARMUP_TAG, key)
|
136
|
2
|
setattr(self, key, dataset_warmup)
|
137
|
2
|
self._groups_warmup.append(key)
|
138
|
|
|
139
|
2
|
def __repr__(self):
|
140
|
|
"""Make string representation of InferenceData object."""
|
141
|
2
|
msg = "Inference data with groups:\n\t> {options}".format(
|
142
|
|
options="\n\t> ".join(self._groups)
|
143
|
|
)
|
144
|
2
|
if self._groups_warmup:
|
145
|
0
|
msg += "\n\nWarmup iterations saved ({}*).".format(WARMUP_TAG)
|
146
|
2
|
return msg
|
147
|
|
|
148
|
2
|
def _repr_html_(self):
|
149
|
|
"""Make html representation of InferenceData object."""
|
150
|
2
|
display_style = OPTIONS["display_style"]
|
151
|
2
|
if display_style == "text":
|
152
|
2
|
html_repr = f"<pre>{escape(repr(self))}</pre>"
|
153
|
|
else:
|
154
|
2
|
elements = "".join(
|
155
|
|
[
|
156
|
|
HtmlTemplate.element_template.format(
|
157
|
|
group_id=group + str(uuid.uuid4()),
|
158
|
|
group=group,
|
159
|
|
xr_data=getattr( # pylint: disable=protected-access
|
160
|
|
self, group
|
161
|
|
)._repr_html_(),
|
162
|
|
)
|
163
|
|
for group in self._groups_all
|
164
|
|
]
|
165
|
|
)
|
166
|
2
|
formatted_html_template = ( # pylint: disable=possibly-unused-variable
|
167
|
|
HtmlTemplate.html_template.format(elements)
|
168
|
|
)
|
169
|
2
|
css_template = HtmlTemplate.css_template # pylint: disable=possibly-unused-variable
|
170
|
2
|
html_repr = "%(formatted_html_template)s%(css_template)s" % locals()
|
171
|
2
|
return html_repr
|
172
|
|
|
173
|
2
|
def __delattr__(self, group):
|
174
|
|
"""Delete a group from the InferenceData object."""
|
175
|
2
|
if group in self._groups:
|
176
|
2
|
self._groups.remove(group)
|
177
|
0
|
elif group in self._groups_warmup:
|
178
|
0
|
self._groups_warmup.remove(group)
|
179
|
2
|
object.__delattr__(self, group)
|
180
|
|
|
181
|
2
|
@property
|
182
|
|
def _groups_all(self):
|
183
|
2
|
return self._groups + self._groups_warmup
|
184
|
|
|
185
|
2
|
def __iter__(self):
|
186
|
|
"""Iterate over groups in InferenceData object."""
|
187
|
2
|
for group in self._groups_all:
|
188
|
2
|
yield group
|
189
|
|
|
190
|
2
|
def groups(self):
|
191
|
|
"""Return all groups present in InferenceData object."""
|
192
|
2
|
return self._groups_all
|
193
|
|
|
194
|
2
|
def values(self):
|
195
|
|
"""Xarray Datasets present in InferenceData object."""
|
196
|
2
|
for group in self._groups_all:
|
197
|
2
|
yield getattr(self, group)
|
198
|
|
|
199
|
2
|
def items(self):
|
200
|
|
"""Yield groups and corresponding datasets present in InferenceData object."""
|
201
|
2
|
for group in self._groups_all:
|
202
|
2
|
yield (group, getattr(self, group))
|
203
|
|
|
204
|
2
|
@staticmethod
|
205
|
|
def from_netcdf(filename):
|
206
|
|
"""Initialize object from a netcdf file.
|
207
|
|
|
208
|
|
Expects that the file will have groups, each of which can be loaded by xarray.
|
209
|
|
By default, the datasets of the InferenceData object will be lazily loaded instead
|
210
|
|
of being loaded into memory. This
|
211
|
|
behaviour is regulated by the value of ``az.rcParams["data.load"]``.
|
212
|
|
|
213
|
|
Parameters
|
214
|
|
----------
|
215
|
|
filename : str
|
216
|
|
location of netcdf file
|
217
|
|
|
218
|
|
Returns
|
219
|
|
-------
|
220
|
|
InferenceData object
|
221
|
|
"""
|
222
|
2
|
groups = {}
|
223
|
2
|
with nc.Dataset(filename, mode="r") as data:
|
224
|
2
|
data_groups = list(data.groups)
|
225
|
|
|
226
|
2
|
for group in data_groups:
|
227
|
2
|
with xr.open_dataset(filename, group=group) as data:
|
228
|
2
|
if rcParams["data.load"] == "eager":
|
229
|
2
|
groups[group] = data.load()
|
230
|
|
else:
|
231
|
2
|
groups[group] = data
|
232
|
2
|
return InferenceData(**groups)
|
233
|
|
|
234
|
2
|
def to_netcdf(self, filename, compress=True, groups=None):
|
235
|
|
"""Write InferenceData to file using netcdf4.
|
236
|
|
|
237
|
|
Parameters
|
238
|
|
----------
|
239
|
|
filename : str
|
240
|
|
Location to write to
|
241
|
|
compress : bool, optional
|
242
|
|
Whether to compress result. Note this saves disk space, but may make
|
243
|
|
saving and loading somewhat slower (default: True).
|
244
|
|
groups : list, optional
|
245
|
|
Write only these groups to netcdf file.
|
246
|
|
|
247
|
|
Returns
|
248
|
|
-------
|
249
|
|
str
|
250
|
|
Location of netcdf file
|
251
|
|
"""
|
252
|
2
|
mode = "w" # overwrite first, then append
|
253
|
2
|
if self._groups_all: # check's whether a group is present or not.
|
254
|
2
|
if groups is None:
|
255
|
2
|
groups = self._groups_all
|
256
|
|
else:
|
257
|
2
|
groups = [group for group in self._groups_all if group in groups]
|
258
|
|
|
259
|
2
|
for group in groups:
|
260
|
2
|
data = getattr(self, group)
|
261
|
2
|
kwargs = {}
|
262
|
2
|
if compress:
|
263
|
2
|
kwargs["encoding"] = {var_name: {"zlib": True} for var_name in data.variables}
|
264
|
2
|
data.to_netcdf(filename, mode=mode, group=group, **kwargs)
|
265
|
2
|
data.close()
|
266
|
2
|
mode = "a"
|
267
|
|
else: # creates a netcdf file for an empty InferenceData object.
|
268
|
2
|
empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
|
269
|
2
|
empty_netcdf_file.close()
|
270
|
2
|
return filename
|
271
|
|
|
272
|
2
|
def to_dict(self, groups=None, filter_groups=None):
|
273
|
|
"""Convert InferenceData to a dictionary following xarray naming conventions.
|
274
|
|
|
275
|
|
Parameters
|
276
|
|
----------
|
277
|
|
groups : list, optional
|
278
|
|
Write only these groups to netcdf file.
|
279
|
|
|
280
|
|
Returns
|
281
|
|
-------
|
282
|
|
dict
|
283
|
|
A dictionary containing all groups of InferenceData object.
|
284
|
|
When `data=False` return just the schema.
|
285
|
|
"""
|
286
|
2
|
ret = defaultdict(dict)
|
287
|
2
|
attrs = None
|
288
|
2
|
if self._groups_all: # check's whether a group is present or not.
|
289
|
2
|
if groups is None:
|
290
|
2
|
groups = self._group_names(groups, filter_groups)
|
291
|
|
else:
|
292
|
0
|
groups = [group for group in self._groups_all if group in groups]
|
293
|
|
|
294
|
2
|
for group in groups:
|
295
|
2
|
dataset = getattr(self, group)
|
296
|
2
|
data = {}
|
297
|
2
|
for var_name, dataarray in dataset.items():
|
298
|
2
|
data[var_name] = dataarray.values
|
299
|
2
|
dims = []
|
300
|
2
|
for coord_name, coord_values in dataarray.coords.items():
|
301
|
2
|
if coord_name not in ("chain", "draw") and not coord_name.startswith(
|
302
|
|
var_name + "_dim_"
|
303
|
|
):
|
304
|
2
|
dims.append(coord_name)
|
305
|
2
|
ret["coords"][coord_name] = coord_values.values
|
306
|
|
|
307
|
2
|
if group in (
|
308
|
|
"predictions",
|
309
|
|
"predictions_constant_data",
|
310
|
|
):
|
311
|
0
|
dims_key = "pred_dims"
|
312
|
|
else:
|
313
|
2
|
dims_key = "dims"
|
314
|
2
|
if len(dims) > 0:
|
315
|
2
|
ret[dims_key][var_name] = dims
|
316
|
2
|
ret[group] = data
|
317
|
2
|
if attrs is None:
|
318
|
2
|
attrs = dataset.attrs
|
319
|
2
|
elif attrs != dataset.attrs:
|
320
|
2
|
warnings.warn(
|
321
|
|
"The attributes are not same for all groups."
|
322
|
|
" Considering only the first group `attrs`"
|
323
|
|
)
|
324
|
|
|
325
|
2
|
ret["attrs"] = attrs
|
326
|
2
|
return ret
|
327
|
|
|
328
|
2
|
def to_json(self, filename, **kwargs):
|
329
|
|
"""Write InferenceData to a json file.
|
330
|
|
|
331
|
|
Parameters
|
332
|
|
----------
|
333
|
|
filename : str
|
334
|
|
Location to write to
|
335
|
|
kwargs : dict
|
336
|
|
kwargs passed to json.dump()
|
337
|
|
|
338
|
|
Returns
|
339
|
|
-------
|
340
|
|
str
|
341
|
|
Location of json file
|
342
|
|
"""
|
343
|
2
|
idata_dict = _make_json_serializable(self.to_dict())
|
344
|
|
|
345
|
2
|
with open(filename, "w") as file:
|
346
|
2
|
json.dump(idata_dict, file, **kwargs)
|
347
|
|
|
348
|
2
|
return filename
|
349
|
|
|
350
|
2
|
def __add__(self, other):
|
351
|
|
"""Concatenate two InferenceData objects."""
|
352
|
2
|
return concat(self, other, copy=True, inplace=False)
|
353
|
|
|
354
|
2
|
def sel(
|
355
|
|
self,
|
356
|
|
groups=None,
|
357
|
|
filter_groups=None,
|
358
|
|
inplace=False,
|
359
|
|
chain_prior=None,
|
360
|
|
**kwargs,
|
361
|
|
):
|
362
|
|
"""Perform an xarray selection on all groups.
|
363
|
|
|
364
|
|
Loops groups to perform Dataset.sel(key=item)
|
365
|
|
for every kwarg if key is a dimension of the dataset.
|
366
|
|
One example could be performing a burn in cut on the InferenceData object
|
367
|
|
or discarding a chain. The selection is performed on all relevant groups (like
|
368
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
369
|
|
omitted. See :meth:`xarray.Dataset.sel <xarray:xarray.Dataset.sel>`
|
370
|
|
|
371
|
|
Parameters
|
372
|
|
----------
|
373
|
|
groups: str or list of str, optional
|
374
|
|
Groups where the selection is to be applied. Can either be group names
|
375
|
|
or metagroup names.
|
376
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
377
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
378
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
379
|
|
If "regex", interpret groups as regular expressions on the real group or
|
380
|
|
metagroup names. A la `pandas.filter`.
|
381
|
|
inplace: bool, optional
|
382
|
|
If ``True``, modify the InferenceData object inplace,
|
383
|
|
otherwise, return the modified copy.
|
384
|
|
chain_prior: bool, optional, deprecated
|
385
|
|
If ``False``, do not select prior related groups using ``chain`` dim.
|
386
|
|
Otherwise, use selection on ``chain`` if present. Default=False
|
387
|
|
**kwargs: mapping
|
388
|
|
It must be accepted by Dataset.sel().
|
389
|
|
|
390
|
|
Returns
|
391
|
|
-------
|
392
|
|
InferenceData
|
393
|
|
A new InferenceData object by default.
|
394
|
|
When `inplace==True` perform selection in-place and return `None`
|
395
|
|
|
396
|
|
Examples
|
397
|
|
--------
|
398
|
|
Use ``sel`` to discard one chain of the InferenceData object. We first check the
|
399
|
|
dimensions of the original object:
|
400
|
|
|
401
|
|
.. ipython::
|
402
|
|
|
403
|
|
In [1]: import arviz as az
|
404
|
|
...: idata = az.load_arviz_data("centered_eight")
|
405
|
|
...: del idata.prior # prior group only has 1 chain currently
|
406
|
|
...: print(idata.posterior.coords)
|
407
|
|
...: print(idata.posterior_predictive.coords)
|
408
|
|
...: print(idata.observed_data.coords)
|
409
|
|
|
410
|
|
In order to remove the third chain:
|
411
|
|
|
412
|
|
.. ipython::
|
413
|
|
|
414
|
|
In [1]: idata_subset = idata.sel(chain=[0, 1, 3])
|
415
|
|
...: print(idata_subset.posterior.coords)
|
416
|
|
...: print(idata_subset.posterior_predictive.coords)
|
417
|
|
...: print(idata_subset.observed_data.coords)
|
418
|
|
|
419
|
|
"""
|
420
|
2
|
if chain_prior is not None:
|
421
|
2
|
warnings.warn(
|
422
|
|
"chain_prior has been deprecated. Use groups argument and "
|
423
|
|
"rcParams['data.metagroups'] instead.",
|
424
|
|
DeprecationWarning,
|
425
|
|
)
|
426
|
|
else:
|
427
|
2
|
chain_prior = False
|
428
|
2
|
groups = self._group_names(groups, filter_groups)
|
429
|
|
|
430
|
2
|
out = self if inplace else deepcopy(self)
|
431
|
2
|
for group in groups:
|
432
|
2
|
dataset = getattr(self, group)
|
433
|
2
|
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
|
434
|
2
|
if not chain_prior and "prior" in group:
|
435
|
2
|
valid_keys -= {"chain"}
|
436
|
2
|
dataset = dataset.sel(**{key: kwargs[key] for key in valid_keys})
|
437
|
2
|
setattr(out, group, dataset)
|
438
|
2
|
if inplace:
|
439
|
2
|
return None
|
440
|
|
else:
|
441
|
2
|
return out
|
442
|
|
|
443
|
2
|
def isel(
|
444
|
|
self,
|
445
|
|
groups=None,
|
446
|
|
filter_groups=None,
|
447
|
|
inplace=False,
|
448
|
|
**kwargs,
|
449
|
|
):
|
450
|
|
"""Perform an xarray selection on all groups.
|
451
|
|
|
452
|
|
Loops groups to perform Dataset.isel(key=item)
|
453
|
|
for every kwarg if key is a dimension of the dataset.
|
454
|
|
One example could be performing a burn in cut on the InferenceData object
|
455
|
|
or discarding a chain. The selection is performed on all relevant groups (like
|
456
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
457
|
|
omitted. See :meth:`xarray:xarray.Dataset.isel`
|
458
|
|
|
459
|
|
Parameters
|
460
|
|
----------
|
461
|
|
groups: str or list of str, optional
|
462
|
|
Groups where the selection is to be applied. Can either be group names
|
463
|
|
or metagroup names.
|
464
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
465
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
466
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
467
|
|
If "regex", interpret groups as regular expressions on the real group or
|
468
|
|
metagroup names. A la `pandas.filter`.
|
469
|
|
inplace: bool, optional
|
470
|
|
If ``True``, modify the InferenceData object inplace,
|
471
|
|
otherwise, return the modified copy.
|
472
|
|
**kwargs: mapping
|
473
|
|
It must be accepted by :meth:`xarray:xarray.Dataset.isel`.
|
474
|
|
|
475
|
|
Returns
|
476
|
|
-------
|
477
|
|
InferenceData
|
478
|
|
A new InferenceData object by default.
|
479
|
|
When `inplace==True` perform selection in-place and return `None`
|
480
|
|
|
481
|
|
"""
|
482
|
2
|
groups = self._group_names(groups, filter_groups)
|
483
|
|
|
484
|
2
|
out = self if inplace else deepcopy(self)
|
485
|
2
|
for group in groups:
|
486
|
2
|
dataset = getattr(self, group)
|
487
|
2
|
valid_keys = set(kwargs.keys()).intersection(dataset.dims)
|
488
|
2
|
dataset = dataset.isel(**{key: kwargs[key] for key in valid_keys})
|
489
|
2
|
setattr(out, group, dataset)
|
490
|
2
|
if inplace:
|
491
|
2
|
return None
|
492
|
|
else:
|
493
|
2
|
return out
|
494
|
|
|
495
|
2
|
def stack(
|
496
|
|
self,
|
497
|
|
dimensions=None,
|
498
|
|
groups=None,
|
499
|
|
filter_groups=None,
|
500
|
|
inplace=False,
|
501
|
|
**kwargs,
|
502
|
|
):
|
503
|
|
"""Perform an xarray stacking on all groups.
|
504
|
|
|
505
|
|
Stack any number of existing dimensions into a single new dimension.
|
506
|
|
Loops groups to perform Dataset.stack(key=value)
|
507
|
|
for every kwarg if value is a dimension of the dataset.
|
508
|
|
The selection is performed on all relevant groups (like
|
509
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
510
|
|
omitted. See :meth:`xarray:xarray.Dataset.stack`
|
511
|
|
|
512
|
|
Parameters
|
513
|
|
----------
|
514
|
|
dimensions: dict
|
515
|
|
Names of new dimensions, and the existing dimensions that they replace.
|
516
|
|
groups: str or list of str, optional
|
517
|
|
Groups where the selection is to be applied. Can either be group names
|
518
|
|
or metagroup names.
|
519
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
520
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
521
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
522
|
|
If "regex", interpret groups as regular expressions on the real group or
|
523
|
|
metagroup names. A la `pandas.filter`.
|
524
|
|
inplace: bool, optional
|
525
|
|
If ``True``, modify the InferenceData object inplace,
|
526
|
|
otherwise, return the modified copy.
|
527
|
|
**kwargs: mapping
|
528
|
|
It must be accepted by :meth:`xarray:xarray.Dataset.stack`.
|
529
|
|
|
530
|
|
Returns
|
531
|
|
-------
|
532
|
|
InferenceData
|
533
|
|
A new InferenceData object by default.
|
534
|
|
When `inplace==True` perform selection in-place and return `None`
|
535
|
|
|
536
|
|
"""
|
537
|
2
|
groups = self._group_names(groups, filter_groups)
|
538
|
|
|
539
|
2
|
dimensions = {} if dimensions is None else dimensions
|
540
|
2
|
dimensions.update(kwargs)
|
541
|
2
|
out = self if inplace else deepcopy(self)
|
542
|
2
|
for group in groups:
|
543
|
2
|
dataset = getattr(self, group)
|
544
|
2
|
kwarg_dict = {}
|
545
|
2
|
for key, value in dimensions.items():
|
546
|
2
|
if not set(value).difference(dataset.dims):
|
547
|
2
|
kwarg_dict[key] = value
|
548
|
2
|
dataset = dataset.stack(**kwarg_dict)
|
549
|
2
|
setattr(out, group, dataset)
|
550
|
2
|
if inplace:
|
551
|
0
|
return None
|
552
|
|
else:
|
553
|
2
|
return out
|
554
|
|
|
555
|
2
|
def unstack(self, dim=None, groups=None, filter_groups=None, inplace=False):
|
556
|
|
"""Perform an xarray unstacking on all groups.
|
557
|
|
|
558
|
|
Unstack existing dimensions corresponding to MultiIndexes into multiple new dimensions.
|
559
|
|
Loops groups to perform Dataset.unstack(key=value).
|
560
|
|
The selection is performed on all relevant groups (like posterior, prior,
|
561
|
|
sample stats) while non relevant groups like observed data are omitted.
|
562
|
|
See :meth:`xarray:xarray.Dataset.unstack`
|
563
|
|
|
564
|
|
Parameters
|
565
|
|
----------
|
566
|
|
dim: Hashable or iterable of Hashable, optional
|
567
|
|
Dimension(s) over which to unstack. By default unstacks all MultiIndexes.
|
568
|
|
groups: str or list of str, optional
|
569
|
|
Groups where the selection is to be applied. Can either be group names
|
570
|
|
or metagroup names.
|
571
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
572
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
573
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
574
|
|
If "regex", interpret groups as regular expressions on the real group or
|
575
|
|
metagroup names. A la `pandas.filter`.
|
576
|
|
inplace: bool, optional
|
577
|
|
If ``True``, modify the InferenceData object inplace,
|
578
|
|
otherwise, return the modified copy.
|
579
|
|
|
580
|
|
Returns
|
581
|
|
-------
|
582
|
|
InferenceData
|
583
|
|
A new InferenceData object by default.
|
584
|
|
When `inplace==True` perform selection in place and return `None`
|
585
|
|
|
586
|
|
"""
|
587
|
2
|
groups = self._group_names(groups, filter_groups)
|
588
|
2
|
if isinstance(dim, str):
|
589
|
2
|
dim = [dim]
|
590
|
|
|
591
|
2
|
out = self if inplace else deepcopy(self)
|
592
|
2
|
for group in groups:
|
593
|
2
|
dataset = getattr(self, group)
|
594
|
2
|
valid_dims = set(dim).intersection(dataset.dims) if dim is not None else dim
|
595
|
2
|
dataset = dataset.unstack(dim=valid_dims)
|
596
|
2
|
setattr(out, group, dataset)
|
597
|
2
|
if inplace:
|
598
|
0
|
return None
|
599
|
|
else:
|
600
|
2
|
return out
|
601
|
|
|
602
|
2
|
def rename(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
603
|
|
"""Perform xarray renaming of variable and dimensions on all groups.
|
604
|
|
|
605
|
|
Loops groups to perform Dataset.rename(name_dict)
|
606
|
|
for every key in name_dict if key is a dimension/data_vars of the dataset.
|
607
|
|
The renaming is performed on all relevant groups (like
|
608
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
609
|
|
omitted. See :meth:`xarray:xarray.Dataset.rename`
|
610
|
|
|
611
|
|
Parameters
|
612
|
|
----------
|
613
|
|
name_dict: dict
|
614
|
|
Dictionary whose keys are current variable or dimension names
|
615
|
|
and whose values are the desired names.
|
616
|
|
groups: str or list of str, optional
|
617
|
|
Groups where the selection is to be applied. Can either be group names
|
618
|
|
or metagroup names.
|
619
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
620
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
621
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
622
|
|
If "regex", interpret groups as regular expressions on the real group or
|
623
|
|
metagroup names. A la `pandas.filter`.
|
624
|
|
inplace: bool, optional
|
625
|
|
If ``True``, modify the InferenceData object inplace,
|
626
|
|
otherwise, return the modified copy.
|
627
|
|
|
628
|
|
|
629
|
|
Returns
|
630
|
|
-------
|
631
|
|
InferenceData
|
632
|
|
A new InferenceData object by default.
|
633
|
|
When `inplace==True` perform renaming in-place and return `None`
|
634
|
|
|
635
|
|
"""
|
636
|
2
|
groups = self._group_names(groups, filter_groups)
|
637
|
2
|
if "chain" in name_dict.keys() or "draw" in name_dict.keys():
|
638
|
0
|
raise KeyError("'chain' or 'draw' dimensions can't be renamed")
|
639
|
2
|
out = self if inplace else deepcopy(self)
|
640
|
|
|
641
|
2
|
for group in groups:
|
642
|
2
|
dataset = getattr(self, group)
|
643
|
2
|
expected_keys = list(dataset.data_vars) + list(dataset.dims)
|
644
|
2
|
valid_keys = set(name_dict.keys()).intersection(expected_keys)
|
645
|
2
|
dataset = dataset.rename({key: name_dict[key] for key in valid_keys})
|
646
|
2
|
setattr(out, group, dataset)
|
647
|
2
|
if inplace:
|
648
|
0
|
return None
|
649
|
|
else:
|
650
|
2
|
return out
|
651
|
|
|
652
|
2
|
def rename_vars(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
653
|
|
"""Perform xarray renaming of variable or coordinate names on all groups.
|
654
|
|
|
655
|
|
Loops groups to perform Dataset.rename_vars(name_dict)
|
656
|
|
for every key in name_dict if key is a variable or coordinate names of the dataset.
|
657
|
|
The renaming is performed on all relevant groups (like
|
658
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
659
|
|
omitted. See :meth:`xarray:xarray.Dataset.rename_vars`
|
660
|
|
|
661
|
|
Parameters
|
662
|
|
----------
|
663
|
|
name_dict: dict
|
664
|
|
Dictionary whose keys are current variable or coordinate names
|
665
|
|
and whose values are the desired names.
|
666
|
|
groups: str or list of str, optional
|
667
|
|
Groups where the selection is to be applied. Can either be group names
|
668
|
|
or metagroup names.
|
669
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
670
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
671
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
672
|
|
If "regex", interpret groups as regular expressions on the real group or
|
673
|
|
metagroup names. A la `pandas.filter`.
|
674
|
|
inplace: bool, optional
|
675
|
|
If ``True``, modify the InferenceData object inplace,
|
676
|
|
otherwise, return the modified copy.
|
677
|
|
|
678
|
|
|
679
|
|
Returns
|
680
|
|
-------
|
681
|
|
InferenceData
|
682
|
|
A new InferenceData object with renamed variables including coordinates by default.
|
683
|
|
When `inplace==True` perform renaming in-place and return `None`
|
684
|
|
|
685
|
|
"""
|
686
|
2
|
groups = self._group_names(groups, filter_groups)
|
687
|
|
|
688
|
2
|
out = self if inplace else deepcopy(self)
|
689
|
2
|
for group in groups:
|
690
|
2
|
dataset = getattr(self, group)
|
691
|
2
|
valid_keys = set(name_dict.keys()).intersection(dataset.data_vars)
|
692
|
2
|
dataset = dataset.rename_vars({key: name_dict[key] for key in valid_keys})
|
693
|
2
|
setattr(out, group, dataset)
|
694
|
2
|
if inplace:
|
695
|
0
|
return None
|
696
|
|
else:
|
697
|
2
|
return out
|
698
|
|
|
699
|
2
|
def rename_dims(self, name_dict=None, groups=None, filter_groups=None, inplace=False):
|
700
|
|
"""Perform xarray renaming of dimensions on all groups.
|
701
|
|
|
702
|
|
Loops groups to perform Dataset.rename_dims(name_dict)
|
703
|
|
for every key in name_dict if key is a dimension of the dataset.
|
704
|
|
The renaming is performed on all relevant groups (like
|
705
|
|
posterior, prior, sample stats) while non relevant groups like observed data are
|
706
|
|
omitted. See :meth:`xarray:xarray.Dataset.rename_dims`
|
707
|
|
|
708
|
|
Parameters
|
709
|
|
----------
|
710
|
|
name_dict: dict
|
711
|
|
Dictionary whose keys are current dimension names and whose values are the desired
|
712
|
|
names.
|
713
|
|
groups: str or list of str, optional
|
714
|
|
Groups where the selection is to be applied. Can either be group names
|
715
|
|
or metagroup names.
|
716
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
717
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
718
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
719
|
|
If "regex", interpret groups as regular expressions on the real group or
|
720
|
|
metagroup names. A la `pandas.filter`.
|
721
|
|
inplace: bool, optional
|
722
|
|
If ``True``, modify the InferenceData object inplace,
|
723
|
|
otherwise, return the modified copy.
|
724
|
|
|
725
|
|
|
726
|
|
Returns
|
727
|
|
-------
|
728
|
|
InferenceData
|
729
|
|
A new InferenceData object with renamed dimension by default.
|
730
|
|
When `inplace==True` perform renaming in-place and return `None`
|
731
|
|
|
732
|
|
"""
|
733
|
2
|
groups = self._group_names(groups, filter_groups)
|
734
|
2
|
if "chain" in name_dict.keys() or "draw" in name_dict.keys():
|
735
|
0
|
raise KeyError("'chain' or 'draw' dimensions can't be renamed")
|
736
|
|
|
737
|
2
|
out = self if inplace else deepcopy(self)
|
738
|
2
|
for group in groups:
|
739
|
2
|
dataset = getattr(self, group)
|
740
|
2
|
valid_keys = set(name_dict.keys()).intersection(dataset.dims)
|
741
|
2
|
dataset = dataset.rename_dims({key: name_dict[key] for key in valid_keys})
|
742
|
2
|
setattr(out, group, dataset)
|
743
|
2
|
if inplace:
|
744
|
0
|
return None
|
745
|
|
else:
|
746
|
2
|
return out
|
747
|
|
|
748
|
2
|
def add_groups(self, group_dict=None, coords=None, dims=None, **kwargs):
|
749
|
|
"""Add new groups to InferenceData object.
|
750
|
|
|
751
|
|
Parameters
|
752
|
|
----------
|
753
|
|
group_dict: dict of {str : dict or xarray.Dataset}, optional
|
754
|
|
Groups to be added
|
755
|
|
coords : dict[str] -> ndarray
|
756
|
|
Coordinates for the dataset
|
757
|
|
dims : dict[str] -> list[str]
|
758
|
|
Dimensions of each variable. The keys are variable names, values are lists of
|
759
|
|
coordinates.
|
760
|
|
**kwargs: mapping
|
761
|
|
The keyword arguments form of group_dict. One of group_dict or kwargs must be provided.
|
762
|
|
|
763
|
|
See Also
|
764
|
|
--------
|
765
|
|
extend : Extend InferenceData with groups from another InferenceData.
|
766
|
|
concat : Concatenate InferenceData objects.
|
767
|
|
"""
|
768
|
2
|
group_dict = either_dict_or_kwargs(group_dict, kwargs, "add_groups")
|
769
|
2
|
if not group_dict:
|
770
|
2
|
raise ValueError("One of group_dict or kwargs must be provided.")
|
771
|
2
|
repeated_groups = [group for group in group_dict.keys() if group in self._groups]
|
772
|
2
|
if repeated_groups:
|
773
|
2
|
raise ValueError("{} group(s) already exists.".format(repeated_groups))
|
774
|
2
|
for group, dataset in group_dict.items():
|
775
|
2
|
if group not in SUPPORTED_GROUPS_ALL:
|
776
|
2
|
warnings.warn(
|
777
|
|
"The group {} is not defined in the InferenceData scheme".format(group),
|
778
|
|
UserWarning,
|
779
|
|
)
|
780
|
2
|
if dataset is None:
|
781
|
0
|
continue
|
782
|
2
|
elif isinstance(dataset, dict):
|
783
|
2
|
if (
|
784
|
|
group in ("observed_data", "constant_data", "predictions_constant_data")
|
785
|
|
or group not in SUPPORTED_GROUPS_ALL
|
786
|
|
):
|
787
|
2
|
warnings.warn(
|
788
|
|
"the default dims 'chain' and 'draw' will be added automatically",
|
789
|
|
UserWarning,
|
790
|
|
)
|
791
|
2
|
dataset = dict_to_dataset(dataset, coords=coords, dims=dims)
|
792
|
2
|
elif isinstance(dataset, xr.DataArray):
|
793
|
0
|
if dataset.name is None:
|
794
|
0
|
dataset.name = "x"
|
795
|
0
|
dataset = dataset.to_dataset()
|
796
|
2
|
elif not isinstance(dataset, xr.Dataset):
|
797
|
2
|
raise ValueError(
|
798
|
|
"Arguments to add_groups() must be xr.Dataset, xr.Dataarray or dicts\
|
799
|
|
(argument '{}' was type '{}')".format(
|
800
|
|
group, type(dataset)
|
801
|
|
)
|
802
|
|
)
|
803
|
2
|
if dataset:
|
804
|
2
|
setattr(self, group, dataset)
|
805
|
2
|
if group.startswith(WARMUP_TAG):
|
806
|
0
|
self._groups_warmup.append(group)
|
807
|
|
else:
|
808
|
2
|
self._groups.append(group)
|
809
|
|
|
810
|
2
|
def extend(self, other, join="left"):
|
811
|
|
"""Extend InferenceData with groups from another InferenceData.
|
812
|
|
|
813
|
|
Parameters
|
814
|
|
----------
|
815
|
|
other : InferenceData
|
816
|
|
InferenceData to be added
|
817
|
|
join : {'left', 'right'}, default 'left'
|
818
|
|
Defines how the two decide which group to keep when the same group is
|
819
|
|
present in both objects. 'left' will discard the group in ``other`` whereas 'right'
|
820
|
|
will keep the group in ``other`` and discard the one in ``self``.
|
821
|
|
|
822
|
|
See Also
|
823
|
|
--------
|
824
|
|
add_groups : Add new groups to InferenceData object.
|
825
|
|
concat : Concatenate InferenceData objects.
|
826
|
|
|
827
|
|
"""
|
828
|
2
|
if not isinstance(other, InferenceData):
|
829
|
2
|
raise ValueError("Extending is possible between two InferenceData objects only.")
|
830
|
2
|
if join not in ("left", "right"):
|
831
|
2
|
raise ValueError("join must be either 'left' or 'right', found {}".format(join))
|
832
|
2
|
for group in other._groups_all: # pylint: disable=protected-access
|
833
|
2
|
if hasattr(self, group):
|
834
|
2
|
if join == "left":
|
835
|
2
|
continue
|
836
|
2
|
if group not in SUPPORTED_GROUPS_ALL:
|
837
|
2
|
warnings.warn(
|
838
|
|
"{} group is not defined in the InferenceData scheme".format(group), UserWarning
|
839
|
|
)
|
840
|
2
|
dataset = getattr(other, group)
|
841
|
2
|
setattr(self, group, dataset)
|
842
|
|
|
843
|
2
|
set_index = _extend_xr_method(xr.Dataset.set_index)
|
844
|
2
|
get_index = _extend_xr_method(xr.Dataset.get_index)
|
845
|
2
|
reset_index = _extend_xr_method(xr.Dataset.reset_index)
|
846
|
2
|
set_coords = _extend_xr_method(xr.Dataset.set_coords)
|
847
|
2
|
reset_coords = _extend_xr_method(xr.Dataset.reset_coords)
|
848
|
2
|
assign = _extend_xr_method(xr.Dataset.assign)
|
849
|
2
|
assign_coords = _extend_xr_method(xr.Dataset.assign_coords)
|
850
|
2
|
sortby = _extend_xr_method(xr.Dataset.sortby)
|
851
|
2
|
chunk = _extend_xr_method(xr.Dataset.chunk)
|
852
|
2
|
unify_chunks = _extend_xr_method(xr.Dataset.unify_chunks)
|
853
|
2
|
load = _extend_xr_method(xr.Dataset.load)
|
854
|
2
|
compute = _extend_xr_method(xr.Dataset.compute)
|
855
|
2
|
persist = _extend_xr_method(xr.Dataset.persist)
|
856
|
|
|
857
|
2
|
mean = _extend_xr_method(xr.Dataset.mean)
|
858
|
2
|
median = _extend_xr_method(xr.Dataset.median)
|
859
|
2
|
min = _extend_xr_method(xr.Dataset.min)
|
860
|
2
|
max = _extend_xr_method(xr.Dataset.max)
|
861
|
2
|
cumsum = _extend_xr_method(xr.Dataset.cumsum)
|
862
|
2
|
sum = _extend_xr_method(xr.Dataset.sum)
|
863
|
2
|
quantile = _extend_xr_method(xr.Dataset.quantile)
|
864
|
|
|
865
|
2
|
def _group_names(self, groups, filter_groups=None):
|
866
|
|
"""Handle expansion of group names input across arviz.
|
867
|
|
|
868
|
|
Parameters
|
869
|
|
----------
|
870
|
|
groups: str, list of str or None
|
871
|
|
group or metagroup names.
|
872
|
|
idata: xarray.Dataset
|
873
|
|
Posterior data in an xarray
|
874
|
|
filter_groups: {None, "like", "regex"}, optional, default=None
|
875
|
|
If `None` (default), interpret groups as the real group or metagroup names.
|
876
|
|
If "like", interpret groups as substrings of the real group or metagroup names.
|
877
|
|
If "regex", interpret groups as regular expressions on the real group or
|
878
|
|
metagroup names. A la `pandas.filter`.
|
879
|
|
|
880
|
|
Returns
|
881
|
|
-------
|
882
|
|
groups: list
|
883
|
|
"""
|
884
|
2
|
all_groups = self._groups_all
|
885
|
2
|
if groups is None:
|
886
|
2
|
return all_groups
|
887
|
2
|
if isinstance(groups, str):
|
888
|
2
|
groups = [groups]
|
889
|
2
|
sel_groups = []
|
890
|
2
|
metagroups = rcParams["data.metagroups"]
|
891
|
2
|
for group in groups:
|
892
|
2
|
if group[0] == "~":
|
893
|
2
|
sel_groups.extend(
|
894
|
|
[f"~{item}" for item in metagroups[group[1:]] if item in all_groups]
|
895
|
|
if group[1:] in metagroups
|
896
|
|
else [group]
|
897
|
|
)
|
898
|
|
else:
|
899
|
2
|
sel_groups.extend(
|
900
|
|
[item for item in metagroups[group] if item in all_groups]
|
901
|
|
if group in metagroups
|
902
|
|
else [group]
|
903
|
|
)
|
904
|
|
|
905
|
2
|
try:
|
906
|
2
|
group_names = _subset_list(sel_groups, all_groups, filter_items=filter_groups)
|
907
|
0
|
except KeyError as err:
|
908
|
0
|
msg = " ".join(("groups:", f"{err}", "in InferenceData"))
|
909
|
0
|
raise KeyError(msg) from err
|
910
|
2
|
return group_names
|
911
|
|
|
912
|
2
|
def map(self, fun, groups=None, filter_groups=None, inplace=False, args=None, **kwargs):
|
913
|
|
"""Apply a function to multiple groups.
|
914
|
|
|
915
|
|
Applies ``fun`` groupwise to the selected ``InferenceData`` groups and overwrites the
|
916
|
|
group with the result of the function.
|
917
|
|
|
918
|
|
Parameters
|
919
|
|
----------
|
920
|
|
fun : callable
|
921
|
|
Function to be applied to each group. Assumes the function is called as
|
922
|
|
``fun(dataset, *args, **kwargs)``.
|
923
|
|
groups : str or list of str, optional
|
924
|
|
Groups where the selection is to be applied. Can either be group names
|
925
|
|
or metagroup names.
|
926
|
|
filter_groups : {None, "like", "regex"}, optional
|
927
|
|
If `None` (default), interpret var_names as the real variables names. If "like",
|
928
|
|
interpret var_names as substrings of the real variables names. If "regex",
|
929
|
|
interpret var_names as regular expressions on the real variables names. A la
|
930
|
|
`pandas.filter`.
|
931
|
|
inplace : bool, optional
|
932
|
|
If ``True``, modify the InferenceData object inplace,
|
933
|
|
otherwise, return the modified copy.
|
934
|
|
args : array_like, optional
|
935
|
|
Positional arguments passed to ``fun``.
|
936
|
|
**kwargs : mapping, optional
|
937
|
|
Keyword arguments passed to ``fun``.
|
938
|
|
|
939
|
|
Returns
|
940
|
|
-------
|
941
|
|
InferenceData
|
942
|
|
A new InferenceData object by default.
|
943
|
|
When `inplace==True` perform selection in place and return `None`
|
944
|
|
|
945
|
|
Examples
|
946
|
|
--------
|
947
|
|
Shift observed_data, prior_predictive and posterior_predictive.
|
948
|
|
|
949
|
|
.. ipython::
|
950
|
|
|
951
|
|
In [1]: import arviz as az
|
952
|
|
...: idata = az.load_arviz_data("non_centered_eight")
|
953
|
|
...: idata_shifted_obs = idata.map(lambda x: x + 3, groups="observed_vars")
|
954
|
|
...: print(idata_shifted_obs.observed_data)
|
955
|
|
...: print(idata_shifted_obs.posterior_predictive)
|
956
|
|
|
957
|
|
Rename and update the coordinate values in both posterior and prior groups.
|
958
|
|
|
959
|
|
.. ipython::
|
960
|
|
|
961
|
|
In [1]: idata = az.load_arviz_data("radon")
|
962
|
|
...: idata = idata.map(
|
963
|
|
...: lambda ds: ds.rename({"g_coef": "uranium_coefs"}).assign(
|
964
|
|
...: uranium_coefs=["intercept", "u_slope"]
|
965
|
|
...: ),
|
966
|
|
...: groups=["posterior", "prior"]
|
967
|
|
...: )
|
968
|
|
...: idata.posterior
|
969
|
|
|
970
|
|
Add extra coordinates to all groups containing observed variables
|
971
|
|
|
972
|
|
.. ipython::
|
973
|
|
|
974
|
|
In [1]: idata = az.load_arviz_data("rugby")
|
975
|
|
...: home_team, away_team = np.array([
|
976
|
|
...: m.split() for m in idata.observed_data.match.values
|
977
|
|
...: ]).T
|
978
|
|
...: idata = idata.map(
|
979
|
|
...: lambda ds, **kwargs: ds.assign_coords(**kwargs),
|
980
|
|
...: groups="observed_vars",
|
981
|
|
...: home_team=("match", home_team),
|
982
|
|
...: away_team=("match", away_team),
|
983
|
|
...: )
|
984
|
|
...: print(idata.posterior_predictive)
|
985
|
|
...: print(idata.observed_data)
|
986
|
|
|
987
|
|
"""
|
988
|
2
|
if args is None:
|
989
|
0
|
args = []
|
990
|
2
|
groups = self._group_names(groups, filter_groups)
|
991
|
|
|
992
|
2
|
out = self if inplace else deepcopy(self)
|
993
|
2
|
for group in groups:
|
994
|
2
|
dataset = getattr(self, group)
|
995
|
2
|
dataset = fun(dataset, *args, **kwargs)
|
996
|
2
|
setattr(out, group, dataset)
|
997
|
2
|
if inplace:
|
998
|
0
|
return None
|
999
|
|
else:
|
1000
|
2
|
return out
|
1001
|
|
|
1002
|
2
|
def _wrap_xarray_method(
|
1003
|
|
self, method, groups=None, filter_groups=None, inplace=False, args=None, **kwargs
|
1004
|
|
):
|
1005
|
|
"""Extend and xarray.Dataset method to InferenceData object.
|
1006
|
|
|
1007
|
|
Parameters
|
1008
|
|
----------
|
1009
|
|
method: str
|
1010
|
|
Method to be extended. Must be a ``xarray.Dataset`` method.
|
1011
|
|
groups: str or list of str, optional
|
1012
|
|
Groups where the selection is to be applied. Can either be group names
|
1013
|
|
or metagroup names.
|
1014
|
|
inplace: bool, optional
|
1015
|
|
If ``True``, modify the InferenceData object inplace,
|
1016
|
|
otherwise, return the modified copy.
|
1017
|
|
**kwargs: mapping, optional
|
1018
|
|
Keyword arguments passed to the xarray Dataset method.
|
1019
|
|
|
1020
|
|
Returns
|
1021
|
|
-------
|
1022
|
|
InferenceData
|
1023
|
|
A new InferenceData object by default.
|
1024
|
|
When `inplace==True` perform selection in place and return `None`
|
1025
|
|
|
1026
|
|
Examples
|
1027
|
|
--------
|
1028
|
|
Compute the mean of `posterior_groups`:
|
1029
|
|
|
1030
|
|
.. ipython::
|
1031
|
|
|
1032
|
|
In [1]: import arviz as az
|
1033
|
|
...: idata = az.load_arviz_data("non_centered_eight")
|
1034
|
|
...: idata_means = idata._wrap_xarray_method("mean", groups="latent_vars")
|
1035
|
|
...: print(idata_means.posterior)
|
1036
|
|
...: print(idata_means.observed_data)
|
1037
|
|
|
1038
|
|
.. ipython::
|
1039
|
|
|
1040
|
|
In [1]: idata_stack = idata._wrap_xarray_method(
|
1041
|
|
...: "stack",
|
1042
|
|
...: groups=["posterior_groups", "prior_groups"],
|
1043
|
|
...: sample=["chain", "draw"]
|
1044
|
|
...: )
|
1045
|
|
...: print(idata_stack.posterior)
|
1046
|
|
...: print(idata_stack.prior)
|
1047
|
|
...: print(idata_stack.observed_data)
|
1048
|
|
|
1049
|
|
"""
|
1050
|
0
|
if args is None:
|
1051
|
0
|
args = []
|
1052
|
0
|
groups = self._group_names(groups, filter_groups)
|
1053
|
|
|
1054
|
0
|
method = getattr(xr.Dataset, method)
|
1055
|
|
|
1056
|
0
|
out = self if inplace else deepcopy(self)
|
1057
|
0
|
for group in groups:
|
1058
|
0
|
dataset = getattr(self, group)
|
1059
|
0
|
dataset = method(dataset, *args, **kwargs)
|
1060
|
0
|
setattr(out, group, dataset)
|
1061
|
0
|
if inplace:
|
1062
|
0
|
return None
|
1063
|
|
else:
|
1064
|
0
|
return out
|
1065
|
|
|
1066
|
|
|
1067
|
|
# pylint: disable=protected-access, inconsistent-return-statements
|
1068
|
2
|
def concat(*args, dim=None, copy=True, inplace=False, reset_dim=True):
|
1069
|
|
"""Concatenate InferenceData objects.
|
1070
|
|
|
1071
|
|
Concatenates over `group`, `chain` or `draw`.
|
1072
|
|
By default concatenates over unique groups.
|
1073
|
|
To concatenate over `chain` or `draw` function
|
1074
|
|
needs identical groups and variables.
|
1075
|
|
|
1076
|
|
The `variables` in the `data` -group are merged if `dim` are not found.
|
1077
|
|
|
1078
|
|
|
1079
|
|
Parameters
|
1080
|
|
----------
|
1081
|
|
*args : InferenceData
|
1082
|
|
Variable length InferenceData list or
|
1083
|
|
Sequence of InferenceData.
|
1084
|
|
dim : str, optional
|
1085
|
|
If defined, concatenated over the defined dimension.
|
1086
|
|
Dimension which is concatenated. If None, concatenates over
|
1087
|
|
unique groups.
|
1088
|
|
copy : bool
|
1089
|
|
If True, groups are copied to the new InferenceData object.
|
1090
|
|
Used only if `dim` is None.
|
1091
|
|
inplace : bool
|
1092
|
|
If True, merge args to first object.
|
1093
|
|
reset_dim : bool
|
1094
|
|
Valid only if dim is not None.
|
1095
|
|
|
1096
|
|
Returns
|
1097
|
|
-------
|
1098
|
|
InferenceData
|
1099
|
|
A new InferenceData object by default.
|
1100
|
|
When `inplace==True` merge args to first arg and return `None`
|
1101
|
|
|
1102
|
|
See Also
|
1103
|
|
--------
|
1104
|
|
add_groups : Add new groups to InferenceData object.
|
1105
|
|
extend : Extend InferenceData with groups from another InferenceData.
|
1106
|
|
|
1107
|
|
Examples
|
1108
|
|
--------
|
1109
|
|
Use ``concat`` method to concatenate InferenceData objects. This will concatenates over
|
1110
|
|
unique groups by default. We first create an ``InferenceData`` object:
|
1111
|
|
|
1112
|
|
.. ipython::
|
1113
|
|
|
1114
|
|
In [1]: import arviz as az
|
1115
|
|
...: import numpy as np
|
1116
|
|
...: data = {
|
1117
|
|
...: "a": np.random.normal(size=(4, 100, 3)),
|
1118
|
|
...: "b": np.random.normal(size=(4, 100)),
|
1119
|
|
...: }
|
1120
|
|
...: coords = {"a_dim": ["x", "y", "z"]}
|
1121
|
|
...: dataA = az.from_dict(data, coords=coords, dims={"a": ["a_dim"]})
|
1122
|
|
...: dataA
|
1123
|
|
|
1124
|
|
We have created an ``InferenceData`` object with default group 'posterior'. Now, we will
|
1125
|
|
create another ``InferenceData`` object:
|
1126
|
|
|
1127
|
|
.. ipython::
|
1128
|
|
|
1129
|
|
In [1]: dataB = az.from_dict(prior=data, coords=coords, dims={"a": ["a_dim"]})
|
1130
|
|
...: dataB
|
1131
|
|
|
1132
|
|
We have created another ``InferenceData`` object with group 'prior'. Now, we will concatenate
|
1133
|
|
these two ``InferenceData`` objects:
|
1134
|
|
|
1135
|
|
.. ipython::
|
1136
|
|
|
1137
|
|
In [1]: az.concat(dataA, dataB)
|
1138
|
|
|
1139
|
|
Now, we will concatenate over chain (or draw). It requires identical groups and variables.
|
1140
|
|
Here we are concatenating two identical ``InferenceData`` objects over dimension chain:
|
1141
|
|
|
1142
|
|
.. ipython::
|
1143
|
|
|
1144
|
|
In [1]: az.concat(dataA, dataA, dim="chain")
|
1145
|
|
|
1146
|
|
It will create an ``InferenceData`` with the original group 'posterior'. In similar way,
|
1147
|
|
we can also concatenate over draws.
|
1148
|
|
|
1149
|
|
"""
|
1150
|
|
# pylint: disable=undefined-loop-variable, too-many-nested-blocks
|
1151
|
2
|
if len(args) == 0:
|
1152
|
2
|
if inplace:
|
1153
|
0
|
return
|
1154
|
2
|
return InferenceData()
|
1155
|
|
|
1156
|
2
|
if len(args) == 1 and isinstance(args[0], Sequence):
|
1157
|
2
|
args = args[0]
|
1158
|
|
|
1159
|
|
# assert that all args are InferenceData
|
1160
|
2
|
for i, arg in enumerate(args):
|
1161
|
2
|
if not isinstance(arg, InferenceData):
|
1162
|
2
|
raise TypeError(
|
1163
|
|
"Concatenating is supported only"
|
1164
|
|
"between InferenceData objects. Input arg {} is {}".format(i, type(arg))
|
1165
|
|
)
|
1166
|
|
|
1167
|
2
|
if dim is not None and dim.lower() not in {"group", "chain", "draw"}:
|
1168
|
0
|
msg = "Invalid `dim`: {}. Valid `dim` are {}".format(dim, '{"group", "chain", "draw"}')
|
1169
|
0
|
raise TypeError(msg)
|
1170
|
2
|
dim = dim.lower() if dim is not None else dim
|
1171
|
|
|
1172
|
2
|
if len(args) == 1 and isinstance(args[0], InferenceData):
|
1173
|
2
|
if inplace:
|
1174
|
2
|
return None
|
1175
|
|
else:
|
1176
|
2
|
if copy:
|
1177
|
2
|
return deepcopy(args[0])
|
1178
|
|
else:
|
1179
|
2
|
return args[0]
|
1180
|
|
|
1181
|
2
|
current_time = str(datetime.now())
|
1182
|
|
|
1183
|
2
|
if not inplace:
|
1184
|
|
# Keep order for python 3.5
|
1185
|
2
|
inference_data_dict = OrderedDict()
|
1186
|
|
|
1187
|
2
|
if dim is None:
|
1188
|
2
|
arg0 = args[0]
|
1189
|
2
|
arg0_groups = ccopy(arg0._groups_all)
|
1190
|
2
|
args_groups = dict()
|
1191
|
|
# check if groups are independent
|
1192
|
|
# Concat over unique groups
|
1193
|
2
|
for arg in args[1:]:
|
1194
|
2
|
for group in arg._groups_all:
|
1195
|
2
|
if group in args_groups or group in arg0_groups:
|
1196
|
2
|
msg = (
|
1197
|
|
"Concatenating overlapping groups is not supported unless `dim` is defined."
|
1198
|
|
" Valid dimensions are `chain` and `draw`. Alternatively, use extend to"
|
1199
|
|
" combine InferenceData with overlapping groups"
|
1200
|
|
)
|
1201
|
2
|
raise TypeError(msg)
|
1202
|
2
|
group_data = getattr(arg, group)
|
1203
|
2
|
args_groups[group] = deepcopy(group_data) if copy else group_data
|
1204
|
|
# add arg0 to args_groups if inplace is False
|
1205
|
|
# otherwise it will merge args_groups to arg0
|
1206
|
|
# inference data object
|
1207
|
2
|
if not inplace:
|
1208
|
2
|
for group in arg0_groups:
|
1209
|
2
|
group_data = getattr(arg0, group)
|
1210
|
2
|
args_groups[group] = deepcopy(group_data) if copy else group_data
|
1211
|
|
|
1212
|
2
|
other_groups = [group for group in args_groups if group not in SUPPORTED_GROUPS_ALL]
|
1213
|
|
|
1214
|
2
|
for group in SUPPORTED_GROUPS_ALL + other_groups:
|
1215
|
2
|
if group not in args_groups:
|
1216
|
2
|
continue
|
1217
|
2
|
if inplace:
|
1218
|
2
|
if group.startswith(WARMUP_TAG):
|
1219
|
0
|
arg0._groups_warmup.append(group)
|
1220
|
|
else:
|
1221
|
2
|
arg0._groups.append(group)
|
1222
|
2
|
setattr(arg0, group, args_groups[group])
|
1223
|
|
else:
|
1224
|
2
|
inference_data_dict[group] = args_groups[group]
|
1225
|
2
|
if inplace:
|
1226
|
2
|
other_groups = [
|
1227
|
|
group for group in arg0_groups if group not in SUPPORTED_GROUPS_ALL
|
1228
|
|
] + other_groups
|
1229
|
2
|
sorted_groups = [
|
1230
|
|
group for group in SUPPORTED_GROUPS + other_groups if group in arg0._groups
|
1231
|
|
]
|
1232
|
2
|
setattr(arg0, "_groups", sorted_groups)
|
1233
|
2
|
sorted_groups_warmup = [
|
1234
|
|
group
|
1235
|
|
for group in SUPPORTED_GROUPS_WARMUP + other_groups
|
1236
|
|
if group in arg0._groups_warmup
|
1237
|
|
]
|
1238
|
2
|
setattr(arg0, "_groups_warmup", sorted_groups_warmup)
|
1239
|
|
else:
|
1240
|
2
|
arg0 = args[0]
|
1241
|
2
|
arg0_groups = arg0._groups_all
|
1242
|
2
|
for arg in args[1:]:
|
1243
|
2
|
for group0 in arg0_groups:
|
1244
|
2
|
if group0 not in arg._groups_all:
|
1245
|
2
|
if group0 == "observed_data":
|
1246
|
0
|
continue
|
1247
|
2
|
msg = "Mismatch between the groups."
|
1248
|
2
|
raise TypeError(msg)
|
1249
|
2
|
for group in arg._groups_all:
|
1250
|
|
# handle data groups seperately
|
1251
|
2
|
if group not in ["observed_data", "constant_data", "predictions_constant_data"]:
|
1252
|
|
# assert that groups are equal
|
1253
|
2
|
if group not in arg0_groups:
|
1254
|
0
|
msg = "Mismatch between the groups."
|
1255
|
0
|
raise TypeError(msg)
|
1256
|
|
|
1257
|
|
# assert that variables are equal
|
1258
|
2
|
group_data = getattr(arg, group)
|
1259
|
2
|
group_vars = group_data.data_vars
|
1260
|
|
|
1261
|
2
|
if not inplace and group in inference_data_dict:
|
1262
|
2
|
group0_data = inference_data_dict[group]
|
1263
|
|
else:
|
1264
|
2
|
group0_data = getattr(arg0, group)
|
1265
|
2
|
group0_vars = group0_data.data_vars
|
1266
|
|
|
1267
|
2
|
for var in group0_vars:
|
1268
|
2
|
if var not in group_vars:
|
1269
|
2
|
msg = "Mismatch between the variables."
|
1270
|
2
|
raise TypeError(msg)
|
1271
|
|
|
1272
|
2
|
for var in group_vars:
|
1273
|
2
|
if var not in group0_vars:
|
1274
|
2
|
msg = "Mismatch between the variables."
|
1275
|
2
|
raise TypeError(msg)
|
1276
|
2
|
var_dims = getattr(group_data, var).dims
|
1277
|
2
|
var0_dims = getattr(group0_data, var).dims
|
1278
|
2
|
if var_dims != var0_dims:
|
1279
|
0
|
msg = "Mismatch between the dimensions."
|
1280
|
0
|
raise TypeError(msg)
|
1281
|
|
|
1282
|
2
|
if dim not in var_dims or dim not in var0_dims:
|
1283
|
0
|
msg = "Dimension {} missing.".format(dim)
|
1284
|
0
|
raise TypeError(msg)
|
1285
|
|
|
1286
|
|
# xr.concat
|
1287
|
2
|
concatenated_group = xr.concat((group0_data, group_data), dim=dim)
|
1288
|
2
|
if reset_dim:
|
1289
|
2
|
concatenated_group[dim] = range(concatenated_group[dim].size)
|
1290
|
|
|
1291
|
|
# handle attrs
|
1292
|
2
|
if hasattr(group0_data, "attrs"):
|
1293
|
2
|
group0_attrs = deepcopy(getattr(group0_data, "attrs"))
|
1294
|
|
else:
|
1295
|
0
|
group0_attrs = OrderedDict()
|
1296
|
|
|
1297
|
2
|
if hasattr(group_data, "attrs"):
|
1298
|
2
|
group_attrs = getattr(group_data, "attrs")
|
1299
|
|
else:
|
1300
|
0
|
group_attrs = dict()
|
1301
|
|
|
1302
|
|
# gather attrs results to group0_attrs
|
1303
|
2
|
for attr_key, attr_values in group_attrs.items():
|
1304
|
2
|
group0_attr_values = group0_attrs.get(attr_key, None)
|
1305
|
2
|
equality = attr_values == group0_attr_values
|
1306
|
2
|
if hasattr(equality, "__iter__"):
|
1307
|
0
|
equality = np.all(equality)
|
1308
|
2
|
if equality:
|
1309
|
2
|
continue
|
1310
|
|
# handle special cases:
|
1311
|
2
|
if attr_key in ("created_at", "previous_created_at"):
|
1312
|
|
# check the defaults
|
1313
|
2
|
if not hasattr(group0_attrs, "previous_created_at"):
|
1314
|
2
|
group0_attrs["previous_created_at"] = []
|
1315
|
2
|
if group0_attr_values is not None:
|
1316
|
2
|
group0_attrs["previous_created_at"].append(group0_attr_values)
|
1317
|
|
# check previous values
|
1318
|
2
|
if attr_key == "previous_created_at":
|
1319
|
0
|
if not isinstance(attr_values, list):
|
1320
|
0
|
attr_values = [attr_values]
|
1321
|
0
|
group0_attrs["previous_created_at"].extend(attr_values)
|
1322
|
0
|
continue
|
1323
|
|
# update "created_at"
|
1324
|
2
|
if group0_attr_values != current_time:
|
1325
|
2
|
group0_attrs[attr_key] = current_time
|
1326
|
2
|
group0_attrs["previous_created_at"].append(attr_values)
|
1327
|
|
|
1328
|
0
|
elif attr_key in group0_attrs:
|
1329
|
0
|
combined_key = "combined_{}".format(attr_key)
|
1330
|
0
|
if combined_key not in group0_attrs:
|
1331
|
0
|
group0_attrs[combined_key] = [group0_attr_values]
|
1332
|
0
|
group0_attrs[combined_key].append(attr_values)
|
1333
|
|
else:
|
1334
|
0
|
group0_attrs[attr_key] = attr_values
|
1335
|
|
# update attrs
|
1336
|
2
|
setattr(concatenated_group, "attrs", group0_attrs)
|
1337
|
|
|
1338
|
2
|
if inplace:
|
1339
|
2
|
setattr(arg0, group, concatenated_group)
|
1340
|
|
else:
|
1341
|
2
|
inference_data_dict[group] = concatenated_group
|
1342
|
|
else:
|
1343
|
|
# observed_data, "constant_data", "predictions_constant_data",
|
1344
|
2
|
if group not in arg0_groups:
|
1345
|
0
|
setattr(arg0, group, deepcopy(group_data) if copy else group_data)
|
1346
|
0
|
arg0._groups.append(group)
|
1347
|
0
|
continue
|
1348
|
|
|
1349
|
|
# assert that variables are equal
|
1350
|
2
|
group_data = getattr(arg, group)
|
1351
|
2
|
group_vars = group_data.data_vars
|
1352
|
|
|
1353
|
2
|
group0_data = getattr(arg0, group)
|
1354
|
2
|
if not inplace:
|
1355
|
2
|
group0_data = deepcopy(group0_data)
|
1356
|
2
|
group0_vars = group0_data.data_vars
|
1357
|
|
|
1358
|
2
|
for var in group_vars:
|
1359
|
2
|
if var not in group0_vars:
|
1360
|
0
|
var_data = getattr(group_data, var)
|
1361
|
0
|
getattr(arg0, group)[var] = var_data
|
1362
|
|
else:
|
1363
|
2
|
var_data = getattr(group_data, var)
|
1364
|
2
|
var0_data = getattr(group0_data, var)
|
1365
|
2
|
if dim in var_data.dims and dim in var0_data.dims:
|
1366
|
0
|
concatenated_var = xr.concat((group_data, group0_data), dim=dim)
|
1367
|
0
|
group0_data[var] = concatenated_var
|
1368
|
|
|
1369
|
|
# handle attrs
|
1370
|
2
|
if hasattr(group0_data, "attrs"):
|
1371
|
2
|
group0_attrs = getattr(group0_data, "attrs")
|
1372
|
|
else:
|
1373
|
0
|
group0_attrs = OrderedDict()
|
1374
|
|
|
1375
|
2
|
if hasattr(group_data, "attrs"):
|
1376
|
2
|
group_attrs = getattr(group_data, "attrs")
|
1377
|
|
else:
|
1378
|
0
|
group_attrs = dict()
|
1379
|
|
|
1380
|
|
# gather attrs results to group0_attrs
|
1381
|
2
|
for attr_key, attr_values in group_attrs.items():
|
1382
|
2
|
group0_attr_values = group0_attrs.get(attr_key, None)
|
1383
|
2
|
equality = attr_values == group0_attr_values
|
1384
|
2
|
if hasattr(equality, "__iter__"):
|
1385
|
0
|
equality = np.all(equality)
|
1386
|
2
|
if equality:
|
1387
|
2
|
continue
|
1388
|
|
# handle special cases:
|
1389
|
2
|
if attr_key in ("created_at", "previous_created_at"):
|
1390
|
|
# check the defaults
|
1391
|
2
|
if not hasattr(group0_attrs, "previous_created_at"):
|
1392
|
2
|
group0_attrs["previous_created_at"] = []
|
1393
|
2
|
if group0_attr_values is not None:
|
1394
|
2
|
group0_attrs["previous_created_at"].append(group0_attr_values)
|
1395
|
|
# check previous values
|
1396
|
2
|
if attr_key == "previous_created_at":
|
1397
|
0
|
if not isinstance(attr_values, list):
|
1398
|
0
|
attr_values = [attr_values]
|
1399
|
0
|
group0_attrs["previous_created_at"].extend(attr_values)
|
1400
|
0
|
continue
|
1401
|
|
# update "created_at"
|
1402
|
2
|
if group0_attr_values != current_time:
|
1403
|
2
|
group0_attrs[attr_key] = current_time
|
1404
|
2
|
group0_attrs["previous_created_at"].append(attr_values)
|
1405
|
|
|
1406
|
0
|
elif attr_key in group0_attrs:
|
1407
|
0
|
combined_key = "combined_{}".format(attr_key)
|
1408
|
0
|
if combined_key not in group0_attrs:
|
1409
|
0
|
group0_attrs[combined_key] = [group0_attr_values]
|
1410
|
0
|
group0_attrs[combined_key].append(attr_values)
|
1411
|
|
|
1412
|
|
else:
|
1413
|
0
|
group0_attrs[attr_key] = attr_values
|
1414
|
|
# update attrs
|
1415
|
2
|
setattr(group0_data, "attrs", group0_attrs)
|
1416
|
|
|
1417
|
2
|
if inplace:
|
1418
|
2
|
setattr(arg0, group, group0_data)
|
1419
|
|
else:
|
1420
|
2
|
inference_data_dict[group] = group0_data
|
1421
|
|
|
1422
|
2
|
return None if inplace else InferenceData(**inference_data_dict)
|