1
"""Matplotlib Density Comparison plot."""
2 4
import matplotlib.pyplot as plt
3 4
import numpy as np
4

5 4
from ...distplot import plot_dist
6 4
from ...plot_utils import _scale_fig_size, make_label
7 4
from . import backend_kwarg_defaults, backend_show
8

9

10 4
def plot_dist_comparison(
11
    ax,
12
    nvars,
13
    ngroups,
14
    figsize,
15
    dc_plotters,
16
    legend,
17
    groups,
18
    textsize,
19
    prior_kwargs,
20
    posterior_kwargs,
21
    observed_kwargs,
22
    backend_kwargs,
23
    show,
24
):
25
    """Matplotlib Density Comparison plot."""
26 4
    if backend_kwargs is None:
27 4
        backend_kwargs = {}
28

29 4
    backend_kwargs = {
30
        **backend_kwarg_defaults(),
31
        **backend_kwargs,
32
    }
33

34 4
    if prior_kwargs is None:
35 4
        prior_kwargs = {}
36

37 4
    if posterior_kwargs is None:
38 4
        posterior_kwargs = {}
39

40 4
    if observed_kwargs is None:
41 4
        observed_kwargs = {}
42

43 4
    if backend_kwargs is None:
44 0
        backend_kwargs = {}
45

46 4
    (figsize, _, _, _, linewidth, _) = _scale_fig_size(figsize, textsize, 2 * nvars, ngroups)
47

48 4
    backend_kwargs.setdefault("figsize", figsize)
49

50 4
    posterior_kwargs.setdefault("plot_kwargs", dict())
51 4
    posterior_kwargs["plot_kwargs"]["color"] = posterior_kwargs["plot_kwargs"].get("color", "C0")
52 4
    posterior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
53 4
    posterior_kwargs.setdefault("hist_kwargs", dict())
54 4
    posterior_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
55

56 4
    prior_kwargs.setdefault("plot_kwargs", dict())
57 4
    prior_kwargs["plot_kwargs"]["color"] = prior_kwargs["plot_kwargs"].get("color", "C1")
58 4
    prior_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
59 4
    prior_kwargs.setdefault("hist_kwargs", dict())
60 4
    prior_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
61

62 4
    observed_kwargs.setdefault("plot_kwargs", dict())
63 4
    observed_kwargs["plot_kwargs"]["color"] = observed_kwargs["plot_kwargs"].get("color", "C2")
64 4
    observed_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
65 4
    observed_kwargs.setdefault("hist_kwargs", dict())
66 4
    observed_kwargs["hist_kwargs"].setdefault("alpha", 0.5)
67

68 4
    if ax is None:
69 4
        axes = np.empty((nvars, ngroups + 1), dtype=object)
70 4
        fig = plt.figure(**backend_kwargs)
71 4
        gs = fig.add_gridspec(ncols=ngroups, nrows=nvars * 2)
72 4
        for i in range(nvars):
73 4
            for j in range(ngroups):
74 4
                axes[i, j] = fig.add_subplot(gs[2 * i, j])
75 4
            axes[i, -1] = fig.add_subplot(gs[2 * i + 1, :])
76

77
    else:
78 0
        axes = ax
79 0
        if ax.shape != (nvars, ngroups + 1):
80 0
            raise ValueError(
81
                "Found {} shape of axes, which is not equal to data shape {}.".format(
82
                    axes.shape, (nvars, ngroups + 1)
83
                )
84
            )
85

86 4
    for idx, plotter in enumerate(dc_plotters):
87 4
        group = groups[idx]
88 4
        kwargs = (
89
            prior_kwargs
90
            if group.startswith("prior")
91
            else posterior_kwargs
92
            if group.startswith("posterior")
93
            else observed_kwargs
94
        )
95 4
        for idx2, (
96
            var,
97
            selection,
98
            data,
99
        ) in enumerate(plotter):
100 4
            label = make_label(var, selection)
101 4
            label = f"{group} {label}"
102 4
            plot_dist(
103
                data,
104
                label=label if legend else None,
105
                ax=axes[idx2, idx],
106
                **kwargs,
107
            )
108 4
            plot_dist(
109
                data,
110
                label=label if legend else None,
111
                ax=axes[idx2, -1],
112
                **kwargs,
113
            )
114

115 4
    if backend_show(show):
116 0
        plt.show()
117

118 4
    return axes

Read our documentation on viewing source code .

Loading