1
"""Matplotib Bayesian p-value Posterior predictive plot."""
2 4
import matplotlib.pyplot as plt
3 4
import numpy as np
4 4
from scipy import stats
5 4
from scipy.interpolate import CubicSpline
6

7 4
from ....stats.density_utils import kde
8 4
from ...kdeplot import plot_kde
9 4
from ...plot_utils import (
10
    _scale_fig_size,
11
    is_valid_quantile,
12
    make_label,
13
    sample_reference_distribution,
14
)
15 4
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
16

17

18 4
def plot_bpv(
19
    ax,
20
    length_plotters,
21
    rows,
22
    cols,
23
    obs_plotters,
24
    pp_plotters,
25
    total_pp_samples,
26
    kind,
27
    t_stat,
28
    bpv,
29
    plot_mean,
30
    reference,
31
    mse,
32
    n_ref,
33
    hdi_prob,
34
    color,
35
    figsize,
36
    textsize,
37
    plot_ref_kwargs,
38
    backend_kwargs,
39
    show,
40
):
41
    """Matplotlib bpv plot."""
42 4
    if backend_kwargs is None:
43 4
        backend_kwargs = {}
44

45 4
    backend_kwargs = {
46
        **backend_kwarg_defaults(),
47
        **backend_kwargs,
48
    }
49

50 4
    figsize, ax_labelsize, _, _, linewidth, markersize = _scale_fig_size(
51
        figsize, textsize, rows, cols
52
    )
53

54 4
    backend_kwargs.setdefault("figsize", figsize)
55 4
    backend_kwargs.setdefault("squeeze", True)
56

57 4
    if (kind == "u_value") and (reference == "analytical"):
58 4
        plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "fill_between")
59
    else:
60 4
        plot_ref_kwargs = matplotlib_kwarg_dealiaser(plot_ref_kwargs, "plot")
61

62 4
    if kind == "p_value" and reference == "analytical":
63 4
        plot_ref_kwargs.setdefault("color", "k")
64 4
        plot_ref_kwargs.setdefault("linestyle", "--")
65 4
    elif kind == "u_value" and reference == "analytical":
66 4
        plot_ref_kwargs.setdefault("color", "k")
67 4
        plot_ref_kwargs.setdefault("alpha", 0.2)
68
    else:
69 4
        plot_ref_kwargs.setdefault("alpha", 0.1)
70 4
        plot_ref_kwargs.setdefault("color", color)
71

72 4
    if ax is None:
73 4
        _, axes = create_axes_grid(length_plotters, rows, cols, backend_kwargs=backend_kwargs)
74
    else:
75 0
        axes = np.asarray(ax)
76 0
        if axes.size < length_plotters:
77 0
            raise ValueError(
78
                (
79
                    "Found {} variables to plot but {} axes instances. "
80
                    "Axes instances must at minimum be equal to variables."
81
                ).format(length_plotters, axes.size)
82
            )
83

84 4
    for i, ax_i in enumerate(np.ravel(axes)[:length_plotters]):
85 4
        var_name, selection, obs_vals = obs_plotters[i]
86 4
        pp_var_name, _, pp_vals = pp_plotters[i]
87

88 4
        obs_vals = obs_vals.flatten()
89 4
        pp_vals = pp_vals.reshape(total_pp_samples, -1)
90

91 4
        if obs_vals.dtype.kind == "i" or pp_vals.dtype.kind == "i":
92 4
            x = np.linspace(0, 1, len(obs_vals))
93 4
            csi = CubicSpline(x, obs_vals)
94 4
            obs_vals = csi(np.linspace(0.001, 0.999, len(obs_vals)))
95

96 4
            x = np.linspace(0, 1, pp_vals.shape[1])
97 4
            csi = CubicSpline(x, pp_vals, axis=1)
98 4
            pp_vals = csi(np.linspace(0.001, 0.999, pp_vals.shape[1]))
99

100 4
        if kind == "p_value":
101 4
            tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
102 4
            x_s, tstat_pit_dens = kde(tstat_pit)
