1
"""Dictionary specific conversion code."""
2 6
import warnings
3

4 6
import xarray as xr
5

6 6
from .. import utils
7 6
from ..rcparams import rcParams
8 6
from .base import dict_to_dataset, generate_dims_coords, make_attrs, requires
9 6
from .inference_data import WARMUP_TAG, InferenceData
10

11

12
# pylint: disable=too-many-instance-attributes
13 6
class DictConverter:
14
    """Encapsulate Dictionary specific logic."""
15

16 6
    def __init__(
17
        self,
18
        *,
19
        posterior=None,
20
        posterior_predictive=None,
21
        predictions=None,
22
        sample_stats=None,
23
        log_likelihood=None,
24
        prior=None,
25
        prior_predictive=None,
26
        sample_stats_prior=None,
27
        observed_data=None,
28
        constant_data=None,
29
        predictions_constant_data=None,
30
        warmup_posterior=None,
31
        warmup_posterior_predictive=None,
32
        warmup_predictions=None,
33
        warmup_log_likelihood=None,
34
        warmup_sample_stats=None,
35
        save_warmup=None,
36
        coords=None,
37
        dims=None,
38
        pred_dims=None,
39
        pred_coords=None,
40
        attrs=None,
41
    ):
42 4
        self.posterior = posterior
43 4
        self.posterior_predictive = posterior_predictive
44 4
        self.predictions = predictions
45 4
        self.sample_stats = sample_stats
46 4
        self.log_likelihood = log_likelihood
47 4
        self.prior = prior
48 4
        self.prior_predictive = prior_predictive
49 4
        self.sample_stats_prior = sample_stats_prior
50 4
        self.observed_data = observed_data
51 4
        self.constant_data = constant_data
52 4
        self.predictions_constant_data = predictions_constant_data
53 4
        self.warmup_posterior = warmup_posterior
54 4
        self.warmup_posterior_predictive = warmup_posterior_predictive
55 4
        self.warmup_predictions = warmup_predictions
56 4
        self.warmup_log_likelihood = warmup_log_likelihood
57 4
        self.warmup_sample_stats = warmup_sample_stats
58 4
        self.save_warmup = rcParams["data.save_warmup"] if save_warmup is None else save_warmup
59 4
        self.coords = (
60
            coords
61
            if pred_coords is None
62
            else pred_coords
63
            if coords is None
64
            else {**coords, **pred_coords}
65
        )
66 4
        self.dims = dims
67 4
        self.pred_dims = dims if pred_dims is None else pred_dims
68 4
        self.attrs = {} if attrs is None else attrs
69 4
        self.attrs.pop("created_at", None)
70 4
        self.attrs.pop("arviz_version", None)
71

72 6
    def _init_dict(self, attr_name):
73 4
        dict_or_none = getattr(self, attr_name, {})
74 4
        return {} if dict_or_none is None else dict_or_none
75

76 6
    @requires(["posterior", f"{WARMUP_TAG}posterior"])
77 4
    def posterior_to_xarray(self):
78
        """Convert posterior samples to xarray."""
79 4
        data = self._init_dict("posterior")
80 4
        data_warmup = self._init_dict(f"{WARMUP_TAG}posterior")
81 4
        if not isinstance(data, dict):
82 4
            raise TypeError("DictConverter.posterior is not a dictionary")
83 4
        if not isinstance(data_warmup, dict):
84 0
            raise TypeError("DictConverter.warmup_posterior is not a dictionary")
85

86 4
        if "log_likelihood" in data:
87 4
            warnings.warn(
88
                "log_likelihood variable found in posterior group."
89
                " For stats functions log likelihood data needs to be in log_likelihood group.",
90
                UserWarning,
91
            )
92

93 4
        return (
94
            dict_to_dataset(
95
                data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
96
            ),
97
            dict_to_dataset(
98
                data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
99
            ),
100
        )
101

102 6
    @requires(["sample_stats", f"{WARMUP_TAG}sample_stats"])
103 4
    def sample_stats_to_xarray(self):
104
        """Convert sample_stats samples to xarray."""
105 4
        data = self._init_dict("sample_stats")
106 4
        data_warmup = self._init_dict(f"{WARMUP_TAG}sample_stats")
107 4
        if not isinstance(data, dict):
108 4
            raise TypeError("DictConverter.sample_stats is not a dictionary")
109 4
        if not isinstance(data_warmup, dict):
110 0
            raise TypeError("DictConverter.warmup_sample_stats is not a dictionary")
111

112 4
        if "log_likelihood" in data:
113 4
            warnings.warn(
114
                "log_likelihood variable found in sample_stats."
115
                " Storing log_likelihood data in sample_stats group will be deprecated in "
116
                "favour of storing them in the log_likelihood group.",
117
                PendingDeprecationWarning,
118
            )
119

120 4
        return (
121
            dict_to_dataset(
122
                data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
123
            ),
124
            dict_to_dataset(
125
                data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
126
            ),
127
        )
128

