1
"""Matplotlib ELPDPlot."""
2 4
import warnings
3

4 4
import matplotlib.cm as cm
5 4
import matplotlib.pyplot as plt
6 4
import numpy as np
7 4
from matplotlib.lines import Line2D
8

9 4
from ....rcparams import rcParams
10 4
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels
11 4
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
12

13

14 4
def plot_elpd(
15
    ax,
16
    models,
17
    pointwise_data,
18
    numvars,
19
    figsize,
20
    textsize,
21
    plot_kwargs,
22
    xlabels,
23
    coord_labels,
24
    xdata,
25
    threshold,
26
    legend,
27
    color,
28
    backend_kwargs,
29
    show,
30
):
31
    """Matplotlib elpd plot."""
32 4
    if backend_kwargs is None:
33 4
        backend_kwargs = {}
34

35 4
    backend_kwargs = {
36
        **backend_kwarg_defaults(),
37
        **backend_kwargs,
38
    }
39 4
    backend_kwargs.setdefault("constrained_layout", not xlabels)
40

41 4
    plot_kwargs = matplotlib_kwarg_dealiaser(plot_kwargs, "scatter")
42

43 4
    markersize = None
44

45 4
    if isinstance(color, str):
46 4
        if color in pointwise_data[0].dims:
47 4
            colors, color_mapping = color_from_dim(pointwise_data[0], color)
48 4
            cmap_name = plot_kwargs.pop("cmap", plt.rcParams["image.cmap"])
49 4
            markersize = plot_kwargs.pop("s", plt.rcParams["lines.markersize"])
50 4
            cmap = getattr(cm, cmap_name)
51 4
            handles = [
52
                Line2D(
53
                    [], [], color=cmap(float_color), label=coord, ms=markersize, lw=0, **plot_kwargs
54
                )
55
                for coord, float_color in color_mapping.items()
56
            ]
57 4
            plot_kwargs.setdefault("cmap", cmap_name)
58 4
            plot_kwargs.setdefault("s", markersize ** 2)
59 4
            plot_kwargs.setdefault("c", colors)
60
        else:
61 4
            legend = False
62
    else:
63 4
        legend = False
64 4
    plot_kwargs.setdefault("c", color)
65

66
    # flatten data (data must be flattened after selecting, labeling and coloring)
67 4
    pointwise_data = [pointwise.values.flatten() for pointwise in pointwise_data]
68

69 4
    if numvars == 2:
70 4
        (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
71
            figsize, textsize, numvars - 1, numvars - 1
72
        )
73 4
        plot_kwargs.setdefault("s", markersize ** 2)
74 4
        backend_kwargs.setdefault("figsize", figsize)
75 4
        backend_kwargs["squeeze"] = True
76 4
        if ax is None:
77 4
            fig, ax = create_axes_grid(
78
                1,
79
                backend_kwargs=backend_kwargs,
80
            )
81

82 4
        ydata = pointwise_data[0] - pointwise_data[1]
83 4
        ax.scatter(xdata, ydata, **plot_kwargs)
84 4
        if threshold is not None:
85 4
            diff_abs = np.abs(ydata - ydata.mean())
86 4
            bool_ary = diff_abs > threshold * ydata.std()
87 4
            if coord_labels is None:
88 4
                coord_labels = xdata.astype(str)
89 4
            outliers = np.argwhere(bool_ary).squeeze()
90 4
            for outlier in outliers:
91 4
                label = coord_labels[outlier]
92 4
                ax.text(
93
                    outlier,
94
                    ydata[outlier],
95
                    label,
96
                    horizontalalignment="center",
97
                    verticalalignment="bottom" if ydata[outlier] > 0 else "top",
98
                    fontsize=0.8 * xt_labelsize,
99
                )
100

101 4
        ax.set_title("{} - {}".format(*models), fontsize=titlesize, wrap=True)
102 4
        ax.set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)
103 4
        ax.tick_params(labelsize=xt_labelsize)
104 4
        if xlabels:
105 4
            set_xticklabels(ax, coord_labels)
106 4
            fig.autofmt_xdate()
107 4
            fig.tight_layout()
108 4
        if legend:
109 4
            ncols = len(handles) // 6 + 1
110 4
            ax.legend(handles=handles, ncol=ncols, title=color)
111

112
    else:
113 4
        max_plots = (
114
            numvars ** 2 if rcParams["plot.max_subplots"] is None else rcParams["plot.max_subplots"]
115
        )
116 4
        vars_to_plot = np.sum(np.arange(numvars).cumsum() < max_plots)
117 4
        if vars_to_plot < numvars:
118 0
            warnings.warn(
119
                "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
120
                "of resulting ELPD pairwise plots with these variables, generating only a "
121
                "{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot),
122
                UserWarning,
123
            )
124 0
            numvars = vars_to_plot
125

126 4
        (figsize, ax_labelsize, titlesize, xt_labelsize, _, markersize) = _scale_fig_size(
127
            figsize, textsize, numvars - 2, numvars - 2
128
        )
129 4
        plot_kwargs.setdefault("s", markersize ** 2)
130

131 4
        if ax is None:
132 4
            fig, ax = plt.subplots(
133
                numvars - 1,
134
                numvars - 1,
135
                figsize=figsize,
136
                squeeze=False,
137
                constrained_layout=not xlabels,
138
                sharey="row",
139
                sharex="all",
140
            )
141

142 4
        for i in range(0, numvars - 1):
143 4
            var1 = pointwise_data[i]
144

145 4
            for j in range(0, numvars - 1):
146 4
                if j < i:
147 4
                    ax[j, i].axis("off")
148 4
                    continue
149

150 4
                var2 = pointwise_data[j + 1]
151 4
                ax[j, i].scatter(xdata, var1 - var2, **plot_kwargs)
152 4
                if threshold is not None:
153 4
                    ydata = var1 - var2
154 4
                    diff_abs = np.abs(ydata - ydata.mean())
155 4
                    bool_ary = diff_abs > threshold * ydata.std()
156 4
                    if coord_labels is None:
157 4
                        coord_labels = xdata.astype(str)
158 4
                    outliers = np.argwhere(bool_ary).squeeze()
159 4
                    for outlier in outliers:
160 4
                        label = coord_labels[outlier]
161 4
                        ax[j, i].text(
162
                            outlier,
163
                            ydata[outlier],
164
                            label,
165
                            horizontalalignment="center",
166
                            verticalalignment="bottom" if ydata[outlier] > 0 else "top",
167
                            fontsize=0.8 * xt_labelsize,
168
                        )
169

170 4
                if i == 0:
171 4
                    ax[j, i].set_ylabel("ELPD difference", fontsize=ax_labelsize, wrap=True)
172

173 4
                ax[j, i].tick_params(labelsize=xt_labelsize)
174 4
                ax[j, i].set_title(
175
                    "{} - {}".format(models[i], models[j + 1]), fontsize=titlesize, wrap=True
176
                )
177 4
        if xlabels:
178 4
            set_xticklabels(ax[-1, -1], coord_labels)
179 4
            fig.autofmt_xdate()
180 4
            fig.tight_layout()
181 4
        if legend:
182 4
            ncols = len(handles) // 6 + 1
183 4
            ax[0, 1].legend(
184
                handles=handles, ncol=ncols, title=color, bbox_to_anchor=(0, 1), loc="upper left"
185
            )
186

187 4
    if backend_show(show):
188 0
        plt.show()
189

190 4
    return ax

Read our documentation on viewing source code .

Loading