1
"""Bokeh Bayesian p-value Posterior predictive plot."""
2 2
import numpy as np
3 2
from bokeh.models import BoxAnnotation
4 2
from bokeh.models.annotations import Title
5 2
from scipy import stats
6

7 2
from ....stats.density_utils import kde
8 2
from ...kdeplot import plot_kde
9 2
from ...plot_utils import (
10
    _scale_fig_size,
11
    is_valid_quantile,
12
    sample_reference_distribution,
13
    vectorized_to_hex,
14
)
15 2
from .. import show_layout
16 2
from . import backend_kwarg_defaults, create_axes_grid
17

18

19 2
def plot_bpv(
20
    ax,
21
    length_plotters,
22
    rows,
23
    cols,
24
    obs_plotters,
25
    pp_plotters,
26
    total_pp_samples,
27
    kind,
28
    t_stat,
29
    bpv,
30
    plot_mean,
31
    reference,
32
    n_ref,
33
    hdi_prob,
34
    color,
35
    figsize,
36
    textsize,
37
    plot_ref_kwargs,
38
    backend_kwargs,
39
    show,
40
):
41
    """Bokeh bpv plot."""
42 2
    if backend_kwargs is None:
43 2
        backend_kwargs = {}
44

45 2
    backend_kwargs = {
46
        **backend_kwarg_defaults(),
47
        **backend_kwargs,
48
    }
49

50 2
    color = vectorized_to_hex(color)
51

52 2
    if plot_ref_kwargs is None:
53 2
        plot_ref_kwargs = {}
54 2
    if kind == "p_value" and reference == "analytical":
55 0
        plot_ref_kwargs.setdefault("line_color", "black")
56 0
        plot_ref_kwargs.setdefault("line_dash", "dashed")
57
    else:
58 2
        plot_ref_kwargs.setdefault("alpha", 0.1)
59 2
        plot_ref_kwargs.setdefault("line_color", color)
60

61 2
    (figsize, ax_labelsize, _, _, linewidth, markersize) = _scale_fig_size(
62
        figsize, textsize, rows, cols
63
    )
64

65 2
    if ax is None:
66 2
        axes = create_axes_grid(
67
            length_plotters,
68
            rows,
69
            cols,
70
            figsize=figsize,
71
            backend_kwargs=backend_kwargs,
72
        )
73
    else:
74 0
        axes = np.atleast_2d(ax)
75

76 0
        if len([item for item in axes.ravel() if not None]) != length_plotters:
77 0
            raise ValueError(
78
                "Found {} variables to plot but {} axes instances. They must be equal.".format(
79
                    length_plotters, len(axes)
80
                )
81
            )
82

83 2
    for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)):
84 2
        var_name, _, obs_vals = obs_plotters[i]
85 2
        pp_var_name, _, pp_vals = pp_plotters[i]
86

87 2
        obs_vals = obs_vals.flatten()
88 2
        pp_vals = pp_vals.reshape(total_pp_samples, -1)
89

90 2
        if kind == "p_value":
91 2
            tstat_pit = np.mean(pp_vals <= obs_vals, axis=-1)
92 2
            x_s, tstat_pit_dens = kde(tstat_pit)
93 2
            ax_i.line(x_s, tstat_pit_dens, line_width=linewidth, line_color=color)
94
            # ax_i.set_yticks([])
95 2
            if reference is not None:
96 2
                dist = stats.beta(obs_vals.size / 2, obs_vals.size / 2)
97 2
                if reference == "analytical":
98 0
                    lwb = dist.ppf((1 - 0.9999) / 2)
99 0
                    upb = 1 - lwb
100 0
                    x = np.linspace(lwb, upb, 500)
101 0
                    dens_ref = dist.pdf(x)
102 0
                    ax_i.line(x, dens_ref, **plot_ref_kwargs)
103 2
                elif reference == "samples":
