1
"""Bokeh loopitplot."""
2 2
import numpy as np
3 2
from bokeh.models import BoxAnnotation
4 2
from matplotlib.colors import hsv_to_rgb, rgb_to_hsv, to_hex, to_rgb
5 2
from xarray import DataArray
6

7 2
from ....stats.density_utils import kde
8 2
from ...plot_utils import _scale_fig_size
9 2
from .. import show_layout
10 2
from . import backend_kwarg_defaults, create_axes_grid
11

12

13 2
def plot_loo_pit(
14
    ax,
15
    figsize,
16
    ecdf,
17
    loo_pit,
18
    loo_pit_ecdf,
19
    unif_ecdf,
20
    p975,
21
    p025,
22
    fill_kwargs,
23
    ecdf_fill,
24
    use_hdi,
25
    x_vals,
26
    hdi_kwargs,
27
    hdi_odds,
28
    n_unif,
29
    unif,
30
    plot_unif_kwargs,
31
    loo_pit_kde,
32
    legend,  # pylint: disable=unused-argument
33
    y_hat,
34
    y,
35
    color,
36
    textsize,
37
    credible_interval,
38
    plot_kwargs,
39
    backend_kwargs,
40
    show,
41
):
42
    """Bokeh loo pit plot."""
43 2
    if backend_kwargs is None:
44 2
        backend_kwargs = {}
45

46 2
    backend_kwargs = {
47
        **backend_kwarg_defaults(),
48
        **backend_kwargs,
49
    }
50

51 2
    (figsize, *_, linewidth, _) = _scale_fig_size(figsize, textsize, 1, 1)
52

53 2
    if ax is None:
54 2
        backend_kwargs.setdefault("x_range", (0, 1))
55 2
        ax = create_axes_grid(
56
            1,
57
            figsize=figsize,
58
            squeeze=True,
59
            backend_kwargs=backend_kwargs,
60
        )
61

62 2
    plot_kwargs = {} if plot_kwargs is None else plot_kwargs
63 2
    plot_kwargs.setdefault("color", to_hex(color))
64 2
    plot_kwargs.setdefault("linewidth", linewidth * 1.4)
65 2
    if isinstance(y, str):
66 2
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y)
67 2
    elif isinstance(y, DataArray) and y.name is not None:
68 2
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y.name)
69 2
    elif isinstance(y_hat, str):
70 2
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat)
71 2
    elif isinstance(y_hat, DataArray) and y_hat.name is not None:
72 2
        label = ("{} LOO-PIT ECDF" if ecdf else "{} LOO-PIT").format(y_hat.name)
73
    else:
74 2
        label = "LOO-PIT ECDF" if ecdf else "LOO-PIT"
75

76 2
    plot_kwargs.setdefault("legend_label", label)
77

78 2
    plot_unif_kwargs = {} if plot_unif_kwargs is None else plot_unif_kwargs
79 2
    light_color = rgb_to_hsv(to_rgb(plot_kwargs.get("color")))
80 2
    light_color[1] /= 2  # pylint: disable=unsupported-assignment-operation
81 2
    light_color[2] += (1 - light_color[2]) / 2  # pylint: disable=unsupported-assignment-operation
82 2
    plot_unif_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
83 2
    plot_unif_kwargs.setdefault("alpha", 0.5)
84 2
    plot_unif_kwargs.setdefault("linewidth", 0.6 * linewidth)
85

86 2
    if ecdf:
87 2
        n_data_points = loo_pit.size
88 2
        plot_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
89 2
        plot_unif_kwargs.setdefault("drawstyle", "steps-mid" if n_data_points < 100 else "default")
90

91 2
        if ecdf_fill:
92 2
            if fill_kwargs is None:
93 2
                fill_kwargs = {}
94 2
            fill_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
95 2
            fill_kwargs.setdefault("alpha", 0.5)
96 2
            fill_kwargs.setdefault(
97
                "step", "mid" if plot_kwargs["drawstyle"] == "steps-mid" else None
98
            )
99 2
            fill_kwargs.setdefault(
100
                "legend_label", "{:.3g}% credible interval".format(credible_interval)
101
            )
102 2
    elif use_hdi:
103 2
        if hdi_kwargs is None:
104 2
            hdi_kwargs = {}
105 2
        hdi_kwargs.setdefault("color", to_hex(hsv_to_rgb(light_color)))
