1
"""Plot pointwise elpd estimations of inference data."""
2 6
import numpy as np
3

4 6
from ..data import convert_to_inference_data
5 6
from ..rcparams import rcParams
6 6
from ..stats import ELPDData, loo, waic
7 6
from ..utils import get_coords
8 6
from .plot_utils import format_coords_as_labels, get_plotting_function
9

10

11 6
def plot_elpd(
12
    compare_dict,
13
    color="C0",
14
    xlabels=False,
15
    figsize=None,
16
    textsize=None,
17
    coords=None,
18
    legend=False,
19
    threshold=None,
20
    ax=None,
21
    ic=None,
22
    scale=None,
23
    plot_kwargs=None,
24
    backend=None,
25
    backend_kwargs=None,
26
    show=None,
27
):
28
    """
29
    Plot pointwise elpd differences between two or more models.
30

31
    Parameters
32
    ----------
33
    compare_dict : mapping, str -> ELPDData or InferenceData
34
        A dictionary mapping the model name to the object containing inference data or the result
35
        of `loo`/`waic` functions.
36
        Refer to az.convert_to_inference_data for details on possible dict items
37
    color : str or array_like, optional
38
        Colors of the scatter plot, if color is a str all dots will have the same color,
39
        if it is the size of the observations, each dot will have the specified color,
40
        otherwise, it will be interpreted as a list of the dims to be used for the color code
41
    xlabels : bool, optional
42
        Use coords as xticklabels
43
    figsize : figure size tuple, optional
44
        If None, size is (8 + numvars, 8 + numvars)
45
    textsize: int, optional
46
        Text size for labels. If None it will be autoscaled based on figsize.
47
    coords : mapping, optional
48
        Coordinates of points to plot. **All** values are used for computation, but only a
49
        a subset can be plotted for convenience.
50
    legend : bool, optional
51
        Include a legend to the plot. Only taken into account when color argument is a dim name.
52
    threshold : float
53
        If some elpd difference is larger than `threshold * elpd.std()`, show its label. If
54
        `None`, no observations will be highlighted.
55
    ic : str, optional
56
        Information Criterion (PSIS-LOO `loo`, WAIC `waic`) used to compare models. Defaults to
57
        ``rcParams["stats.information_criterion"]``.
58
        Only taken into account when input is InferenceData.
59
    scale : str, optional
60
        scale argument passed to az.loo or az.waic, see their docs for details. Only taken
61
        into account when input is InferenceData.
62
    plot_kwargs : dicts, optional
63
        Additional keywords passed to ax.scatter
64
    ax: axes, optional
65
        Matplotlib axes or bokeh figures.
66
    backend: str, optional
67
        Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
68
    backend_kwargs: bool, optional
69
        These are kwargs specific to the backend being used. For additional documentation
70
        check the plotting method of the backend.
71
    show : bool, optional
72
        Call backend show function.
73

74
    Returns
75
    -------
76
    axes : matplotlib axes or bokeh figures
77

78
    Examples
79
    --------
80
    Compare pointwise PSIS-LOO for centered and non centered models of the 8-schools problem
81
    using matplotlib.
82

83
    .. plot::
84
        :context: close-figs
85

86
        >>> import arviz as az
87
        >>> idata1 = az.load_arviz_data("centered_eight")
88
        >>> idata2 = az.load_arviz_data("non_centered_eight")
89
        >>> az.plot_elpd(
90
        >>>     {"centered model": idata1, "non centered model": idata2},
91
        >>>     xlabels=True
92
        >>> )
93

94
    .. bokeh-plot::
95
        :source-position: above
96

97
        import arviz as az
98
        idata1 = az.load_arviz_data("centered_eight")
99
        idata2 = az.load_arviz_data("non_centered_eight")
100
        az.plot_elpd(
101
            {"centered model": idata1, "non centered model": idata2},
102
            backend="bokeh"
103
        )
104

105
    """
106 4
    valid_ics = ["loo", "waic"]
107 4
    ic = rcParams["stats.information_criterion"] if ic is None else ic.lower()
108 4
    scale = rcParams["stats.ic_scale"] if scale is None else scale.lower()
109 4
    if ic not in valid_ics:
110 4
        raise ValueError(
111
            ("Information Criteria type {} not recognized." "IC must be in {}").format(
112
                ic, valid_ics
113
            )
114
        )
115 4
    ic_fun = loo if ic == "loo" else waic
116

117
    # Make sure all object are ELPDData
118 4
    for k, item in compare_dict.items():
119 4
        if not isinstance(item, ELPDData):
120 4
            compare_dict[k] = ic_fun(convert_to_inference_data(item), pointwise=True, scale=scale)
121 4
    ics = [elpd_data.index[0] for elpd_data in compare_dict.values()]
122 4
    if not all(x == ics[0] for x in ics):
123 4
        raise SyntaxError(
124
            "All Information Criteria must be of the same kind, but both loo and waic data present"
125
        )
126 4
    ic = ics[0]
127 4
    scales = [elpd_data["{}_scale".format(ic)] for elpd_data in compare_dict.values()]
128 4
    if not all(x == scales[0] for x in scales):
129 4
        raise SyntaxError(
130
            "All Information Criteria must be on the same scale, but {} are present".format(
131
                set(scales)
132
            )
133
        )
134

135 4
    if backend is None:
136 4
        backend = rcParams["plot.backend"]
137 4
    backend = backend.lower()
138

139 4
    numvars = len(compare_dict)
140 4
    models = list(compare_dict.keys())
141

142 4
    if coords is None:
143 4
        coords = {}
144

145 4
    pointwise_data = [
146
        get_coords(compare_dict[model]["{}_i".format(ic)], coords) for model in models
147
    ]
148 4
    xdata = np.arange(pointwise_data[0].size)
149 4
    coord_labels = format_coords_as_labels(pointwise_data[0]) if xlabels else None
150

151 4
    if numvars < 2:
152 4
        raise Exception("Number of models to compare must be 2 or greater.")
153

154 4
    elpd_plot_kwargs = dict(
155
        ax=ax,
156
        models=models,
157
        pointwise_data=pointwise_data,
158
        numvars=numvars,
159
        figsize=figsize,
160
        textsize=textsize,
161
        plot_kwargs=plot_kwargs,
162
        xlabels=xlabels,
163
        coord_labels=coord_labels,
164
        xdata=xdata,
165
        threshold=threshold,
166
        legend=legend,
167
        color=color,
168
        backend_kwargs=backend_kwargs,
169
        show=show,
170
    )
171

172
    # TODO: Add backend kwargs
173 4
    plot = get_plotting_function("plot_elpd", "elpdplot", backend)
174 4
    ax = plot(**elpd_plot_kwargs)
175 4
    return ax

Read our documentation on viewing source code .

Loading