104 2
                    x_ss, u_dens = sample_reference_distribution(
105
                        dist,
106
                        (
107
                            n_ref,
108
                            tstat_pit_dens.size,
109
                        ),
110
                    )
111 2
                    ax_i.multi_line(
112
                        list(x_ss.T), list(u_dens.T), line_width=linewidth, **plot_ref_kwargs
113
                    )
114

115 2
        elif kind == "u_value":
116 2
            tstat_pit = np.mean(pp_vals <= obs_vals, axis=0)
117 2
            x_s, tstat_pit_dens = kde(tstat_pit)
118 2
            ax_i.line(x_s, tstat_pit_dens, line_color=color)
119 2
            if reference is not None:
120 2
                if reference == "analytical":
121 2
                    n_obs = obs_vals.size
122 2
                    hdi = stats.beta(n_obs / 2, n_obs / 2).ppf((1 - hdi_prob) / 2)
123 2
                    hdi_odds = (hdi / (1 - hdi), (1 - hdi) / hdi)
124 2
                    ax_i.add_layout(
125
                        BoxAnnotation(
126
                            bottom=hdi_odds[1],
127
                            top=hdi_odds[0],
128
                            fill_alpha=plot_ref_kwargs.pop("alpha"),
129
                            fill_color=plot_ref_kwargs.pop("line_color"),
130
                            **plot_ref_kwargs,
131
                        )
132
                    )
133 2
                    ax_i.line([0, 1], [1, 1], line_color="white")
134 2
                elif reference == "samples":
135 2
                    dist = stats.uniform(0, 1)
136 2
                    x_ss, u_dens = sample_reference_distribution(dist, (tstat_pit_dens.size, n_ref))
137 2
                    for x_ss_i, u_dens_i in zip(x_ss.T, u_dens.T):
138 2
                        ax_i.line(x_ss_i, u_dens_i, line_width=linewidth, **plot_ref_kwargs)
139 2
            ax_i.line(0, 0)
140
        else:
141 2
            if t_stat in ["mean", "median", "std"]:
142 2
                if t_stat == "mean":
143 0
                    tfunc = np.mean
144 2
                elif t_stat == "median":
145 0
                    tfunc = np.median
146 2
                elif t_stat == "std":
147 2
                    tfunc = np.std
148 2
                obs_vals = tfunc(obs_vals)
149 2
                pp_vals = tfunc(pp_vals, axis=1)
150 2
            elif hasattr(t_stat, "__call__"):
151 0
                obs_vals = t_stat(obs_vals.flatten())
152 0
                pp_vals = t_stat(pp_vals)
153 2
            elif is_valid_quantile(t_stat):
154 2
                t_stat = float(t_stat)
155 2
                obs_vals = np.quantile(obs_vals, q=t_stat)
156 2
                pp_vals = np.quantile(pp_vals, q=t_stat, axis=1)
157
            else:
158 0
                raise ValueError(f"T statistics {t_stat} not implemented")
159

160 2
            plot_kde(pp_vals, ax=ax_i, plot_kwargs={"color": color}, backend="bokeh", show=False)
161
            # ax_i.set_yticks([])
162 2
            if bpv:
163 2
                p_value = np.mean(pp_vals <= obs_vals)
164 2
                ax_i.line(0, 0, legend_label=f"bpv={p_value:.2f}", alpha=0)
165

166 2
            if plot_mean:
167 2
                ax_i.circle(
168
                    obs_vals.mean(), 0, fill_color=color, line_color="black", size=markersize
169
                )
170

171 2
        if var_name != pp_var_name:
172 0
            xlabel = "{} / {}".format(var_name, pp_var_name)
173
        else:
174 2
            xlabel = var_name
175 2
        _title = Title()
176 2
        _title.text = xlabel
177 2
        ax_i.title = _title
178 2
        size = str(int(ax_labelsize))
179 2
        ax_i.title.text_font_size = f"{size}pt"
180

181 2
    show_layout(axes, show)
182

183 2
    return axes

Read our documentation on viewing source code .

Loading