1
"""Plot highest density intervals for regression data."""
2 6
import warnings
3

4 6
import numpy as np
5 6
from scipy.interpolate import griddata
6 6
from scipy.signal import savgol_filter
7 6
from xarray import Dataset
8

9 6
from ..rcparams import rcParams
10 6
from ..stats import hdi
11 6
from ..utils import credible_interval_warning
12 6
from .plot_utils import get_plotting_function
13

14

15 6
def plot_hdi(
16
    x,
17
    y=None,
18
    hdi_prob=None,
19
    hdi_data=None,
20
    color="C1",
21
    circular=False,
22
    smooth=True,
23
    smooth_kwargs=None,
24
    figsize=None,
25
    fill_kwargs=None,
26
    plot_kwargs=None,
27
    hdi_kwargs=None,
28
    ax=None,
29
    backend=None,
30
    backend_kwargs=None,
31
    show=None,
32
    credible_interval=None,
33
):
34
    r"""
35
    Plot HDI intervals for regression data.
36

37
    Parameters
38
    ----------
39
    x : array-like
40
        Values to plot.
41
    y : array-like, optional
42
        Values from which to compute the HDI. Assumed shape ``(chain, draw, \*shape)``.
43
        Only optional if hdi_data is present.
44
    hdi_data : array_like, optional
45
        Precomputed HDI values to use. Assumed shape is ``(*x.shape, 2)``.
46
    hdi_prob : float, optional
47
        Probability for the highest density interval. Defaults to ``stats.hdi_prob`` rcParam.
48
    color : str, optional
49
        Color used for the limits of the HDI and fill. Should be a valid matplotlib color.
50
    circular : bool, optional
51
        Whether to compute the HDI taking into account `x` is a circular variable
52
        (in the range [-np.pi, np.pi]) or not. Defaults to False (i.e non-circular variables).
53
    smooth : boolean, optional
54
        If True the result will be smoothed by first computing a linear interpolation of the data
55
        over a regular grid and then applying the Savitzky-Golay filter to the interpolated data.
56
        Defaults to True.
57
    smooth_kwargs : dict, optional
58
        Additional keywords modifying the Savitzky-Golay filter. See
59
        :func:`scipy:scipy.signal.savgol_filter` for details.
60
    figsize : tuple
61
        Figure size. If None it will be defined automatically.
62
    fill_kwargs : dict, optional
63
        Keywords passed to :meth:`mpl:matplotlib.axes.Axes.fill_between`
64
        (use fill_kwargs={'alpha': 0} to disable fill) or to
65
        :meth:`bokeh:bokeh.plotting.figure.Figure.patch`.
66
    plot_kwargs : dict, optional
67
        HDI limits keyword arguments, passed to :meth:`mpl:matplotlib.axes.Axes.plot` or
68
        :meth:`bokeh:bokeh.plotting.figure.Figure.patch`.
69
    hdi_kwargs : dict, optional
70
        Keyword arguments passed to :func:`~arviz.hdi`. Ignored if ``hdi_data`` is present.
71
    ax : axes, optional
72
        Matplotlib axes or bokeh figures.
73
    backend : {"matplotlib","bokeh"}, optional
74
        Select plotting backend.
75
    backend_kwargs : bool, optional
76
        These are kwargs specific to the backend being used. Passed to ::``
77
    show : bool, optional
78
        Call backend show function.
79
    credible_interval : float, optional
80
        Deprecated: Please see hdi_prob
81

82
    Returns
83
    -------
84
    axes : matplotlib axes or bokeh figures
85

86
    See Also
87
    --------
88
    hdi : Calculate highest density interval (HDI) of array for given probability.
89

90
    Examples
91
    --------
92
    Plot HDI interval of simulated regression data using `y` argument:
93

94
    .. plot::
95
        :context: close-figs
96

97
        >>> import numpy as np
98
        >>> import arviz as az
99
        >>> x_data = np.random.normal(0, 1, 100)
100
        >>> y_data = np.random.normal(2 + x_data * 0.5, 0.5, (2, 50, 100))
101
        >>> az.plot_hdi(x_data, y_data)
102

103
    ``plot_hdi`` can also be given precalculated values with the argument ``hdi_data``. This example
104
    shows how to use :func:`~arviz.hdi` to precalculate the values and pass these values to
105
    ``plot_hdi``. Similarly to an example in ``hdi`` we are using the ``input_core_dims``
106
    argument of :func:`~arviz.wrap_xarray_ufunc` to manually define the dimensions over which
107
    to calculate the HDI.
108

109
    .. plot::
110
        :context: close-figs
111

112
        >>> hdi_data = az.hdi(y_data, input_core_dims=[["draw"]])
113
        >>> ax = az.plot_hdi(x_data, hdi_data=hdi_data[0], color="r", fill_kwargs={"alpha": .2})
114
        >>> az.plot_hdi(x_data, hdi_data=hdi_data[1], color="k", ax=ax, fill_kwargs={"alpha": .2})
115

116
    """
