1
"""Parallel coordinates plot showing posterior points with and without divergences marked."""
2 6
import numpy as np
3 6
from scipy.stats import rankdata
4

5 6
from ..data import convert_to_dataset
6 6
from ..rcparams import rcParams
7 6
from ..stats.stats_utils import stats_variance_2d as svar
8 6
from ..utils import _numba_var, _var_names, get_coords
9 6
from .plot_utils import get_plotting_function, xarray_to_ndarray
10

11

12 6
def plot_parallel(
13
    data,
14
    var_names=None,
15
    filter_vars=None,
16
    coords=None,
17
    figsize=None,
18
    textsize=None,
19
    legend=True,
20
    colornd="k",
21
    colord="C1",
22
    shadend=0.025,
23
    ax=None,
24
    norm_method=None,
25
    backend=None,
26
    backend_config=None,
27
    backend_kwargs=None,
28
    show=None,
29
):
30
    """
31
    Plot parallel coordinates plot showing posterior points with and without divergences.
32

33
    Described by https://arxiv.org/abs/1709.01449
34

35
    Parameters
36
    ----------
37
    data: obj
38
        Any object that can be converted to an az.InferenceData object
39
        Refer to documentation of az.convert_to_dataset for details
40
    var_names: list of variable names
41
        Variables to be plotted, if `None` all variable are plotted. Can be used to change the order
42
        of the plotted variables. Prefix the variables by `~` when you want to exclude
43
        them from the plot.
44
    filter_vars: {None, "like", "regex"}, optional, default=None
45
        If `None` (default), interpret var_names as the real variables names. If "like",
46
        interpret var_names as substrings of the real variables names. If "regex",
47
        interpret var_names as regular expressions on the real variables names. A la
48
        `pandas.filter`.
49
    coords: mapping, optional
50
        Coordinates of var_names to be plotted. Passed to `Dataset.sel`
51
    figsize: tuple
52
        Figure size. If None it will be defined automatically.
53
    textsize: float
54
        Text size scaling factor for labels, titles and lines. If None it will be autoscaled based
55
        on figsize.
56
    legend: bool
57
        Flag for plotting legend (defaults to True)
58
    colornd: valid matplotlib color
59
        color for non-divergent points. Defaults to 'k'
60
    colord: valid matplotlib color
61
        color for divergent points. Defaults to 'C1'
62
    shadend: float
63
        Alpha blending value for non-divergent points, between 0 (invisible) and 1 (opaque).
64
        Defaults to .025
65
    ax: axes, optional
66
        Matplotlib axes or bokeh figures.
67
    norm_method: str
68
        Method for normalizing the data. Methods include normal, minmax and rank.
69
        Defaults to none.
70
    backend: str, optional
71
        Select plotting backend {"matplotlib","bokeh"}. Default "matplotlib".
72
    backend_config: dict, optional
73
        Currently specifies the bounds to use for bokeh axes. Defaults to value set in rcParams.
74
    backend_kwargs: bool, optional
75
        These are kwargs specific to the backend being used. For additional documentation
76
        check the plotting method of the backend.
77
    show: bool, optional
78
        Call backend show function.
79

80
    Returns
81
    -------
82
    axes: matplotlib axes or bokeh figures
83

84
    Examples
85
    --------
86
    Plot default parallel plot
87

88
    .. plot::
89
        :context: close-figs
90

91
        >>> import arviz as az
92
        >>> data = az.load_arviz_data('centered_eight')
93
        >>> az.plot_parallel(data, var_names=["mu", "tau"])
94

95

96
    Plot parallel plot with normalization
97

98
    .. plot::
99
        :context: close-figs
100

101
        >>> az.plot_parallel(data, var_names=["mu", "tau"], norm_method='normal')
102

103
    """
104 4
    if coords is None:
105 4
        coords = {}
106

107
    # Get diverging draws and combine chains
108 4
    divergent_data = convert_to_dataset(data, group="sample_stats")
109 4
    _, diverging_mask = xarray_to_ndarray(divergent_data, var_names=("diverging",), combined=True)
110 4
    diverging_mask = np.squeeze(diverging_mask)
111

112
    # Get posterior draws and combine chains
113 4
    posterior_data = convert_to_dataset(data, group="posterior")
114 4
    var_names = _var_names(var_names, posterior_data, filter_vars)
115 4
    var_names, _posterior = xarray_to_ndarray(
116
        get_coords(posterior_data, coords), var_names=var_names, combined=True
117
    )
118 4
    if len(var_names) < 2:
119 4
        raise ValueError("This plot needs at least two variables")
120 4
    if norm_method is not None:
121 4
        if norm_method == "normal":
122 4
            mean = np.mean(_posterior, axis=1)
123 4
            if _posterior.ndim <= 2:
124 4
                standard_deviation = np.sqrt(_numba_var(svar, np.var, _posterior, axis=1))
125
            else:
126 0
                standard_deviation = np.std(_posterior, axis=1)
127 4
            for i in range(0, np.shape(mean)[0]):
128 4
                _posterior[i, :] = (_posterior[i, :] - mean[i]) / standard_deviation[i]
129 4
        elif norm_method == "minmax":
130 4
            min_elem = np.min(_posterior, axis=1)
131 4
            max_elem = np.max(_posterior, axis=1)
132 4
            for i in range(0, np.shape(min_elem)[0]):
133 4
                _posterior[i, :] = ((_posterior[i, :]) - min_elem[i]) / (max_elem[i] - min_elem[i])
134 4
        elif norm_method == "rank":
135 4
            _posterior = rankdata(_posterior, axis=1, method="average")
136
        else:
137 4
            raise ValueError("{} is not supported. Use normal, minmax or rank.".format(norm_method))
138

139 4
    parallel_kwargs = dict(
140
        ax=ax,
141
        colornd=colornd,
142
        colord=colord,
143
        shadend=shadend,
144
        diverging_mask=diverging_mask,
145
        posterior=_posterior,
146
        textsize=textsize,
147
        var_names=var_names,
148
        legend=legend,
149
        figsize=figsize,
150
        backend_kwargs=backend_kwargs,
151
        backend_config=backend_config,
152
        show=show,
153
    )
154

155 4
    if backend is None:
156 4
        backend = rcParams["plot.backend"]
157 4
    backend = backend.lower()
158

159
    # TODO: Add backend kwargs
160 4
    plot = get_plotting_function("plot_parallel", "parallelplot", backend)
161 4
    ax = plot(**parallel_kwargs)
162

163 4
    return ax

Read our documentation on viewing source code .

Loading