1
"""Bokeh Posterior predictive plot."""
2 2
import numpy as np
3

4 2
from ....stats.density_utils import get_bins, histogram, kde
5 2
from ...kdeplot import plot_kde
6 2
from ...plot_utils import _scale_fig_size
7 2
from .. import show_layout
8 2
from . import backend_kwarg_defaults, create_axes_grid
9

10

11 2
def plot_ppc(
12
    ax,
13
    length_plotters,
14
    rows,
15
    cols,
16
    figsize,
17
    animated,
18
    obs_plotters,
19
    pp_plotters,
20
    predictive_dataset,
21
    pp_sample_ix,
22
    kind,
23
    alpha,
24
    color,  # pylint: disable=unused-argument
25
    textsize,
26
    mean,
27
    jitter,
28
    total_pp_samples,
29
    legend,  # pylint: disable=unused-argument
30
    group,  # pylint: disable=unused-argument
31
    animation_kwargs,  # pylint: disable=unused-argument
32
    num_pp_samples,
33
    backend_kwargs,
34
    show,
35
):
36
    """Bokeh ppc plot."""
37 2
    if backend_kwargs is None:
38 2
        backend_kwargs = {}
39

40 2
    backend_kwargs = {
41
        **backend_kwarg_defaults(
42
            ("dpi", "plot.bokeh.figure.dpi"),
43
        ),
44
        **backend_kwargs,
45
    }
46

47 2
    (figsize, *_, linewidth, markersize) = _scale_fig_size(figsize, textsize, rows, cols)
48 2
    if ax is None:
49 2
        axes = create_axes_grid(
50
            length_plotters,
51
            rows,
52
            cols,
53
            figsize=figsize,
54
            backend_kwargs=backend_kwargs,
55
        )
56
    else:
57 2
        axes = np.atleast_2d(ax)
58

59 2
        if len([item for item in axes.ravel() if not None]) != length_plotters:
60 0
            raise ValueError(
61
                "Found {} variables to plot but {} axes instances. They must be equal.".format(
62
                    length_plotters, len(axes)
63
                )
64
            )
65

66 2
    if alpha is None:
67 2
        if animated:
68 0
            alpha = 1
69
        else:
70 2
            if kind.lower() == "scatter":
71 2
                alpha = 0.7
72
            else:
73 2
                alpha = 0.2
74

75 2
    if jitter is None:
76 2
        jitter = 0.0
77 2
    if jitter < 0.0:
78 0
        raise ValueError("jitter must be >=0.")
79

80 2
    for i, ax_i in enumerate((item for item in axes.flatten() if item is not None)):
81 2
        var_name, _, obs_vals = obs_plotters[i]
82 2
        pp_var_name, _, pp_vals = pp_plotters[i]
83 2
        dtype = predictive_dataset[pp_var_name].dtype.kind
84

85
        # flatten non-specified dimensions
86 2
        obs_vals = obs_vals.flatten()
87 2
        pp_vals = pp_vals.reshape(total_pp_samples, -1)
88 2
        pp_sampled_vals = pp_vals[pp_sample_ix]
89

90 2
        if kind == "kde":
91 2
            plot_kwargs = {"line_color": "red", "line_alpha": alpha, "line_width": 0.5 * linewidth}
92

93 2
            pp_densities = []
94 2
            pp_xs = []
95 2
            for vals in pp_sampled_vals:
96 2
                vals = np.array([vals]).flatten()
97 2
                if dtype == "f":
98 2
                    pp_x, pp_density = kde(vals)
99 2
                    pp_densities.append(pp_density)
100 2
                    pp_xs.append(pp_x)
101
                else:
102 2
                    bins = get_bins(vals)
103 2
                    _, hist, bin_edges = histogram(vals, bins=bins)
104 2
                    hist = np.concatenate((hist[:1], hist))
105 2
                    pp_densities.append(hist)
106 2
                    pp_xs.append(bin_edges)
107

108 2
            if dtype == "f":
109 2
                ax_i.multi_line(pp_xs, pp_densities, **plot_kwargs)
110
            else:
111 2
                for x_s, y_s in zip(pp_xs, pp_densities):
112 2
                    ax_i.step(x_s, y_s, **plot_kwargs)
113

114 2
            if dtype == "f":
115 2
                plot_kde(
116
                    obs_vals,
117
                    plot_kwargs={"line_color": "black", "line_width": linewidth},
118
                    fill_kwargs={"alpha": 0},
119
                    ax=ax_i,
120
                    backend="bokeh",
121
                    backend_kwargs={},
122
                    show=False,
123
                )
124
            else:
125 2
                bins = get_bins(obs_vals)
126 2
                _, hist, bin_edges = histogram(obs_vals, bins=bins)
127 2
                hist = np.concatenate((hist[:1], hist))
128 2
                ax_i.step(
129
                    bin_edges,
130
                    hist,
131
                    line_color="black",
132
                    line_width=linewidth,
133
                    mode="center",
134
                )
135

136 2
            if mean:
137 2
                if dtype == "f":
138 2
                    rep = len(pp_densities)
139 2
                    len_density = len(pp_densities[0])
140

141 2
                    new_x = np.linspace(np.min(pp_xs), np.max(pp_xs), len_density)
142 2
                    new_d = np.zeros((rep, len_density))
143 2
                    bins = np.digitize(pp_xs, new_x, right=True)
