1
"""Plot kde or histograms and values from MCMC samples."""
2 2
import warnings
3 2
from typing import Any, Callable, List, Mapping, Optional, Tuple, Union
4

5 2
from ..data import CoordSpec, InferenceData, convert_to_dataset
6 2
from ..rcparams import rcParams
7 2
from ..utils import _var_names, get_coords
8 2
from .plot_utils import KwargSpec, get_plotting_function, xarray_var_iter
9

10

11 2
def plot_trace(
12
    data: InferenceData,
13
    var_names: Optional[List[str]] = None,
14
    filter_vars: Optional[str] = None,
15
    transform: Optional[Callable] = None,
16
    coords: Optional[CoordSpec] = None,
17
    divergences: Optional[str] = "auto",
18
    kind: Optional[str] = "trace",
19
    figsize: Optional[Tuple[float, float]] = None,
20
    rug: bool = False,
21
    lines: Optional[List[Tuple[str, CoordSpec, Any]]] = None,
22
    circ_var_names: Optional[List[str]] = None,
23
    circ_var_units: bool = "radians",
24
    compact: bool = False,
25
    compact_prop: Optional[Union[str, Mapping[str, Any]]] = None,
26
    combined: bool = False,
27
    chain_prop: Optional[Union[str, Mapping[str, Any]]] = None,
28
    legend: bool = False,
29
    plot_kwargs: Optional[KwargSpec] = None,
30
    fill_kwargs: Optional[KwargSpec] = None,
31
    rug_kwargs: Optional[KwargSpec] = None,
32
    hist_kwargs: Optional[KwargSpec] = None,
33
    trace_kwargs: Optional[KwargSpec] = None,
34
    rank_kwargs: Optional[KwargSpec] = None,
35
    axes=None,
36
    backend: Optional[str] = None,
37
    backend_config: Optional[KwargSpec] = None,
38
    backend_kwargs: Optional[KwargSpec] = None,
39
    show: Optional[bool] = None,
40
):
41
    """Plot distribution (histogram or kernel density estimates) and sampled values or rank plot.
42

43
    If `divergences` data is available in `sample_stats`, will plot the location of divergences as
44
    dashed vertical lines.
45

46
    Parameters
47
    ----------
48
    data: obj
49
        Any object that can be converted to an az.InferenceData object
50
        Refer to documentation of az.convert_to_dataset for details
51
    var_names: str or list of str, optional
52
        One or more variables to be plotted. Prefix the variables by `~` when you want
53
        to exclude them from the plot.
54
    filter_vars: {None, "like", "regex"}, optional, default=None
55
        If `None` (default), interpret var_names as the real variables names. If "like",
56
        interpret var_names as substrings of the real variables names. If "regex",
57
        interpret var_names as regular expressions on the real variables names. A la
58
        `pandas.filter`.
59
    coords: dict of {str: slice or array_like}, optional
60
        Coordinates of var_names to be plotted. Passed to `Dataset.sel`
61
    divergences: {"bottom", "top", None}, optional
62
        Plot location of divergences on the traceplots.
63
    kind: {"trace", "rank_bar", "rank_vlines"}, optional
64
        Choose between plotting sampled values per iteration and rank plots.
65
    transform: callable, optional
66
        Function to transform data (defaults to None i.e.the identity function)
67
    figsize: tuple of (float, float), optional
68
        If None, size is (12, variables * 2)
69
    rug: bool, optional
70
        If True adds a rugplot. Defaults to False. Ignored for 2D KDE.
71
        Only affects continuous variables.
72
    lines: list of tuple of (str, dict, array_like), optional
73
        List of (var_name, {'coord': selection}, [line, positions]) to be overplotted as
74
        vertical lines on the density and horizontal lines on the trace.
75
    circ_var_names : str or list of str, optional
76
        List of circular variables to account for when plotting KDE.
77
    circ_var_units : str
78
        Whether the variables in `circ_var_names` are in "degrees" or "radians".
79
    compact: bool, optional
80
        Plot multidimensional variables in a single plot.
81
    compact_prop: str or dict {str: array_like}, optional
82
        Tuple containing the property name and the property values to distinguish diferent
83
        dimensions with compact=True
84
    combined: bool, optional
85
        Flag for combining multiple chains into a single line. If False (default), chains will be
86
        plotted separately.
87
    chain_prop: str or dict {str: array_like}, optional
88
        Tuple containing the property name and the property values to distinguish diferent chains
89
    legend: bool, optional
90
        Add a legend to the figure with the chain color code.
91
    plot_kwargs, fill_kwargs, rug_kwargs, hist_kwargs: dict, optional
92
        Extra keyword arguments passed to `arviz.plot_dist`. Only affects continuous variables.
93
    trace_kwargs: dict, optional
94
        Extra keyword arguments passed to `plt.plot`
95
    backend: {"matplotlib", "bokeh"}, optional
96
        Select plotting backend.
97
    backend_config: dict, optional
98
        Currently specifies the bounds to use for bokeh axes. Defaults to value set in rcParams.
99
    backend_kwargs: dict, optional
100
        These are kwargs specific to the backend being used. For additional documentation
101
        check the plotting method of the backend.
102
    show: bool, optional
103
        Call backend show function.
104

105
    Returns
106
    -------
107
    axes: matplotlib axes or bokeh figures
108

109
    Examples
110
    --------
111
    Plot a subset variables and select them with partial naming
112

113
    .. plot::
114
        :context: close-figs
115

116
        >>> import arviz as az
117
        >>> data = az.load_arviz_data('non_centered_eight')
118
        >>> coords = {'school': ['Choate', 'Lawrenceville']}
119
        >>> az.plot_trace(data, var_names=('theta'), filter_vars="like", coords=coords)
120

121
    Show all dimensions of multidimensional variables in the same plot
122

123
    .. plot::
124
        :context: close-figs
125

126
        >>> az.plot_trace(data, compact=True)
127

128
    Display a rank plot instead of trace
129

130
    .. plot::
131
        :context: close-figs
132

133
        >>> az.plot_trace(data, var_names=["mu", "tau"], kind="rank_bars")
134

135
    Combine all chains into one distribution and select variables with regular expressions
136

137
    .. plot::
138
        :context: close-figs
139

140
        >>> az.plot_trace(
141
        >>>     data, var_names=('^theta'), filter_vars="regex", coords=coords, combined=True
142
        >>> )
143

144

145
    Plot reference lines against distribution and trace
146

147
    .. plot::
148
        :context: close-figs
149

150
        >>> lines = (('theta_t',{'school': "Choate"}, [-1]),)
151
        >>> az.plot_trace(data, var_names=('theta_t', 'theta'), coords=coords, lines=lines)
152

153
    """
