1
"""High level conversion functions."""
2 2
import numpy as np
3 2
import xarray as xr
4

5 2
from .base import dict_to_dataset
6 2
from .inference_data import InferenceData
7 2
from .io_cmdstan import from_cmdstan
8 2
from .io_cmdstanpy import from_cmdstanpy
9 2
from .io_emcee import from_emcee
10 2
from .io_numpyro import from_numpyro
11 2
from .io_pymc3 import from_pymc3
12 2
from .io_pyro import from_pyro
13 2
from .io_pystan import from_pystan
14

15

16
# pylint: disable=too-many-return-statements
17 2
def convert_to_inference_data(obj, *, group="posterior", coords=None, dims=None, **kwargs):
18
    r"""Convert a supported object to an InferenceData object.
19

20
    This function sends `obj` to the right conversion function. It is idempotent,
21
    in that it will return arviz.InferenceData objects unchanged.
22

23
    Parameters
24
    ----------
25
    obj : dict, str, np.ndarray, xr.Dataset, pystan fit, pymc3 trace
26
        A supported object to convert to InferenceData:
27
            | InferenceData: returns unchanged
28
            | str: Attempts to load the cmdstan csv or netcdf dataset from disk
29
            | pystan fit: Automatically extracts data
30
            | cmdstanpy fit: Automatically extracts data
31
            | cmdstan csv-list: Automatically extracts data
32
            | pymc3 trace: Automatically extracts data
33
            | emcee sampler: Automatically extracts data
34
            | pyro MCMC: Automatically extracts data
35
            | xarray.Dataset: adds to InferenceData as only group
36
            | xarray.DataArray: creates an xarray dataset as the only group, gives the
37
                         array an arbitrary name, if name not set
38
            | dict: creates an xarray dataset as the only group
39
            | numpy array: creates an xarray dataset as the only group, gives the
40
                         array an arbitrary name
41
    group : str
42
        If `obj` is a dict or numpy array, assigns the resulting xarray
43
        dataset to this group. Default: "posterior".
44
    coords : dict[str, iterable]
45
        A dictionary containing the values that are used as index. The key
46
        is the name of the dimension, the values are the index values.
47
    dims : dict[str, List(str)]
48
        A mapping from variables to a list of coordinate names for the variable
49
    kwargs
50
        Rest of the supported keyword arguments transferred to conversion function.
51

52
    Returns
53
    -------
54
    InferenceData
55
    """
56 2
    kwargs[group] = obj
57 2
    kwargs["coords"] = coords
58 2
    kwargs["dims"] = dims
59

60
    # Cases that convert to InferenceData
61 2
    if isinstance(obj, InferenceData):
62 2
        if coords is not None or dims is not None:
63 0
            raise TypeError("Cannot use coords or dims arguments with InferenceData value.")
64 2
        return obj
65 2
    elif isinstance(obj, str):
66 2
        if obj.endswith(".csv"):
67 0
            if group == "sample_stats":
68 0
                kwargs["posterior"] = kwargs.pop(group)
69 0
            elif group == "sample_stats_prior":
70 0
                kwargs["prior"] = kwargs.pop(group)
71 0
            return from_cmdstan(**kwargs)
72
        else:
73 2
            if coords is not None or dims is not None:
74 0
                raise TypeError(
75
                    "Cannot use coords or dims arguments reading InferenceData from netcdf."
76
                )
77 2
            return InferenceData.from_netcdf(obj)
78 2
    elif (
79
        obj.__class__.__name__ in {"StanFit4Model", "CmdStanMCMC"}
80
        or obj.__class__.__module__ == "stan.fit"
81
    ):
82 0
        if group == "sample_stats":
83 0
            kwargs["posterior"] = kwargs.pop(group)
84 0
        elif group == "sample_stats_prior":
85 0
            kwargs["prior"] = kwargs.pop(group)
86 0
        if obj.__class__.__name__ == "CmdStanMCMC":
87 0
            return from_cmdstanpy(**kwargs)
88
        else:  # pystan or pystan3
89 0
            return from_pystan(**kwargs)
90 2
    elif obj.__class__.__name__ == "MultiTrace":  # ugly, but doesn't make PyMC3 a requirement
91 0
        return from_pymc3(trace=kwargs.pop(group), **kwargs)