129 6
    @requires(["log_likelihood", f"{WARMUP_TAG}log_likelihood"])
130 4
    def log_likelihood_to_xarray(self):
131
        """Convert log_likelihood samples to xarray."""
132 4
        data = self._init_dict("log_likelihood")
133 4
        data_warmup = self._init_dict(f"{WARMUP_TAG}log_likelihood")
134 4
        if not isinstance(data, dict):
135 0
            raise TypeError("DictConverter.log_likelihood is not a dictionary")
136 4
        if not isinstance(data_warmup, dict):
137 0
            raise TypeError("DictConverter.warmup_log_likelihood is not a dictionary")
138

139 4
        return (
140
            dict_to_dataset(
141
                data,
142
                library=None,
143
                coords=self.coords,
144
                dims=self.dims,
145
                attrs=self.attrs,
146
                skip_event_dims=True,
147
            ),
148
            dict_to_dataset(
149
                data_warmup,
150
                library=None,
151
                coords=self.coords,
152
                dims=self.dims,
153
                attrs=self.attrs,
154
                skip_event_dims=True,
155
            ),
156
        )
157

158 6
    @requires(["posterior_predictive", f"{WARMUP_TAG}posterior_predictive"])
159 4
    def posterior_predictive_to_xarray(self):
160
        """Convert posterior_predictive samples to xarray."""
161 4
        data = self._init_dict("posterior_predictive")
162 4
        data_warmup = self._init_dict(f"{WARMUP_TAG}posterior_predictive")
163 4
        if not isinstance(data, dict):
164 4
            raise TypeError("DictConverter.posterior_predictive is not a dictionary")
165 4
        if not isinstance(data_warmup, dict):
166 0
            raise TypeError("DictConverter.warmup_posterior_predictive is not a dictionary")
167

168 4
        return (
169
            dict_to_dataset(
170
                data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
171
            ),
172
            dict_to_dataset(
173
                data_warmup, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
174
            ),
175
        )
176

177 6
    @requires(["predictions", f"{WARMUP_TAG}predictions"])
178 4
    def predictions_to_xarray(self):
179
        """Convert predictions to xarray."""
180 4
        data = self._init_dict("predictions")
181 4
        data_warmup = self._init_dict(f"{WARMUP_TAG}predictions")
182 4
        if not isinstance(data, dict):
183 0
            raise TypeError("DictConverter.predictions is not a dictionary")
184 4
        if not isinstance(data_warmup, dict):
185 0
            raise TypeError("DictConverter.warmup_predictions is not a dictionary")
186

187 4
        return (
188
            dict_to_dataset(
189
                data, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs
190
            ),
191
            dict_to_dataset(
192
                data_warmup, library=None, coords=self.coords, dims=self.pred_dims, attrs=self.attrs
193
            ),
194
        )
195

196 6
    @requires("prior")
197 4
    def prior_to_xarray(self):
198
        """Convert prior samples to xarray."""
199 4
        data = self.prior
200 4
        if not isinstance(data, dict):
201 4
            raise TypeError("DictConverter.prior is not a dictionary")
202

203 4
        return dict_to_dataset(
204
            data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
205
        )
206

207 6
    @requires("sample_stats_prior")
208 4
    def sample_stats_prior_to_xarray(self):
209
        """Convert sample_stats_prior samples to xarray."""
210 4
        data = self.sample_stats_prior
211 4
        if not isinstance(data, dict):
212 4
            raise TypeError("DictConverter.sample_stats_prior is not a dictionary")
213

214 4
        return dict_to_dataset(
215
            data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
216
        )
217

218 6
    @requires("prior_predictive")
219 4
    def prior_predictive_to_xarray(self):
220
        """Convert prior_predictive samples to xarray."""
221 4
        data = self.prior_predictive
222 4
        if not isinstance(data, dict):
223 4
            raise TypeError("DictConverter.prior_predictive is not a dictionary")
224

225 4
        return dict_to_dataset(
226
            data, library=None, coords=self.coords, dims=self.dims, attrs=self.attrs
227
        )
228

229 6
    def data_to_xarray(self, dct, group, dims=None):
230
        """Convert data to xarray."""
231 4
        data = dct
232 4
        if not isinstance(data, dict):
233 4
            raise TypeError("DictConverter.{} is not a dictionary".format(group))
234 4
        if dims is None:
235 4
            dims = {} if self.dims is None else self.dims
236 4
        new_data = dict()
237 4
        for key, vals in data.items():
238 4
            vals = utils.one_de(vals)
239 4
            val_dims = dims.get(key)
240 4
            val_dims, coords = generate_dims_coords(
241
                vals.shape, key, dims=val_dims, coords=self.coords
242
            )
243 4
            new_data[key] = xr.DataArray(vals, dims=val_dims, coords=coords)
244 4
        return xr.Dataset(data_vars=new_data, attrs=make_attrs(attrs=self.attrs, library=None))
245

246 6
    @requires("observed_data")
247 4
    def observed_data_to_xarray(self):
248
        """Convert observed_data to xarray."""