154 2
    if kind not in {"trace", "rank_vlines", "rank_bars"}:
155 0
        raise ValueError("The value of kind must be either trace, rank_vlines or rank_bars.")
156

157 2
    if divergences == "auto":
158 2
        divergences = "top" if rug else "bottom"
159 2
    if divergences:
160 2
        try:
161 2
            divergence_data = convert_to_dataset(data, group="sample_stats").diverging
162 2
        except (ValueError, AttributeError):  # No sample_stats, or no `.diverging`
163 2
            divergences = False
164

165 2
    if coords is None:
166 2
        coords = {}
167

168 2
    if divergences:
169 2
        divergence_data = get_coords(
170
            divergence_data, {k: v for k, v in coords.items() if k in ("chain", "draw")}
171
        )
172
    else:
173 2
        divergence_data = False
174

175 2
    data = get_coords(convert_to_dataset(data, group="posterior"), coords)
176

177 2
    if transform is not None:
178 0
        data = transform(data)
179

180 2
    var_names = _var_names(var_names, data, filter_vars)
181

182 2
    if compact:
183 2
        skip_dims = set(data.dims) - {"chain", "draw"}
184
    else:
185 2
        skip_dims = set()
186

187 2
    plotters = list(xarray_var_iter(data, var_names=var_names, combined=True, skip_dims=skip_dims))
188 2
    max_plots = rcParams["plot.max_subplots"]
189 2
    max_plots = len(plotters) if max_plots is None else max(max_plots // 2, 1)
190 2
    if len(plotters) > max_plots:
191 2
        warnings.warn(
192
            "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
193
            "of variables to plot ({len_plotters}), generating only {max_plots} "
194
            "plots".format(max_plots=max_plots, len_plotters=len(plotters)),
195
            UserWarning,
196
        )
197 2
        plotters = plotters[:max_plots]
198

199
    # TODO: Check if this can be further simplified
200 2
    trace_plot_args = dict(
201
        # User Kwargs
202
        data=data,
203
        var_names=var_names,
204
        # coords = coords,
205
        divergences=divergences,
206
        kind=kind,
207
        figsize=figsize,
208
        rug=rug,
209
        lines=lines,
210
        circ_var_names=circ_var_names,
211
        circ_var_units=circ_var_units,
212
        plot_kwargs=plot_kwargs,
213
        fill_kwargs=fill_kwargs,
214
        rug_kwargs=rug_kwargs,
215
        hist_kwargs=hist_kwargs,
216
        trace_kwargs=trace_kwargs,
217
        rank_kwargs=rank_kwargs,
218
        compact=compact,
219
        compact_prop=compact_prop,
220
        combined=combined,
221
        chain_prop=chain_prop,
222
        legend=legend,
223
        # Generated kwargs
224
        divergence_data=divergence_data,
225
        # skip_dims=skip_dims,
226
        plotters=plotters,
227
        axes=axes,
228
        backend_config=backend_config,
229
        backend_kwargs=backend_kwargs,
230
        show=show,
231
    )
232

233 2
    if backend is None:
234 2
        backend = rcParams["plot.backend"]
235 2
    backend = backend.lower()
236

237 2
    plot = get_plotting_function("plot_trace", "traceplot", backend)
238 2
    axes = plot(**trace_plot_args)
239

240 2
    return axes

Read our documentation on viewing source code .

Loading