1
"""Matplotlib Plot posterior densities."""
2 2
from numbers import Number
3

4 2
import matplotlib.pyplot as plt
5 2
import numpy as np
6

7 2
from ....stats import hdi
8 2
from ....stats.density_utils import get_bins
9 2
from ...kdeplot import plot_kde
10 2
from ...plot_utils import (
11
    _scale_fig_size,
12
    calculate_point_estimate,
13
    format_sig_figs,
14
    make_label,
15
    round_num,
16
)
17 2
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
18

19

20 2
def plot_posterior(
21
    ax,
22
    length_plotters,
23
    rows,
24
    cols,
25
    figsize,
26
    plotters,
27
    bw,
28
    circular,
29
    bins,
30
    kind,
31
    point_estimate,
32
    round_to,
33
    hdi_prob,
34
    multimodal,
35
    textsize,
36
    ref_val,
37
    rope,
38
    kwargs,
39
    backend_kwargs,
40
    show,
41
):
42
    """Matplotlib posterior plot."""
43 2
    if backend_kwargs is None:
44 2
        backend_kwargs = {}
45

46 2
    backend_kwargs = {
47
        **backend_kwarg_defaults(),
48
        **backend_kwargs,
49
    }
50

51 2
    (figsize, ax_labelsize, titlesize, xt_labelsize, _linewidth, _) = _scale_fig_size(
52
        figsize, textsize, rows, cols
53
    )
54 2
    backend_kwargs.setdefault("figsize", figsize)
55 2
    backend_kwargs.setdefault("squeeze", True)
56

57 2
    if kind == "hist":
58 2
        kwargs = matplotlib_kwarg_dealiaser(kwargs, "hist")
59
    else:
60 2
        kwargs = matplotlib_kwarg_dealiaser(kwargs, "plot")
61 2
    kwargs.setdefault("linewidth", _linewidth)
62

63 2
    if ax is None:
64 2
        _, ax = create_axes_grid(
65
            length_plotters,
66
            rows,
67
            cols,
68
            backend_kwargs=backend_kwargs,
69
        )
70 2
    idx = 0
71 2
    for (var_name, selection, x), ax_ in zip(plotters, np.ravel(ax)):
72 2
        _plot_posterior_op(
73
            idx,
74
            x.flatten(),
75
            var_name,
76
            selection,
77
            ax=ax_,
78
            bw=bw,
79
            circular=circular,
80
            bins=bins,
81
            kind=kind,
82
            point_estimate=point_estimate,
83
            round_to=round_to,
84
            hdi_prob=hdi_prob,
85
            multimodal=multimodal,
86
            ref_val=ref_val,
87
            rope=rope,
88
            ax_labelsize=ax_labelsize,
89
            xt_labelsize=xt_labelsize,
90
            **kwargs,
91
        )
92 2
        idx += 1
93 2
        ax_.set_title(make_label(var_name, selection), fontsize=titlesize, wrap=True)
94

95 2
    if backend_show(show):
96 0
        plt.show()
97

98 2
    return ax
99

100

101 2
def _plot_posterior_op(
102
    idx,
103
    values,
104
    var_name,
105
    selection,
106
    ax,
107
    bw,
108
    circular,
109
    linewidth,
110
    bins,
111
    kind,
112
    point_estimate,
113
    hdi_prob,
114
    multimodal,
115
    ref_val,
116
    rope,
117
    ax_labelsize,
118
    xt_labelsize,
119
    round_to=None,
120
    **kwargs,
121
):  # noqa: D202
122
    """Artist to draw posterior."""
123

124 2
    def format_as_percent(x, round_to=0):
125 2
        return "{0:.{1:d}f}%".format(100 * x, round_to)
126

127 2
    def display_ref_val():
128 2
        if ref_val is None:
129 2
            return
130 2
        elif isinstance(ref_val, dict):
131 2
            val = None
132 2
            for sel in ref_val.get(var_name, []):
133 2
                if all(
134
                    k in selection and selection[k] == v for k, v in sel.items() if k != "ref_val"
135
                ):
136 2
                    val = sel["ref_val"]
137 2
                    break
138 2
            if val is None:
139 2
                return
140 2
        elif isinstance(ref_val, list):
141 0
            val = ref_val[idx]
142 2
        elif isinstance(ref_val, Number):
143 2
            val = ref_val
144
        else:
145 2
            raise ValueError(
146
                "Argument `ref_val` must be None, a constant, a list or a "
147
                'dictionary like {"var_name": [{"ref_val": ref_val}]}'
148
            )
149 2
        less_than_ref_probability = (values < val).mean()
150 2
        greater_than_ref_probability = (values >= val).mean()
151 2
        ref_in_posterior = "{} <{:g}< {}".format(
152
            format_as_percent(less_than_ref_probability, 1),
153
            val,
154
            format_as_percent(greater_than_ref_probability, 1),
155
        )
156 2
        ax.axvline(val, ymin=0.05, ymax=0.75, color="C1", lw=linewidth, alpha=0.65)
157 2
        ax.text(
158
            values.mean(),
159
            plot_height * 0.6,
160
            ref_in_posterior,
161
            size=ax_labelsize,
162
            color="C1",
163
            weight="semibold",
164
            horizontalalignment="center",
165
        )
166

167 2
    def display_rope():
168 2
        if rope is None:
169 2
            return
170 2
        elif isinstance(rope, dict):
171 2
            vals = None
172 2
            for sel in rope.get(var_name, []):
173
                # pylint: disable=line-too-long
174 2
                if all(k in selection and selection[k] == v for k, v in sel.items() if k != "rope"):
175 2
                    vals = sel["rope"]
176 2
                    break
177 2
            if vals is None:
178 2
                return
179 2
        elif len(rope) == 2:
180 2
            vals = rope
181
        else:
182 2
            raise ValueError(
183
                "Argument `rope` must be None, a dictionary like"
184
                '{"var_name": {"rope": (lo, hi)}}, or an'
185
                "iterable of length 2"
186
            )
187

188 2
        ax.plot(
189
            vals,
190
            (plot_height * 0.02, plot_height * 0.02),
191
            lw=linewidth * 5,
192
            color="C2",
193
            solid_capstyle="butt",
194
            zorder=0,
195
            alpha=0.7,
196
        )
197 2
        text_props = {"size": ax_labelsize, "color": "C2"}
198 2
        ax.text(
199
            vals[0],
200
            plot_height * 0.2,
201
            f"{vals[0]} ",
202
            weight="semibold",
203
            horizontalalignment="right",
204
            **text_props,
205
        )
206 2
        ax.text(
207
            vals[1],
208
            plot_height * 0.2,
209
            f" {vals[1]}",
210
            weight="semibold",
211
            horizontalalignment="left",
212
            **text_props,
213
        )
214

215 2
    def display_point_estimate():
216 2
        if not point_estimate:
217 2
            return
218 2
        point_value = calculate_point_estimate(point_estimate, values, bw, circular)
219 2
        sig_figs = format_sig_figs(point_value, round_to)
220 2
        point_text = "{point_estimate}={point_value:.{sig_figs}g}".format(
221
            point_estimate=point_estimate, point_value=point_value, sig_figs=sig_figs
222
        )
223 2
        ax.text(
224
            point_value,
225
            plot_height * 0.8,
226
            point_text,
227
            size=ax_labelsize,
228
            horizontalalignment="center",
229
        )
230

231 2
    def display_hdi():
232
        # np.ndarray with 2 entries, min and max
233
        # pylint: disable=line-too-long
234 2
        hdi_probs = hdi(values, hdi_prob=hdi_prob, multimodal=multimodal)  # type: np.ndarray
235

236 2
        for hdi_i in np.atleast_2d(hdi_probs):
237 2
            ax.plot(
238
                hdi_i,
239
                (plot_height * 0.02, plot_height * 0.02),
240
                lw=linewidth * 2,
241
                color="k",
242
                solid_capstyle="butt",
243
            )
244 2
            ax.text(
245
                hdi_i[0],
246
                plot_height * 0.07,
247
                round_num(hdi_i[0], round_to) + " ",
248
                size=ax_labelsize,
249
                horizontalalignment="right",
250
            )
251 2
            ax.text(
252
                hdi_i[1],
253
                plot_height * 0.07,
254
                " " + round_num(hdi_i[1], round_to),
255
                size=ax_labelsize,
256
                horizontalalignment="left",
257
            )
258 2
            ax.text(
259
                (hdi_i[0] + hdi_i[1]) / 2,
260
                plot_height * 0.3,
261
                format_as_percent(hdi_prob) + " HDI",
262
                size=ax_labelsize,
263
                horizontalalignment="center",
264
            )
265

266 2
    def format_axes():
267 2
        ax.yaxis.set_ticks([])
268 2
        ax.spines["top"].set_visible(False)
269 2
        ax.spines["right"].set_visible(False)
270 2
        ax.spines["left"].set_visible(False)
271 2
        ax.spines["bottom"].set_visible(True)
272 2
        ax.xaxis.set_ticks_position("bottom")
273 2
        ax.tick_params(
274
            axis="x", direction="out", width=1, length=3, color="0.5", labelsize=xt_labelsize
275
        )
276 2
        ax.spines["bottom"].set_color("0.5")
277

278 2
    if kind == "kde" and values.dtype.kind == "f":
279 2
        kwargs.setdefault("linewidth", linewidth)
280 2
        plot_kde(
281
            values,
282
            bw=bw,
283
            circular=circular,
284
            fill_kwargs={"alpha": kwargs.pop("fill_alpha", 0)},
285
            plot_kwargs=kwargs,
286
            ax=ax,
287
            rug=False,
288
            show=False,
289
        )
290
    else:
291 2
        if bins is None:
292 2
            if values.dtype.kind == "i":
293 2
                xmin = values.min()
294 2
                xmax = values.max()
295 2
                bins = get_bins(values)
296 2
                ax.set_xlim(xmin - 0.5, xmax + 0.5)
297
            else:
298 2
                bins = "auto"
299 2
        kwargs.setdefault("align", "left")
300 2
        kwargs.setdefault("color", "C0")
301 2
        ax.hist(values, bins=bins, alpha=0.35, **kwargs)
302

303 2
    plot_height = ax.get_ylim()[1]
304

305 2
    format_axes()
306 2
    if hdi_prob != "hide":
307 2
        display_hdi()
308 2
    display_point_estimate()
309 2
    display_ref_val()
310 2
    display_rope()

Read our documentation on viewing source code .

Loading