103 4
            ax_i.plot(x_s, tstat_pit_dens, linewidth=linewidth, color=color)
104 4
            ax_i.set_yticks([])
105 4
            if reference is not None:
106 4
                dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
107 4
                if reference == "analytical":
108 4
                    lwb = dist.ppf((1 - 0.9999) / 2)
109 4
                    upb = 1 - lwb
110 4
                    x = np.linspace(lwb, upb, 500)
111 4
                    dens_ref = dist.pdf(x)
112 4
                    ax_i.plot(x, dens_ref, zorder=1, **plot_ref_kwargs)
113 0
                elif reference == "samples":
114 0
                    x_ss, u_dens = sample_reference_distribution(
115
                        dist,
116
                        (
117
                            tstat_pit_dens.size,
118
                            n_ref,
119
                        ),
120
                    )
121 0
                    ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
122

123 4
        elif kind == "u_value":
124 4
            tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
125 4
            x_s, tstat_pit_dens = kde(tstat_pit)
126 4
            ax_i.plot(x_s, tstat_pit_dens, color=color)
127 4
            if reference is not None:
128 4
                if reference == "analytical":
129 4
                    n_obs = obs_vals.size
130 4
                    hdi_ = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
131 4
                    hdi_odds = (hdi_ / (1 - hdi_), (1 - hdi_) / hdi_)
132 4
                    ax_i.axhspan(*hdi_odds, **plot_ref_kwargs)
133 4
                    ax_i.axhline(1, color="w", zorder=1)
134 0
                elif reference == "samples":
135 0
                    dist = stats.uniform(0, 1)
136 0
                    x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
137 0
                    ax_i.plot(x_ss, u_dens, linewidth=linewidth, **plot_ref_kwargs)
138 4
            if mse:
139 0
                ax_i.plot(0, 0, label=f"mse={np.mean((1 - tstat_pit_dens)**2) * 100:.2f}")
140 0
                ax_i.legend()
141

142 4
            ax_i.set_ylim(0, None)
143 4
            ax_i.set_xlim(0, 1)
144
        else:
145 4
            if t_stat in ["mean", "median", "std"]:
146 4
                if t_stat == "mean":
147 0
                    tfunc = np.mean
148 4
                elif t_stat == "median":
149 0
                    tfunc = np.median
150 4
                elif t_stat == "std":
151 4
                    tfunc = np.std
152 4
                obs_vals = tfunc(obs_vals)
153 4
                pp_vals = tfunc(pp_vals, axis=1)
154 4
            elif hasattr(t_stat, "__call__"):
155 0
                obs_vals = t_stat(obs_vals.flatten())
156 0
                pp_vals = t_stat(pp_vals)
157 4
            elif is_valid_quantile(t_stat):
158 4
                t_stat = float(t_stat)
159 4
                obs_vals = np.quantile(obs_vals, q=t_stat)
160 4
                pp_vals = np.quantile(pp_vals, q=t_stat, axis=1)
161
            else:
162 0
                raise ValueError(f"T statistics {t_stat} not implemented")
163

164 4
            plot_kde(pp_vals, ax=ax_i, plot_kwargs={"color": color})
165 4
            ax_i.set_yticks([])
166 4
            if bpv:
167 4
                p_value = np.mean(pp_vals <= obs_vals)
168 4
                ax_i.plot(obs_vals, 0, label=f"bpv={p_value:.2f}", alpha=0)
169 4
                ax_i.legend()
170

171 4
            if plot_mean:
172 4
                ax_i.plot(
173
                    obs_vals.mean(), 0, "o", color=color, markeredgecolor="k", markersize=markersize
174
                )
175

176 4
        if var_name != pp_var_name:
177 0
            xlabel = "{} / {}".format(var_name, pp_var_name)
178
        else:
179 4
            xlabel = var_name
180 4
        ax_i.set_title(make_label(xlabel, selection), fontsize=ax_labelsize)
181

182 4
    if backend_show(show):
183 0
        plt.show()
184

185 4
    return axes

Read our documentation on viewing source code .

Loading