92 2
    elif obj.__class__.__name__ == "EnsembleSampler":  # ugly, but doesn't make emcee a requirement
93 0
        return from_emcee(sampler=kwargs.pop(group), **kwargs)
94 2
    elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("pyro"):
95 0
        return from_pyro(posterior=kwargs.pop(group), **kwargs)
96 2
    elif obj.__class__.__name__ == "MCMC" and obj.__class__.__module__.startswith("numpyro"):
97 0
        return from_numpyro(posterior=kwargs.pop(group), **kwargs)
98

99
    # Cases that convert to xarray
100 2
    if isinstance(obj, xr.Dataset):
101 2
        dataset = obj
102 2
    elif isinstance(obj, xr.DataArray):
103 2
        if obj.name is None:
104 2
            obj.name = "x"
105 2
        dataset = obj.to_dataset()
106 2
    elif isinstance(obj, dict):
107 2
        dataset = dict_to_dataset(obj, coords=coords, dims=dims)
108 2
    elif isinstance(obj, np.ndarray):
109 2
        dataset = dict_to_dataset({"x": obj}, coords=coords, dims=dims)
110 2
    elif isinstance(obj, (list, tuple)) and isinstance(obj[0], str) and obj[0].endswith(".csv"):
111 0
        if group == "sample_stats":
112 0
            kwargs["posterior"] = kwargs.pop(group)
113 0
        elif group == "sample_stats_prior":
114 0
            kwargs["prior"] = kwargs.pop(group)
115 0
        return from_cmdstan(**kwargs)
116
    else:
117 2
        allowable_types = (
118
            "xarray dataarray",
119
            "xarray dataset",
120
            "dict",
121
            "netcdf filename",
122
            "numpy array",
123
            "pystan fit",
124
            "pymc3 trace",
125
            "emcee fit",
126
            "pyro mcmc fit",
127
            "numpyro mcmc fit",
128
            "cmdstan fit csv filename",
129
            "cmdstanpy fit",
130
        )
131 2
        raise ValueError(
132
            "Can only convert {} to InferenceData, not {}".format(
133
                ", ".join(allowable_types), obj.__class__.__name__
134
            )
135
        )
136

137 2
    return InferenceData(**{group: dataset})
138

139

140 2
def convert_to_dataset(obj, *, group="posterior", coords=None, dims=None):
141
    """Convert a supported object to an xarray dataset.
142

143
    This function is idempotent, in that it will return xarray.Dataset functions
144
    unchanged. Raises `ValueError` if the desired group can not be extracted.
145

146
    Note this goes through a DataInference object. See `convert_to_inference_data`
147
    for more details. Raises ValueError if it can not work out the desired
148
    conversion.
149

150
    Parameters
151
    ----------
152
    obj : dict, str, np.ndarray, xr.Dataset, pystan fit, pymc3 trace
153
        A supported object to convert to InferenceData:
154
            InferenceData: returns unchanged
155
            str: Attempts to load the netcdf dataset from disk
156
            pystan fit: Automatically extracts data
157
            pymc3 trace: Automatically extracts data
158
            xarray.Dataset: adds to InferenceData as only group
159
            xarray.DataArray: creates an xarray dataset as the only group, gives the
160
                         array an arbitrary name, if name not set
161
            dict: creates an xarray dataset as the only group
162
            numpy array: creates an xarray dataset as the only group, gives the
163
                         array an arbitrary name
164
    group : str
165
        If `obj` is a dict or numpy array, assigns the resulting xarray
166
        dataset to this group.
167
    coords : dict[str, iterable]
168
        A dictionary containing the values that are used as index. The key
169
        is the name of the dimension, the values are the index values.
170
    dims : dict[str, List(str)]
171
        A mapping from variables to a list of coordinate names for the variable
172

173
    Returns
174
    -------
175
    xarray.Dataset
176
    """
177 2
    inference_data = convert_to_inference_data(obj, group=group, coords=coords, dims=dims)
178 2
    dataset = getattr(inference_data, group, None)
179 2
    if dataset is None:
180 2
        raise ValueError(
181
            "Can not extract {group} from {obj}! See {filename} for other "
182
            "conversion utilities.".format(group=group, obj=obj, filename=__file__)
183
        )
184 2
    return dataset

Read our documentation on viewing source code .

Loading