1
"""Bokeh Densityplot."""
2 2
from collections import defaultdict
3 2
from itertools import cycle
4

5 2
import matplotlib.pyplot as plt
6 2
import numpy as np
7 2
from bokeh.models.annotations import Legend, Title
8

9 2
from ....stats import hdi
10 2
from ....stats.density_utils import get_bins, histogram, kde
11 2
from ...plot_utils import _scale_fig_size, calculate_point_estimate, make_label, vectorized_to_hex
12 2
from .. import show_layout
13 2
from . import backend_kwarg_defaults, create_axes_grid
14

15

16 2
def plot_density(
17
    ax,
18
    all_labels,
19
    to_plot,
20
    colors,
21
    bw,
22
    circular,
23
    figsize,
24
    length_plotters,
25
    rows,
26
    cols,
27
    textsize,
28
    hdi_prob,
29
    point_estimate,
30
    hdi_markers,
31
    outline,
32
    shade,
33
    n_data,
34
    data_labels,
35
    backend_kwargs,
36
    show,
37
):
38
    """Bokeh density plot."""
39 2
    if backend_kwargs is None:
40 2
        backend_kwargs = {}
41

42 2
    backend_kwargs = {
43
        **backend_kwarg_defaults(),
44
        **backend_kwargs,
45
    }
46

47 2
    if colors == "cycle":
48 2
        colors = [
49
            prop
50
            for _, prop in zip(
51
                range(n_data), cycle(plt.rcParams["axes.prop_cycle"].by_key()["color"])
52
            )
53
        ]
54 0
    elif isinstance(colors, str):
55 0
        colors = [colors for _ in range(n_data)]
56 2
    colors = vectorized_to_hex(colors)
57

58 2
    (figsize, _, _, _, line_width, markersize) = _scale_fig_size(figsize, textsize, rows, cols)
59

60 2
    if ax is None:
61 2
        ax = create_axes_grid(
62
            length_plotters,
63
            rows,
64
            cols,
65
            figsize=figsize,
66
            squeeze=True,
67
            backend_kwargs=backend_kwargs,
68
        )
69
    else:
70 0
        ax = np.atleast_2d(ax)
71

72 2
    axis_map = {
73
        label: ax_
74
        for label, ax_ in zip(all_labels, (item for item in ax.flatten() if item is not None))
75
    }
76 2
    if data_labels is None:
77 0
        data_labels = {}
78

79 2
    legend_items = defaultdict(list)
80 2
    for m_idx, plotters in enumerate(to_plot):
81 2
        for var_name, selection, values in plotters:
82 2
            label = make_label(var_name, selection)
83

84 2
            if data_labels:
85 2
                data_label = data_labels[m_idx]
86
            else:
87 0
                data_label = None
88

89 2
            plotted = _d_helper(
90
                values.flatten(),
91
                label,
92
                colors[m_idx],
93
                bw,
94
                circular,
95
                line_width,
96
                markersize,
97
                hdi_prob,
98
                point_estimate,
99
                hdi_markers,
100
                outline,
101
                shade,
102
                axis_map[label],
103
            )
104 2
            if data_label is not None:
105 2
                legend_items[axis_map[label]].append((data_label, plotted))
106

107 2
    for ax1, legend in legend_items.items():
108 2
        legend = Legend(
109
            items=legend,
110
            location="center_right",
111
            orientation="horizontal",
112
        )
113 2
        ax1.add_layout(legend, "above")
114 2
        ax1.legend.click_policy = "hide"
115

116 2
    show_layout(ax, show)
117

118 2
    return ax
119

120

121 2
def _d_helper(
122
    vec,
123
    vname,
124
    color,
125
    bw,
126
    circular,
127
    line_width,
128
    markersize,
129
    hdi_prob,
130
    point_estimate,
131
    hdi_markers,
132
    outline,
133
    shade,
134
    ax,
135
):
136

137 2
    extra = dict()
138 2
    plotted = []
139

140 2
    if vec.dtype.kind == "f":
141 2
        if hdi_prob != 1:
142 2
            hdi_ = hdi(vec, hdi_prob, multimodal=False)
143 2
            new_vec = vec[(vec >= hdi_[0]) & (vec <= hdi_[1])]
144
        else:
145 2
            new_vec = vec
146

147 2
        x, density = kde(new_vec, circular=circular, bw=bw)
148 2
        density *= hdi_prob
149 2
        xmin, xmax = x[0], x[-1]
150 2
        ymin, ymax = density[0], density[-1]
151

152 2
        if outline:
153 2
            plotted.append(ax.line(x, density, line_color=color, line_width=line_width, **extra))
154 2
            plotted.append(
155
                ax.line(
156
                    [xmin, xmin],
157
                    [-ymin / 100, ymin],
158
                    line_color=color,
159
                    line_dash="solid",
160
                    line_width=line_width,
161
                    muted_color=color,
162
                    muted_alpha=0.2,
163
                )
164
            )
165 2
            plotted.append(
166
                ax.line(
167
                    [xmax, xmax],
168
                    [-ymax / 100, ymax],
169
                    line_color=color,
170
                    line_dash="solid",
171
                    line_width=line_width,
172
                    muted_color=color,
173
                    muted_alpha=0.2,
174
                )
175
            )
176

177 2
        if shade:
178 2
            plotted.append(
179
                ax.patch(
180
                    np.r_[x[::-1], x, x[-1:]],
181
                    np.r_[np.zeros_like(x), density, [0]],
182
                    fill_color=color,
183
                    fill_alpha=shade,
184
                    muted_color=color,
185
                    muted_alpha=0.2,
186
                    **extra
187
                )
188
            )
189

190
    else:
191 2
        xmin, xmax = hdi(vec, hdi_prob, multimodal=False)
192 2
        bins = get_bins(vec)
193

194 2
        _, hist, edges = histogram(vec, bins=bins)
195

196 2
        if outline:
197 2
            plotted.append(
198
                ax.quad(
199
                    top=hist,
200
                    bottom=0,
201
                    left=edges[:-1],
202
                    right=edges[1:],
203
                    line_color=color,
204
                    fill_color=None,
205
                    muted_color=color,
206
                    muted_alpha=0.2,
207
                    **extra
208
                )
209
            )
210
        else:
211 0
            plotted.append(
212
                ax.quad(
213
                    top=hist,
214
                    bottom=0,
215
                    left=edges[:-1],
216
                    right=edges[1:],
217
                    line_color=color,
218
                    fill_color=color,
219
                    fill_alpha=shade,
220
                    muted_color=color,
221
                    muted_alpha=0.2,
222
                    **extra
223
                )
224
            )
225

226 2
    if hdi_markers:
227 2
        plotted.append(ax.diamond(xmin, 0, line_color="black", fill_color=color, size=markersize))
228 2
        plotted.append(ax.diamond(xmax, 0, line_color="black", fill_color=color, size=markersize))
229

230 2
    if point_estimate is not None:
231 2
        est = calculate_point_estimate(point_estimate, vec, bw, circular)
232 2
        plotted.append(ax.circle(est, 0, fill_color=color, line_color="black", size=markersize))
233

234 2
    _title = Title()
235 2
    _title.text = vname
236 2
    ax.title = _title
237 2
    ax.title.text_font_size = "13pt"
238

239 2
    return plotted

Read our documentation on viewing source code .

Loading