249 4
        return self.data_to_xarray(self.observed_data, group="observed_data", dims=self.dims)
250

251 6
    @requires("constant_data")
252 4
    def constant_data_to_xarray(self):
253
        """Convert constant_data to xarray."""
254 0
        return self.data_to_xarray(self.constant_data, group="constant_data")
255

256 6
    @requires("predictions_constant_data")
257 4
    def predictions_constant_data_to_xarray(self):
258
        """Convert predictions_constant_data to xarray."""
259 0
        return self.data_to_xarray(
260
            self.predictions_constant_data, group="predictions_constant_data", dims=self.pred_dims
261
        )
262

263 6
    def to_inference_data(self):
264
        """Convert all available data to an InferenceData object.
265

266
        Note that if groups can not be created, then the InferenceData
267
        will not have those groups.
268
        """
269 4
        return InferenceData(
270
            **{
271
                "posterior": self.posterior_to_xarray(),
272
                "sample_stats": self.sample_stats_to_xarray(),
273
                "log_likelihood": self.log_likelihood_to_xarray(),
274
                "posterior_predictive": self.posterior_predictive_to_xarray(),
275
                "predictions": self.predictions_to_xarray(),
276
                "prior": self.prior_to_xarray(),
277
                "sample_stats_prior": self.sample_stats_prior_to_xarray(),
278
                "prior_predictive": self.prior_predictive_to_xarray(),
279
                "observed_data": self.observed_data_to_xarray(),
280
                "constant_data": self.constant_data_to_xarray(),
281
                "predictions_constant_data": self.predictions_constant_data_to_xarray(),
282
                "save_warmup": self.save_warmup,
283
            }
284
        )
285

286

287
# pylint: disable=too-many-instance-attributes
288 6
def from_dict(
289
    posterior=None,
290
    *,
291
    posterior_predictive=None,
292
    predictions=None,
293
    sample_stats=None,
294
    log_likelihood=None,
295
    prior=None,
296
    prior_predictive=None,
297
    sample_stats_prior=None,
298
    observed_data=None,
299
    constant_data=None,
300
    predictions_constant_data=None,
301
    warmup_posterior=None,
302
    warmup_posterior_predictive=None,
303
    warmup_predictions=None,
304
    warmup_log_likelihood=None,
305
    warmup_sample_stats=None,
306
    save_warmup=None,
307
    coords=None,
308
    dims=None,
309
    pred_dims=None,
310
    pred_coords=None,
311
    attrs=None,
312
):
313
    """Convert Dictionary data into an InferenceData object.
314

315
    For a usage example read the
316
    :ref:`Cookbook section on from_dict <cookbook>`
317

318
    Parameters
319
    ----------
320
    posterior : dict
321
    posterior_predictive : dict
322
    predictions: dict
323
    sample_stats : dict
324
    log_likelihood : dict
325
        For stats functions, log likelihood data should be stored here.
326
    prior : dict
327
    prior_predictive : dict
328
    observed_data : dict
329
    constant_data : dict
330
    predictions_constant_data: dict
331
    warmup_posterior : dict
332
    warmup_posterior_predictive : dict
333
    warmup_predictions : dict
334
    warmup_log_likelihood : dict
335
    warmup_sample_stats : dict
336
    save_warmup : bool
337
        Save warmup iterations InferenceData object. If not defined, use default
338
        defined by the rcParams.
339
    coords : dict[str, iterable]
340
        A dictionary containing the values that are used as index. The key
341
        is the name of the dimension, the values are the index values.
342
    dims : dict[str, List(str)]
343
        A mapping from variables to a list of coordinate names for the variable.
344
    pred_dims : dict[str, List(str)]
345
        A mapping from variables to a list of coordinate names for predictions.
346
    pred_coords : dict[str, List(str)]
347
        A mapping from variables to a list of coordinate values for predictions.
348
    attrs : dict
349
        A dictionary containing attributes for different groups.
350

351
    Returns
352
    -------
353
    InferenceData object
354
    """
355 4
    return DictConverter(
356
        posterior=posterior,
357
        posterior_predictive=posterior_predictive,
358
        predictions=predictions,
359
        sample_stats=sample_stats,
360
        log_likelihood=log_likelihood,
361
        prior=prior,
362
        prior_predictive=prior_predictive,
363
        sample_stats_prior=sample_stats_prior,
364
        observed_data=observed_data,
365
        constant_data=constant_data,
366
        predictions_constant_data=predictions_constant_data,
367
        warmup_posterior=warmup_posterior,
368
        warmup_posterior_predictive=warmup_posterior_predictive,
369
        warmup_predictions=warmup_predictions,
370
        warmup_log_likelihood=warmup_log_likelihood,
371
        warmup_sample_stats=warmup_sample_stats,
372
        save_warmup=save_warmup,
373
        coords=coords,
374
        dims=dims,
375
        pred_dims=pred_dims,
376
        pred_coords=pred_coords,
377
        attrs=attrs,
378
    ).to_inference_data()

Read our documentation on viewing source code .

Loading