#1421 [WIP] geweke inspired diagnostic and plot

Open Osvaldo Martin aloctavodia
Coverage Reach
plots/backends/matplotlib/forestplot.py plots/backends/matplotlib/ppcplot.py plots/backends/matplotlib/pairplot.py plots/backends/matplotlib/traceplot.py plots/backends/matplotlib/khatplot.py plots/backends/matplotlib/posteriorplot.py plots/backends/matplotlib/bpvplot.py plots/backends/matplotlib/elpdplot.py plots/backends/matplotlib/essplot.py plots/backends/matplotlib/mcseplot.py plots/backends/matplotlib/kdeplot.py plots/backends/matplotlib/loopitplot.py plots/backends/matplotlib/violinplot.py plots/backends/matplotlib/densityplot.py plots/backends/matplotlib/distplot.py plots/backends/matplotlib/__init__.py plots/backends/matplotlib/distcomparisonplot.py plots/backends/matplotlib/jointplot.py plots/backends/matplotlib/energyplot.py plots/backends/matplotlib/rankplot.py plots/backends/matplotlib/separationplot.py plots/backends/matplotlib/compareplot.py plots/backends/matplotlib/autocorrplot.py plots/backends/matplotlib/parallelplot.py plots/backends/matplotlib/hdiplot.py plots/backends/bokeh/forestplot.py plots/backends/bokeh/pairplot.py plots/backends/bokeh/traceplot.py plots/backends/bokeh/kdeplot.py plots/backends/bokeh/ppcplot.py plots/backends/bokeh/bpvplot.py plots/backends/bokeh/posteriorplot.py plots/backends/bokeh/elpdplot.py plots/backends/bokeh/densityplot.py plots/backends/bokeh/__init__.py plots/backends/bokeh/loopitplot.py plots/backends/bokeh/mcseplot.py plots/backends/bokeh/essplot.py plots/backends/bokeh/khatplot.py plots/backends/bokeh/distplot.py plots/backends/bokeh/violinplot.py plots/backends/bokeh/energyplot.py plots/backends/bokeh/rankplot.py plots/backends/bokeh/jointplot.py plots/backends/bokeh/compareplot.py plots/backends/bokeh/separationplot.py plots/backends/bokeh/autocorrplot.py plots/backends/bokeh/parallelplot.py plots/backends/bokeh/hdiplot.py plots/backends/bokeh/distcomparisonplot.py plots/backends/__init__.py plots/plot_utils.py plots/geweke.py plots/ppcplot.py plots/hdiplot.py plots/densityplot.py plots/essplot.py plots/bpvplot.py plots/pairplot.py plots/loopitplot.py plots/traceplot.py plots/parallelplot.py plots/distcomparisonplot.py plots/elpdplot.py plots/separationplot.py plots/khatplot.py plots/forestplot.py plots/kdeplot.py plots/posteriorplot.py plots/mcseplot.py plots/rankplot.py plots/jointplot.py plots/__init__.py plots/violinplot.py plots/compareplot.py plots/distplot.py plots/autocorrplot.py plots/energyplot.py data/inference_data.py data/io_pystan.py data/io_cmdstan.py data/io_pymc3.py data/io_cmdstanpy.py data/io_pyro.py data/io_numpyro.py data/io_dict.py data/io_emcee.py data/base.py data/io_tfp.py data/io_pyjags.py data/converters.py data/datasets.py data/__init__.py data/io_json.py data/io_netcdf.py stats/stats.py stats/diagnostics.py stats/density_utils.py stats/stats_utils.py stats/stats_refitting.py stats/__init__.py utils.py rcparams.py wrappers/base.py wrappers/wrap_pystan.py wrappers/__init__.py

No flags found

Use flags to group coverage reports by test type, project and/or folders.
Then setup custom commit statuses and notifications for each flag.

e.g., #unittest #integration

#production #enterprise

#frontend #backend

Learn more about Codecov Flags here.

Showing 1 of 1 files from the diff.
Newly tracked file
arviz/plots/geweke.py created.