106 2
        hdi_kwargs.setdefault("alpha", 0.35)
107

108 2
    if ecdf:
109 2
        if plot_kwargs.get("drawstyle") == "steps-mid":
110 2
            ax.step(
111
                np.hstack((0, loo_pit, 1)),
112
                np.hstack((0, loo_pit - loo_pit_ecdf, 0)),
113
                line_color=plot_kwargs.get("color", "black"),
114
                line_alpha=plot_kwargs.get("alpha", 1.0),
115
                line_width=plot_kwargs.get("linewidth", 3.0),
116
                mode="center",
117
            )
118
        else:
119 0
            ax.line(
120
                np.hstack((0, loo_pit, 1)),
121
                np.hstack((0, loo_pit - loo_pit_ecdf, 0)),
122
                line_color=plot_kwargs.get("color", "black"),
123
                line_alpha=plot_kwargs.get("alpha", 1.0),
124
                line_width=plot_kwargs.get("linewidth", 3.0),
125
            )
126

127 2
        if ecdf_fill:
128 2
            if fill_kwargs.get("drawstyle") == "steps-mid":
129
                # use step patch when you find out how to do that
130 0
                ax.patch(
131
                    np.concatenate((unif_ecdf, unif_ecdf[::-1])),
132
                    np.concatenate((p975 - unif_ecdf, (p025 - unif_ecdf)[::-1])),
133
                    fill_color=fill_kwargs.get("color"),
134
                    fill_alpha=fill_kwargs.get("alpha", 1.0),
135
                )
136
            else:
137 2
                ax.patch(
138
                    np.concatenate((unif_ecdf, unif_ecdf[::-1])),
139
                    np.concatenate((p975 - unif_ecdf, (p025 - unif_ecdf)[::-1])),
140
                    fill_color=fill_kwargs.get("color"),
141
                    fill_alpha=fill_kwargs.get("alpha", 1.0),
142
                )
143
        else:
144 2
            if fill_kwargs is not None and fill_kwargs.get("drawstyle") == "steps-mid":
145 0
                ax.step(
146
                    unif_ecdf,
147
                    p975 - unif_ecdf,
148
                    line_color=plot_unif_kwargs.get("color", "black"),
149
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
150
                    line_width=plot_kwargs.get("linewidth", 1.0),
151
                    mode="center",
152
                )
153 0
                ax.step(
154
                    unif_ecdf,
155
                    p025 - unif_ecdf,
156
                    line_color=plot_unif_kwargs.get("color", "black"),
157
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
158
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
159
                    mode="center",
160
                )
161
            else:
162 2
                ax.line(
163
                    unif_ecdf,
164
                    p975 - unif_ecdf,
165
                    line_color=plot_unif_kwargs.get("color", "black"),
166
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
167
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
168
                )
169 2
                ax.line(
170
                    unif_ecdf,
171
                    p025 - unif_ecdf,
172
                    line_color=plot_unif_kwargs.get("color", "black"),
173
                    line_alpha=plot_unif_kwargs.get("alpha", 1.0),
174
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
175
                )
176
    else:
177 2
        if use_hdi:
178 2
            ax.add_layout(
179
                BoxAnnotation(
180
                    bottom=hdi_odds[1],
181
                    top=hdi_odds[0],
182
                    fill_alpha=hdi_kwargs.pop("alpha"),
183
                    fill_color=hdi_kwargs.pop("color"),
184
                    **hdi_kwargs
185
                )
186
            )
187
        else:
188 2
            for idx in range(n_unif):
189 2
                x_s, unif_density = kde(unif[idx, :])
190 2
                ax.line(
191
                    x_s,
192
                    unif_density,
193
                    line_color=plot_unif_kwargs.get("color", "black"),
194
                    line_alpha=plot_unif_kwargs.get("alpha", 0.1),
195
                    line_width=plot_unif_kwargs.get("linewidth", 1.0),
196
                )
197 2
        ax.line(
198
            x_vals,
199
            loo_pit_kde,
200
            line_color=plot_kwargs.get("color", "black"),
201
            line_alpha=plot_kwargs.get("alpha", 1.0),
202
            line_width=plot_kwargs.get("linewidth", 3.0),
203
        )
204

205 2
    show_layout(ax, show)
206

207 2
    return ax

Read our documentation on viewing source code .

Loading