144 2
                    new_x -= (new_x[1] - new_x[0]) / 2
145 2
                    for irep in range(rep):
146 2
                        new_d[irep][bins[irep]] = pp_densities[irep]
147 2
                    ax_i.line(
148
                        new_x,
149
                        new_d.mean(0),
150
                        color="blue",
151
                        line_dash="dashed",
152
                        line_width=linewidth,
153
                    )
154
                else:
155 2
                    vals = pp_vals.flatten()
156 2
                    bins = get_bins(vals)
157 2
                    _, hist, bin_edges = histogram(vals, bins=bins)
158 2
                    hist = np.concatenate((hist[:1], hist))
159 2
                    ax_i.step(
160
                        bin_edges,
161
                        hist,
162
                        line_color="blue",
163
                        line_width=linewidth,
164
                        line_dash="dashed",
165
                        mode="center",
166
                    )
167 2
            ax_i.yaxis.major_tick_line_color = None
168 2
            ax_i.yaxis.minor_tick_line_color = None
169 2
            ax_i.yaxis.major_label_text_font_size = "0pt"
170

171 2
        elif kind == "cumulative":
172 2
            if dtype == "f":
173 2
                ax_i.line(
174
                    *_empirical_cdf(obs_vals),
175
                    line_color="black",
176
                    line_width=linewidth,
177
                )
178
            else:
179 2
                ax_i.step(
180
                    *_empirical_cdf(obs_vals),
181
                    line_color="black",
182
                    line_width=linewidth,
183
                    mode="center",
184
                )
185 2
            pp_densities = np.empty((2 * len(pp_sampled_vals), pp_sampled_vals[0].size))
186 2
            for idx, vals in enumerate(pp_sampled_vals):
187 2
                vals = np.array([vals]).flatten()
188 2
                pp_x, pp_density = _empirical_cdf(vals)
189 2
                pp_densities[2 * idx] = pp_x
190 2
                pp_densities[2 * idx + 1] = pp_density
191 2
            ax_i.multi_line(
192
                list(pp_densities[::2]),
193
                list(pp_densities[1::2]),
194
                line_alpha=alpha,
195
                line_color="pink",
196
                line_width=linewidth,
197
            )
198 2
            if mean:
199 2
                ax_i.line(
200
                    *_empirical_cdf(pp_vals.flatten()),
201
                    color="blue",
202
                    line_dash="dashed",
203
                    line_width=linewidth,
204
                )
205

206 2
        elif kind == "scatter":
207 2
            if mean:
208 2
                if dtype == "f":
209 2
                    plot_kde(
210
                        pp_vals.flatten(),
211
                        plot_kwargs={
212
                            "line_color": "blue",
213
                            "line_dash": "dashed",
214
                            "line_width": linewidth,
215
                        },
216
                        ax=ax_i,
217
                        backend="bokeh",
218
                        backend_kwargs={},
219
                        show=False,
220
                    )
221
                else:
222 2
                    vals = pp_vals.flatten()
223 2
                    bins = get_bins(vals)
224 2
                    _, hist, bin_edges = histogram(vals, bins=bins)
225 2
                    hist = np.concatenate((hist[:1], hist))
226 2
                    ax_i.step(
227
                        bin_edges,
228
                        hist,
229
                        color="blue",
230
                        line_width=linewidth,
231
                        line_dash="dashed",
232
                        mode="center",
233
                    )
234

235 2
            jitter_scale = 0.1
236 2
            y_rows = np.linspace(0, 0.1, num_pp_samples + 1)
237 2
            scale_low = 0
238 2
            scale_high = jitter_scale * jitter
239

240 2
            obs_yvals = np.zeros_like(obs_vals, dtype=np.float64)
241 2
            if jitter:
242 2
                obs_yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(obs_vals))
243 2
            ax_i.circle(
244
                obs_vals,
245
                obs_yvals,
246
                fill_color="black",
247
                size=markersize,
248
                line_alpha=alpha,
249
            )
250

251 2
            for vals, y in zip(pp_sampled_vals, y_rows[1:]):
252 2
                vals = np.ravel(vals)
253 2
                yvals = np.full_like(vals, y, dtype=np.float64)
254 2
                if jitter:
255 2
                    yvals += np.random.uniform(low=scale_low, high=scale_high, size=len(vals))
256 2
                ax_i.circle(vals, yvals, fill_color="red", size=markersize, fill_alpha=alpha)
257

258 2
            ax_i.yaxis.major_tick_line_color = None
259 2
            ax_i.yaxis.minor_tick_line_color = None
260 2
            ax_i.yaxis.major_label_text_font_size = "0pt"
261

262 2
        if var_name != pp_var_name:
263 2
            xlabel = "{} / {}".format(var_name, pp_var_name)
264
        else:
265 2
            xlabel = var_name
266 2
        ax_i.xaxis.axis_label = xlabel
267

268 2
    show_layout(axes, show)
269

270 2
    return axes
271

272

273 2
def _empirical_cdf(data):
274
    """Compute empirical cdf of a numpy array.
275

276
    Parameters
277
    ----------
278
    data : np.array
279
        1d array
280

281
    Returns
282
    -------
283
    np.array, np.array
284
        x and y coordinates for the empirical cdf of the data
285
    """
286 2
    return np.sort(data), np.linspace(0, 1, len(data))

Read our documentation on viewing source code .

Loading