@@ -0,0 +1,172 @@
Loading
1 +
import numpy as np
2 +
import pandas as pd
3 +
import matplotlib.pyplot as plt
4 +
from arviz.plots.plot_utils import xarray_var_iter, default_grid, _scale_fig_size, make_label
5 +
from arviz.data import convert_to_dataset
6 +
from arviz.plots.backends.matplotlib import create_axes_grid
7 +
from arviz.utils import _var_names
8 +
9 +
10 +
def geweke_like(data, var_names=None, splits=10, round_to=2):
11 +
    r"""Compute z-scores for convergence diagnostics.
12 +
13 +
    Concatenates all chains and split them in equal size portions. Them compare them pairwise by computing the
14 +
    difference of the mean divided by their pooled variances. This is esentially a Welch's t statistic.
15 +
    The computed z_scores are expected to be distributed as a standard normal distribution.
16 +
17 +
    Parameters
18 +
    ----------
19 +
    data : obj
20 +
        Any object that can be converted to an az.InferenceData object
21 +
        Refer to documentation of az.convert_to_dataset for details
22 +
    var_names : list
23 +
        Names of variables to include. Prefix the variables by `~` when you
24 +
        want to exclude them from the analysis: `["~beta"]` instead of `["beta"]` (see
25 +
        examples below).
26 +
    splits : int:
27 +
        Number of portions to split the concatenated chains. It must lead to portions of the same size.
28 +
    round_to: int
29 +
        Number of decimals used to round results. Defaults to 2. Use "none" to return raw numbers.
30 +
31 +
    Returns
32 +
    -------
33 +
    pandas.DataFrame
34 +
        Return value will contain summary statistics for each variable. The summaries are the mean of the z_scores,
35 +
        the standard deviation, and the proportion of z_scores in absolute value larger than 2.
36 +
    """
37 +
38 +
    posterior_data = convert_to_dataset(data, group="posterior")
39 +
    iterator = list(xarray_var_iter(
40 +
        posterior_data, var_names=var_names, combined=True))
41 +
    #var_names = _var_names(var_names, posterior_data, filter_vars)
42 +
    summary = {}
43 +
    for var_name, selection, var_data in iterator:
44 +
        z_scores = _geweke_like(var_data, splits=splits)
45 +
        v_name = make_label(var_name, selection, position="beside")
46 +
        summary[v_name] = (z_scores.mean(),  z_scores.std(),
47 +
                           np.mean(np.abs(z_scores) > 2))
48 +
    df_summary = pd.DataFrame.from_dict(
49 +
        summary, orient="index", columns=["mean", "std", ">|2|"])
50 +
    return df_summary.round(round_to)
51 +
52 +
53 +
def _geweke_like(ary, splits=10):
54 +
    ary_flat = np.ravel(ary)
55 +
    ary_split = np.array(np.split(ary_flat, splits))
56 +
    n_ary_s = len(ary_flat) / splits
57 +
    ary_means = ary_split.mean(1)
58 +
    ary_vars = ary_split.var(axis=1, ddof=1) / n_ary_s
