1
|
|
"""High level conversion functions."""
|
2
|
2
|
import numpy as np
|
3
|
2
|
import xarray as xr
|
4
|
|
|
5
|
2
|
from .base import dict_to_dataset
|
6
|
2
|
from .inference_data import InferenceData
|
7
|
2
|
from .io_cmdstan import from_cmdstan
|
8
|
2
|
from .io_cmdstanpy import from_cmdstanpy
|
9
|
2
|
from .io_emcee import from_emcee
|
10
|
2
|
from .io_numpyro import from_numpyro
|
11
|
2
|
from .io_pymc3 import from_pymc3
|
12
|
2
|
from .io_pyro import from_pyro
|
13
|
2
|
from .io_pystan import from_pystan
|
14
|
|
|
15
|
|
|
16
|
|
# pylint: disable=too-many-return-statements
|
17
|
2
|
def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, **kwargs):
|
18
|
|
r"""Convert a supported object to an InferenceData object.
|
19
|
|
|
20
|
|
This function sends `obj` to the right conversion function. It is idempotent,
|
21
|
|
in that it will return arviz.InferenceData objects unchanged.
|
22
|
|
|
23
|
|
Parameters
|
24
|
|
----------
|
25
|
|
obj : dict, str, np.ndarray, xr.Dataset, pystan fit, pymc3 trace
|
26
|
|
A supported object to convert to InferenceData:
|
27
|
|
| InferenceData: returns unchanged
|
28
|
|
| str: Attempts to load the cmdstan csv or netcdf dataset from disk
|
29
|
|
| pystan fit: Automatically extracts data
|
30
|
|
| cmdstanpy fit: Automatically extracts data
|
31
|
|
| cmdstan csv-list: Automatically extracts data
|
32
|
|
| pymc3 trace: Automatically extracts data
|
33
|
|
| emcee sampler: Automatically extracts data
|
34
|
|
| pyro MCMC: Automatically extracts data
|
35
|
|
| xarray.Dataset: adds to InferenceData as only group
|
36
|
|
| xarray.DataArray: creates an xarray dataset as the only group, gives the
|
37
|
|
array an arbitrary name, if name not set
|
38
|
|
| dict: creates an xarray dataset as the only group
|
39
|
|
| numpy array: creates an xarray dataset as the only group, gives the
|
40
|
|
array an arbitrary name
|
41
|
|
group : str
|
42
|
|
If `obj` is a dict or numpy array, assigns the resulting xarray
|
43
|
|
dataset to this group. Default: "posterior".
|
44
|
|
coords : dict[str, iterable]
|
45
|
|
A dictionary containing the values that are used as index. The key
|
46
|
|
is the name of the dimension, the values are the index values.
|
47
|
|
dims : dict[str, List(str)]
|
48
|
|
A mapping from variables to a list of coordinate names for the variable
|
49
|
|
kwargs
|
50
|
|
Rest of the supported keyword arguments transferred to conversion function.
|
51
|
|
|
52
|
|
Returns
|
53
|
|
-------
|
54
|
|
InferenceData
|
55
|
|
"""
|
56
|
2
|
kwargs[group] = obj
|
57
|
2
|
kwargs["coords"] = coords
|
58
|
2
|
kwargs["dims"] = dims
|
59
|
|
|
60
|
|
# Cases that convert to InferenceData
|
61
|
2
|
if isinstance(obj, InferenceData):
|
62
|
2
|
if coords is not None or dims is not None:
|
63
|
0
|
raise TypeError("Cannot use coords or dims arguments with InferenceData value.")
|
64
|
2
|
return obj
|
65
|
2
|
elif isinstance(obj, str):
|
66
|
2
|
if obj.endswith(".csv"):
|
67
|
0
|
if group == "sample_stats":
|
68
|
0
|
kwargs["posterior"] = kwargs.pop(group)
|
69
|
0
|
elif group == "sample_stats_prior":
|
70
|
0
|
kwargs["prior"] = kwargs.pop(group)
|
71
|
0
|
return from_cmdstan(**kwargs)
|
72
|
|
else:
|
73
|
2
|
if coords is not None or dims is not None:
|
74
|
0
|
raise TypeError(
|
75
|
|
"Cannot use coords or dims arguments reading InferenceData from netcdf."
|
76
|
|
)
|
77
|
2
|
return InferenceData.from_netcdf(obj)
|
78
|
2
|
elif (
|
79
|
|
obj.__class__.__name__ in {"StanFit4Model", "CmdStanMCMC"}
|
80
|
|
or obj.__class__.__module__ == "stan.fit"
|
81
|
|
):
|
82
|
0
|
if group == "sample_stats":
|
83
|
0
|
kwargs["posterior"] = kwargs.pop(group)
|
84
|
0
|
elif group == "sample_stats_prior":
|
85
|
0
|
kwargs["prior"] = kwargs.pop(group)
|
86
|
0
|
if obj.__class__.__name__ == "CmdStanMCMC":
|
87
|
0
|
return from_cmdstanpy(**kwargs)
|
88
|
|
else: # pystan or pystan3
|
89
|
0
|
return from_pystan(**kwargs)
|
90
|
2
|
elif obj.__class__.__name__ == "MultiTrace": # ugly, but doesn't make PyMC3 a requirement
|
91
|
0
|
return from_pymc3(trace=kwargs.pop(group), **kwargs)
|
92
|
2
|
elif obj.__class__.__name__ == "EnsembleSampler": # ugly, but doesn't make emcee a requirement
|
93
|
0
|
return from_emcee(sampler=kwargs.pop(group), **kwargs)
|
94
|
2
|
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"):
|
95
|
0
|
return from_pyro(posterior=kwargs.pop(group), **kwargs)
|
96
|
2
|
elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"):
|
97
|
0
|
return from_numpyro(posterior=kwargs.pop(group), **kwargs)
|
98
|
|
|
99
|
|
# Cases that convert to xarray
|
100
|
2
|
if isinstance(obj, xr.Dataset):
|
101
|
2
|
dataset = obj
|
102
|
2
|
elif isinstance(obj, xr.DataArray):
|
103
|
2
|
if obj.name is None:
|
104
|
2
|
obj.name = "x"
|
105
|
2
|
dataset = obj.to_dataset()
|
106
|
2
|
elif isinstance(obj, dict):
|
107
|
2
|
dataset = dict_to_dataset(obj, coords=coords, dims=dims)
|
108
|
2
|
elif isinstance(obj, np.ndarray):
|
109
|
2
|
dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
|
110
|
2
|
elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
|
111
|
0
|
if group == "sample_stats":
|
112
|
0
|
kwargs["posterior"] = kwargs.pop(group)
|
113
|
0
|
elif group == "sample_stats_prior":
|
114
|
0
|
kwargs["prior"] = kwargs.pop(group)
|
115
|
0
|
return from_cmdstan(**kwargs)
|
116
|
|
else:
|
117
|
2
|
allowable_types = (
|
118
|
|
"xarray dataarray",
|
119
|
|
"xarray dataset",
|
120
|
|
"dict",
|
121
|
|
"netcdf filename",
|
122
|
|
"numpy array",
|
123
|
|
"pystan fit",
|
124
|
|
"pymc3 trace",
|
125
|
|
"emcee fit",
|
126
|
|
"pyro mcmc fit",
|
127
|
|
"numpyro mcmc fit",
|
128
|
|
"cmdstan fit csv filename",
|
129
|
|
"cmdstanpy fit",
|
130
|
|
)
|
131
|
2
|
raise ValueError(
|
132
|
|
"Can only convert {} to InferenceData, not {}".format(
|
133
|
|
", ".join(allowable_types), obj.__class__.__name__
|
134
|
|
)
|
135
|
|
)
|
136
|
|
|
137
|
2
|
return InferenceData(**{group: dataset})
|
138
|
|
|
139
|
|
|
140
|
2
|
def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
|
141
|
|
"""Convert a supported object to an xarray dataset.
|
142
|
|
|
143
|
|
This function is idempotent, in that it will return xarray.Dataset functions
|
144
|
|
unchanged. Raises `ValueError` if the desired group can not be extracted.
|
145
|
|
|
146
|
|
Note this goes through a DataInference object. See `convert_to_inference_data`
|
147
|
|
for more details. Raises ValueError if it can not work out the desired
|
148
|
|
conversion.
|
149
|
|
|
150
|
|
Parameters
|
151
|
|
----------
|
152
|
|
obj : dict, str, np.ndarray, xr.Dataset, pystan fit, pymc3 trace
|
153
|
|
A supported object to convert to InferenceData:
|
154
|
|
InferenceData: returns unchanged
|
155
|
|
str: Attempts to load the netcdf dataset from disk
|
156
|
|
pystan fit: Automatically extracts data
|
157
|
|
pymc3 trace: Automatically extracts data
|
158
|
|
xarray.Dataset: adds to InferenceData as only group
|
159
|
|
xarray.DataArray: creates an xarray dataset as the only group, gives the
|
160
|
|
array an arbitrary name, if name not set
|
161
|
|
dict: creates an xarray dataset as the only group
|
162
|
|
numpy array: creates an xarray dataset as the only group, gives the
|
163
|
|
array an arbitrary name
|
164
|
|
group : str
|
165
|
|
If `obj` is a dict or numpy array, assigns the resulting xarray
|
166
|
|
dataset to this group.
|
167
|
|
coords : dict[str, iterable]
|
168
|
|
A dictionary containing the values that are used as index. The key
|
169
|
|
is the name of the dimension, the values are the index values.
|
170
|
|
dims : dict[str, List(str)]
|
171
|
|
A mapping from variables to a list of coordinate names for the variable
|
172
|
|
|
173
|
|
Returns
|
174
|
|
-------
|
175
|
|
xarray.Dataset
|
176
|
|
"""
|
177
|
2
|
inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
|
178
|
2
|
dataset = getattr(inference_data, group, None)
|
179
|
2
|
if dataset is None:
|
180
|
2
|
raise ValueError(
|
181
|
|
"Can not extract {group} from {obj}! See {filename} for other "
|
182
|
|
"conversion utilities.".format(group=group, obj=obj, filename=__file__)
|
183
|
|
)
|
184
|
2
|
return dataset
|