1
"""Bokeh rankplot."""
2 4
import numpy as np
3

4 4
from bokeh.models import Span
5 4
from bokeh.models.annotations import Title
6 4
from bokeh.models.tickers import FixedTicker
7

8 4
from ....stats.density_utils import histogram
9 4
from ...plot_utils import _scale_fig_size, make_label, compute_ranks
10 4
from .. import show_layout
11 4
from . import backend_kwarg_defaults, create_axes_grid
12

13

14 4
def plot_rank(
15
    axes,
16
    length_plotters,
17
    rows,
18
    cols,
19
    figsize,
20
    plotters,
21
    bins,
22
    kind,
23
    colors,
24
    ref_line,
25
    labels,
26
    ref_line_kwargs,
27
    bar_kwargs,
28
    vlines_kwargs,
29
    marker_vlines_kwargs,
30
    backend_kwargs,
31
    show,
32
):
33
    """Bokeh rank plot."""
34 4
    if ref_line_kwargs is None:
35 4
        ref_line_kwargs = {}
36 4
    ref_line_kwargs.setdefault("line_dash", "dashed")
37 4
    ref_line_kwargs.setdefault("line_color", "black")
38

39 4
    if bar_kwargs is None:
40 4
        bar_kwargs = {}
41 4
    bar_kwargs.setdefault("line_color", "white")
42

43 4
    if vlines_kwargs is None:
44 4
        vlines_kwargs = {}
45 4
    vlines_kwargs.setdefault("line_width", 2)
46 4
    vlines_kwargs.setdefault("line_dash", "solid")
47

48 4
    if marker_vlines_kwargs is None:
49 4
        marker_vlines_kwargs = {}
50

51 4
    if backend_kwargs is None:
52 4
        backend_kwargs = {}
53

54 4
    backend_kwargs = {
55
        **backend_kwarg_defaults(
56
            ("dpi", "plot.bokeh.figure.dpi"),
57
        ),
58
        **backend_kwargs,
59
    }
60 4
    figsize, *_ = _scale_fig_size(figsize, None, rows=rows, cols=cols)
61 4
    if axes is None:
62 4
        axes = create_axes_grid(
63
            length_plotters,
64
            rows,
65
            cols,
66
            figsize=figsize,
67
            sharex=True,
68
            sharey=True,
69
            backend_kwargs=backend_kwargs,
70
        )
71
    else:
72 4
        axes = np.atleast_2d(axes)
73

74 4
    for ax, (var_name, selection, var_data) in zip(
75
        (item for item in axes.flatten() if item is not None), plotters
76
    ):
77 4
        ranks = compute_ranks(var_data)
78 4
        bin_ary = np.histogram_bin_edges(ranks, bins=bins, range=(0, ranks.size))
79 4
        all_counts = np.empty((len(ranks), len(bin_ary) - 1))
80 4
        for idx, row in enumerate(ranks):
81 4
            _, all_counts[idx], _ = histogram(row, bins=bin_ary)
82 4
        counts_normalizer = all_counts.max() / 0.95
83 4
        gap = 1
84 4
        width = bin_ary[1] - bin_ary[0]
85

86 4
        bar_kwargs.setdefault("width", width)
87
        # Center the bins
88 4
        bin_ary = (bin_ary[1:] + bin_ary[:-1]) / 2
89

90 4
        y_ticks = []
91 4
        if kind == "bars":
92 4
            for idx, counts in enumerate(all_counts):
93 4
                counts = counts / counts_normalizer
94 4
                y_ticks.append(idx * gap)
95 4
                ax.vbar(
96
                    x=bin_ary,
97
                    top=y_ticks[-1] + counts,
98
                    bottom=y_ticks[-1],
99
                    fill_color=colors[idx],
100
                    **bar_kwargs,
101
                )
102 4
                if ref_line:
103 4
                    hline = Span(location=y_ticks[-1] + counts.mean(), **ref_line_kwargs)
104 4
                    ax.add_layout(hline)
105 4
            if labels:
106 4
                ax.yaxis.axis_label = "Chain"
107 4
        elif kind == "vlines":
108 4
            ymin = np.full(len(all_counts), all_counts.mean())
109 4
            for idx, counts in enumerate(all_counts):
110 4
                ax.circle(
111
                    bin_ary,
112
                    counts,
113
                    fill_color=colors[idx],
114
                    line_color=colors[idx],
115
                    **marker_vlines_kwargs,
116
                )
117 4
                x_locations = [(bin, bin) for bin in bin_ary]
118 4
                y_locations = [(ymin[idx], counts_) for counts_ in counts]
119 4
                ax.multi_line(x_locations, y_locations, line_color=colors[idx], **vlines_kwargs)
120

121 4
            if ref_line:
122 4
                hline = Span(location=all_counts.mean(), **ref_line_kwargs)
123 4
                ax.add_layout(hline)
124

125 4
        if labels:
126 4
            ax.xaxis.axis_label = "Rank (all chains)"
127

128 4
            ax.yaxis.ticker = FixedTicker(ticks=y_ticks)
129 4
            ax.xaxis.major_label_overrides = dict(
130
                zip(map(str, y_ticks), map(str, range(len(y_ticks))))
131
            )
132

133
        else:
134 0
            ax.yaxis.major_tick_line_color = None
135 0
            ax.yaxis.minor_tick_line_color = None
136

137 0
            ax.xaxis.major_label_text_font_size = "0pt"
138 0
            ax.yaxis.major_label_text_font_size = "0pt"
139

140 4
        _title = Title()
141 4
        _title.text = make_label(var_name, selection)
142 4
        ax.title = _title
143

144 4
    show_layout(axes, show)
145

146 4
    return axes

Read our documentation on viewing source code .

Loading