59 +
60 +
    z_scores = np.zeros((splits**2-splits)//2)
61 +
    idx = 0
62 +
    for i in range(splits):
63 +
        for j in range(i+1, splits):
64 +
            z_scores[idx] = (ary_means[i] - ary_means[j]) / (ary_vars[i] + ary_vars[j])**0.5
65 +
            idx += 1
66 +
    return z_scores
67 +
68 +
69 +
def plot_geweke_like(data, var_names=None, filter_vars=None, splits=10, kind="scatter",
70 +
                     figsize=None, axes=None, backend_kwargs=None):
71 +
    """Compute and plot z-scores for convergence diagnostics.
72 +
73 +
    Concatenates all chains and split them in equal size portions. Then compare them pairwise by computing the
74 +
    difference of the mean divided by their pooled variances. This is esentially a Welch's t statistic.
75 +
    The computed z_scores are expected to be distributed as a standard normal distribution.
76 +
77 +
    Parameters
78 +
    ----------
79 +
    data : obj
80 +
        Any object that can be converted to an az.InferenceData object
81 +
        Refer to documentation of az.convert_to_dataset for details
82 +
    var_names : list
83 +
        Names of variables to include. Prefix the variables by `~` when you
84 +
        want to exclude them from the analysis: `["~beta"]` instead of `["beta"]` (see
85 +
        examples below).
86 +
    filter_vars: {None, "like", "regex"}, optional, default=None
87 +
        If `None` (default), interpret var_names as the real variables names. If "like",
88 +
        interpret var_names as substrings of the real variables names. If "regex",
89 +
        interpret var_names as regular expressions on the real variables names. A la
90 +
        `pandas.filter`.
91 +
    splits : int:
92 +
        Number of portions to split the concatenated chains. It must lead to portions of the same size.
93 +
    kind : str:
94 +
        Available options are `scatter` or `forest`.
95 +
    figsize: tuple
96 +
        Figure size. If None it will be defined automatically.
97 +
    ax: numpy array-like of matplotlib axes or bokeh figures, optional
98 +
         If not supplied, Arviz will create its own array of plot areas (and return it).
99 +
    backend_kwargs: bool, optional
100 +
        These are kwargs specific to the backend being used. For additional documentation
101 +
        check the plotting method of the backend.
102 +
    """
103 +
    posterior_data = convert_to_dataset(data, group="posterior")
104 +
    var_names = _var_names(var_names, posterior_data, filter_vars)
105 +
106 +
    if backend_kwargs is None:
107 +
        backend_kwargs = {}
108 +
109 +
    plotters = list(xarray_var_iter(
110 +
        posterior_data, var_names=var_names, combined=True))
111 +
    length_plotters = len(plotters)
112 +
113 +
    if kind == "scatter":
114 +
115 +
        rows, cols = default_grid(length_plotters)
116 +
        figsize, ax_labelsize, titlesize, xt_labelsize, _, _ = _scale_fig_size(
117 +
            figsize, None, rows=rows, cols=cols)
118 +
119 +
        backend_kwargs.setdefault("figsize", figsize)
120 +
        backend_kwargs.setdefault("sharex", True)
121 +
        backend_kwargs.setdefault("sharey", True)
122 +
123 +
        if axes is None:
124 +
            _, axes = create_axes_grid(
125 +
                length_plotters,
126 +
                rows,
127 +
                cols,
128 +
                backend_kwargs=backend_kwargs,
129 +
            )
130 +
131 +
        for ax, (var_name, selection, var_data) in zip(np.ravel(axes), plotters):
132 +
            z_scores = _geweke_like(var_data, splits=splits)
133 +
            ax.plot(z_scores, 'o')
134 +
            ax.axhline(-2, color='k', ls='--')
135 +
            ax.axhline(2, color='k', ls='--')
136 +
137 +
            ax.set_title(make_label(var_name, selection),
138 +
                         fontsize=titlesize*1.5, wrap=True)
139 +
            ax.tick_params(labelsize=xt_labelsize*1.5)
140 +
141 +
    elif kind == "forest":
142 +
143 +
        figsize, ax_labelsize, titlesize, xt_labelsize, linewidth, _ = _scale_fig_size(figsize, None,
144 +
                                                                                       rows=length_plotters,
145 +
                                                                                       cols=1)
146 +
        figsize = (figsize[0], figsize[1]/9)
147 +
        if axes is None:
148 +
            fig, axes = plt.subplots(
149 +
                length_plotters, 1, figsize=figsize, sharex=True)
150 +
            axes = np.ravel(axes)
151 +
        backend_kwargs.setdefault("squeeze", True)
152 +
153 +
        fig.set_constrained_layout(False)
154 +
        fig.subplots_adjust(hspace=0)
155 +
        for ax, (var_name, selection, var_data) in zip(axes, plotters):
156 +
            z_scores = _geweke_like(var_data, splits=splits)
157 +
            quant = np.quantile(z_scores, (.05, .16, .5, .84, .95))
158 +
            ax.plot((quant[1], quant[3]), [0, 0],  lw=linewidth *
159 +
                    4, color="C0", solid_capstyle="round")
160 +
            ax.plot((quant[0], quant[4]), [0, 0],  lw=linewidth *
161 +
                    1, color="C0", solid_capstyle="round")
162 +
            ax.plot(quant[2], 0, 'ko')
163 +
            for i in [-2, -1, 1, 2]:
164 +
                ax.axvline(i, ls="--", color="k")
165 +
            x_label = make_label(var_name, selection, "beside")
166 +
            ax.set_ylabel(x_label, rotation=0, labelpad=40 +
167 +
                          len(x_label)*2, fontsize=ax_labelsize*0.9)
168 +
            ax.set_yticks([])
169 +
            if ax != axes[0]:
170 +
                ax.spines['top'].set_visible(False)
171 +
172 +
    return axes

Learn more Showing 12 files with coverage changes found.

Changes in arviz/data/io_emcee.py
-103
+103
Loading file...
Changes in arviz/data/io_tfp.py
-88
+88
Loading file...
Changes in arviz/data/io_pystan.py
-429
+429
Loading file...
Changes in arviz/data/io_cmdstanpy.py
-176
+176
Loading file...
Changes in arviz/data/io_pymc3.py
-202
+202
Loading file...
Changes in arviz/data/io_pyro.py
-109
+109
Loading file...
Changes in arviz/data/io_numpyro.py
-105
+105
Loading file...
Changes in arviz/data/io_pyjags.py
-64
+64
Loading file...
Changes in arviz/data/io_cmdstan.py
-188
+188
Loading file...
Changes in arviz/data/base.py
-9
+9
Loading file...
Changes in arviz/utils.py
-1
+1
Loading file...
New file arviz/plots/geweke.py
New
Loading file...
Files Coverage
arviz -13.81% 77.76%
Project Totals (106 files) 77.76%
Loading