117 4
    if credible_interval:
118 0
        hdi_prob = credible_interval_warning(credible_interval, hdi_prob)
119

120 4
    if hdi_kwargs is None:
121 4
        hdi_kwargs = {}
122

123 4
    x = np.asarray(x)
124 4
    x_shape = x.shape
125

126 4
    if y is None and hdi_data is None:
127 4
        raise ValueError("One of {y, hdi_data} is required")
128 4
    if hdi_data is not None and y is not None:
129 4
        warnings.warn("Both y and hdi_data arguments present, ignoring y")
130 4
    elif hdi_data is not None:
131 4
        hdi_prob = (
132
            hdi_data.hdi.attrs.get("hdi_prob", np.nan) if hasattr(hdi_data, "hdi") else np.nan
133
        )
134 4
        if isinstance(hdi_data, Dataset):
135 4
            data_vars = list(hdi_data.data_vars)
136 4
            if len(data_vars) != 1:
137 4
                raise ValueError(
138
                    "Found several variables in hdi_data. Only single variable Datasets are "
139
                    "supported."
140
                )
141 4
            hdi_data = hdi_data[data_vars[0]]
142
    else:
143 4
        y = np.asarray(y)
144 4
        if hdi_prob is None:
145 4
            hdi_prob = rcParams["stats.hdi_prob"]
146
        else:
147 0
            if not 1 >= hdi_prob > 0:
148 0
                raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
149 4
        hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
150

151 4
    hdi_shape = hdi_data.shape
152 4
    if hdi_shape[:-1] != x_shape:
153 0
        msg = (
154
            "Dimension mismatch for x: {} and hdi: {}. Check the dimensions of y and"
155
            "hdi_kwargs to make sure they are compatible"
156
        )
157 0
        raise TypeError(msg.format(x_shape, hdi_shape))
158

159 4
    if smooth:
160 4
        if smooth_kwargs is None:
161 4
            smooth_kwargs = {}
162 4
        smooth_kwargs.setdefault("window_length", 55)
163 4
        smooth_kwargs.setdefault("polyorder", 2)
164 4
        x_data = np.linspace(x.min(), x.max(), 200)
165 4
        x_data[0] = (x_data[0] + x_data[1]) / 2
166 4
        hdi_interp = griddata(x, hdi_data, x_data)
167 4
        y_data = savgol_filter(hdi_interp, axis=0, **smooth_kwargs)
168
    else:
169 4
        idx = np.argsort(x)
170 4
        x_data = x[idx]
171 4
        y_data = hdi_data[idx]
172

173 4
    hdiplot_kwargs = dict(
174
        ax=ax,
175
        x_data=x_data,
176
        y_data=y_data,
177
        color=color,
178
        figsize=figsize,
179
        plot_kwargs=plot_kwargs,
180
        fill_kwargs=fill_kwargs,
181
        backend_kwargs=backend_kwargs,
182
        show=show,
183
    )
184

185 4
    if backend is None:
186 4
        backend = rcParams["plot.backend"]
187 4
    backend = backend.lower()
188

189 4
    plot = get_plotting_function("plot_hdi", "hdiplot", backend)
190 4
    ax = plot(**hdiplot_kwargs)
191 4
    return ax
192

193

194 6
def plot_hpd(*args, **kwargs):  # noqa: D103
195 0
    warnings.warn("plot_hpd has been deprecated, please use plot_hdi", DeprecationWarning)
196 0
    return plot_hdi(*args, **kwargs)

Read our documentation on viewing source code .

Loading