1
"""Matplotlib pairplot."""
2 4
import warnings
3

4 4
import matplotlib.pyplot as plt
5 4
import numpy as np
6 4
from matplotlib.ticker import NullFormatter
7 4
from mpl_toolkits.axes_grid1 import make_axes_locatable
8

9 4
from ....rcparams import rcParams
10 4
from ...distplot import plot_dist
11 4
from ...kdeplot import plot_kde
12 4
from ...plot_utils import _scale_fig_size, calculate_point_estimate
13 4
from . import backend_kwarg_defaults, backend_show, matplotlib_kwarg_dealiaser
14

15

16 4
def plot_pair(
17
    ax,
18
    plotters,
19
    numvars,
20
    figsize,
21
    textsize,
22
    kind,
23
    scatter_kwargs,
24
    kde_kwargs,
25
    hexbin_kwargs,
26
    gridsize,
27
    colorbar,
28
    divergences,
29
    diverging_mask,
30
    divergences_kwargs,
31
    flat_var_names,
32
    backend_kwargs,
33
    marginal_kwargs,
34
    show,
35
    marginals,
36
    point_estimate,
37
    point_estimate_kwargs,
38
    point_estimate_marker_kwargs,
39
    reference_values,
40
    reference_values_kwargs,
41
):
42
    """Matplotlib pairplot."""
43 4
    if backend_kwargs is None:
44 4
        backend_kwargs = {}
45

46 4
    backend_kwargs = {
47
        **backend_kwarg_defaults(),
48
        **backend_kwargs,
49
    }
50 4
    backend_kwargs.pop("constrained_layout")
51

52 4
    scatter_kwargs = matplotlib_kwarg_dealiaser(scatter_kwargs, "scatter")
53

54 4
    scatter_kwargs.setdefault("marker", ".")
55 4
    scatter_kwargs.setdefault("lw", 0)
56
    # Sets the default zorder higher than zorder of grid, which is 0.5
57 4
    scatter_kwargs.setdefault("zorder", 0.6)
58

59 4
    if kde_kwargs is None:
60 4
        kde_kwargs = {}
61

62 4
    if hexbin_kwargs is None:
63 4
        hexbin_kwargs = {}
64 4
    hexbin_kwargs.setdefault("mincnt", 1)
65

66 4
    divergences_kwargs = matplotlib_kwarg_dealiaser(divergences_kwargs, "plot")
67 4
    divergences_kwargs.setdefault("marker", "o")
68 4
    divergences_kwargs.setdefault("markeredgecolor", "k")
69 4
    divergences_kwargs.setdefault("color", "C1")
70 4
    divergences_kwargs.setdefault("lw", 0)
71

72 4
    if marginal_kwargs is None:
73 4
        marginal_kwargs = {}
74

75 4
    point_estimate_kwargs = matplotlib_kwarg_dealiaser(point_estimate_kwargs, "fill_between")
76 4
    point_estimate_kwargs.setdefault("color", "k")
77

78 4
    if kind != "kde":
79 4
        kde_kwargs.setdefault("contourf_kwargs", {})
80 4
        kde_kwargs["contourf_kwargs"].setdefault("alpha", 0)
81 4
        kde_kwargs.setdefault("contour_kwargs", {})
82 4
        kde_kwargs["contour_kwargs"].setdefault("colors", "k")
83

84 4
    if reference_values:
85 4
        reference_values_copy = {}
86 4
        label = []
87 4
        for variable in list(reference_values.keys()):
88 4
            if " " in variable:
89 0
                variable_copy = variable.replace(" ", "\n", 1)
90
            else:
91 4
                variable_copy = variable
92

93 4
            label.append(variable_copy)
94 4
            reference_values_copy[variable_copy] = reference_values[variable]
95

96 4
        difference = set(flat_var_names).difference(set(label))
97

98 4
        if difference:
99 4
            warn = [diff.replace("\n", " ", 1) for diff in difference]
100 4
            warnings.warn(
101
                "Argument reference_values does not include reference value for: {}".format(
102
                    ", ".join(warn)
103
                ),
104
                UserWarning,
105
            )
106

107 4
    reference_values_kwargs = matplotlib_kwarg_dealiaser(reference_values_kwargs, "plot")
108

109 4
    reference_values_kwargs.setdefault("color", "C2")
110 4
    reference_values_kwargs.setdefault("markeredgecolor", "k")
111 4
    reference_values_kwargs.setdefault("marker", "o")
112

113 4
    point_estimate_marker_kwargs = matplotlib_kwarg_dealiaser(
114
        point_estimate_marker_kwargs, "scatter"
115
    )
116 4
    point_estimate_marker_kwargs.setdefault("marker", "s")
117 4
    point_estimate_marker_kwargs.setdefault("color", "k")
118

119
    # pylint: disable=too-many-nested-blocks
120 4
    if numvars == 2:
121 4
        (figsize, ax_labelsize, _, xt_labelsize, linewidth, markersize) = _scale_fig_size(
122
            figsize, textsize, numvars - 1, numvars - 1
123
        )
124 4
        backend_kwargs.setdefault("figsize", figsize)
125

126 4
        marginal_kwargs.setdefault("plot_kwargs", {})
127 4
        marginal_kwargs["plot_kwargs"].setdefault("linewidth", linewidth)
128

129 4
        point_estimate_marker_kwargs.setdefault("s", markersize + 50)
130

131
        # Flatten data
132 4
        x = plotters[0][-1].flatten()
133 4
        y = plotters[1][-1].flatten()
134 4
        if ax is None:
135 4
            if marginals:
136
                # Instantiate figure and grid
137 4
                widths = [2, 2, 2, 1]
138 4
                heights = [1.4, 2, 2, 2]
139 4
                fig = plt.figure(**backend_kwargs)
140 4
                grid = plt.GridSpec(
141
                    4,
142
                    4,
143
                    hspace=0.1,
144
                    wspace=0.1,
145
                    figure=fig,
146
                    width_ratios=widths,
147
                    height_ratios=heights,
148
                )
149
                # Set up main plot
150 4
                ax = fig.add_subplot(grid[1:, :-1])
151
                # Set up top KDE
152 4
                ax_hist_x = fig.add_subplot(grid[0, :-1], sharex=ax)
153 4
                ax_hist_x.set_yticks([])
154
                # Set up right KDE
155 4
                ax_hist_y = fig.add_subplot(grid[1:, -1], sharey=ax)
156 4
                ax_hist_y.set_xticks([])
157 4
                ax_return = np.array([[ax_hist_x, None], [ax, ax_hist_y]])
158

159 4
                for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
160 4
                    plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs)
161

162
                # Personalize axes
163 4
                ax_hist_x.tick_params(labelleft=False, labelbottom=False)
164 4
                ax_hist_y.tick_params(labelleft=False, labelbottom=False)
165
            else:
166 4
                fig, ax = plt.subplots(numvars - 1, numvars - 1, **backend_kwargs)
167
        else:
168 4
            if marginals:
169 4
                assert ax.shape == (numvars, numvars)
170 4
                if ax[0, 1] is not None and ax[0, 1].get_figure() is not None:
171 0
                    ax[0, 1].remove()
172 4
                ax_return = ax
173 4
                ax_hist_x = ax[0, 0]
174 4
                ax_hist_y = ax[1, 1]
175 4
                ax = ax[1, 0]
176 4
                for val, ax_, rotate in ((x, ax_hist_x, False), (y, ax_hist_y, True)):
177 4
                    plot_dist(val, textsize=xt_labelsize, rotated=rotate, ax=ax_, **marginal_kwargs)
178
            else:
179 4
                ax = np.atleast_2d(ax)[0, 0]
180

181 4
        if "scatter" in kind:
182 4
            ax.plot(x, y, **scatter_kwargs)
183 4
        if "kde" in kind:
184 4
            plot_kde(x, y, ax=ax, **kde_kwargs)
185 4
        if "hexbin" in kind:
186 4
            hexbin = ax.hexbin(
187
                x,
188
                y,
189
                gridsize=gridsize,
190
                **hexbin_kwargs,
191
            )
192 4
            ax.grid(False)
193

194 4
        if kind == "hexbin" and colorbar:
195 4
            cbar = ax.figure.colorbar(hexbin, ticks=[hexbin.norm.vmin, hexbin.norm.vmax], ax=ax)
196 4
            cbar.ax.set_yticklabels(["low", "high"], fontsize=ax_labelsize)
197

198 4
        if divergences:
199 4
            ax.plot(
200
                x[diverging_mask],
201
                y[diverging_mask],
202
                **divergences_kwargs,
203
            )
204

205 4
        if point_estimate:
206 0
            pe_x = calculate_point_estimate(point_estimate, x)
207 0
            pe_y = calculate_point_estimate(point_estimate, y)
208 0
            if marginals:
209 0
                ax_hist_x.axvline(pe_x, **point_estimate_kwargs)
210 0
                ax_hist_y.axhline(pe_y, **point_estimate_kwargs)
211

212 0
            ax.axvline(pe_x, **point_estimate_kwargs)
213 0
            ax.axhline(pe_y, **point_estimate_kwargs)
214

215 0
            ax.scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
216

217 4
        if reference_values:
218 0
            ax.plot(
219
                reference_values_copy[flat_var_names[0]],
220
                reference_values_copy[flat_var_names[1]],
221
                **reference_values_kwargs,
222
            )
223 4
        ax.set_xlabel("{}".format(flat_var_names[0]), fontsize=ax_labelsize, wrap=True)
224 4
        ax.set_ylabel("{}".format(flat_var_names[1]), fontsize=ax_labelsize, wrap=True)
225 4
        ax.tick_params(labelsize=xt_labelsize)
226

227
    else:
228 4
        not_marginals = int(not marginals)
229 4
        num_subplot_cols = numvars - not_marginals
230 4
        max_plots = (
231
            num_subplot_cols ** 2
232
            if rcParams["plot.max_subplots"] is None
233
            else rcParams["plot.max_subplots"]
234
        )
235 4
        cols_to_plot = np.sum(np.arange(1, num_subplot_cols + 1).cumsum() <= max_plots)
236 4
        if cols_to_plot < num_subplot_cols:
237 4
            vars_to_plot = cols_to_plot
238 4
            warnings.warn(
239
                "rcParams['plot.max_subplots'] ({max_plots}) is smaller than the number "
240
                "of resulting pair plots with these variables, generating only a "
241
                "{side}x{side} grid".format(max_plots=max_plots, side=vars_to_plot),
242
                UserWarning,
243
            )
244
        else:
245 4
            vars_to_plot = numvars - not_marginals
246

247 4
        (figsize, ax_labelsize, _, xt_labelsize, _, markersize) = _scale_fig_size(
248
            figsize, textsize, vars_to_plot, vars_to_plot
249
        )
250 4
        backend_kwargs.setdefault("figsize", figsize)
251 4
        point_estimate_marker_kwargs.setdefault("s", markersize + 50)
252

253 4
        if ax is None:
254 4
            fig, ax = plt.subplots(
255
                vars_to_plot,
256
                vars_to_plot,
257
                **backend_kwargs,
258
            )
259 4
        hexbin_values = []
260 4
        for i in range(0, vars_to_plot):
261 4
            var1 = plotters[i][-1].flatten()
262

263 4
            for j in range(0, vars_to_plot):
264 4
                var2 = plotters[j + not_marginals][-1].flatten()
265 4
                if i > j:
266 4
                    if ax[j, i].get_figure() is not None:
267 4
                        ax[j, i].remove()
268 4
                    continue
269

270 4
                elif i == j and marginals:
271 4
                    loc = "right"
272 4
                    plot_dist(var1, ax=ax[i, j], **marginal_kwargs)
273

274
                else:
275 4
                    if i == j:
276 4
                        loc = "left"
277

278 4
                    if "scatter" in kind:
279 4
                        ax[j, i].plot(var1, var2, **scatter_kwargs)
280

281 4
                    if "kde" in kind:
282

283 4
                        plot_kde(
284
                            var1,
285
                            var2,
286
                            ax=ax[j, i],
287
                            **kde_kwargs,
288
                        )
289

290 4
                    if "hexbin" in kind:
291 4
                        ax[j, i].grid(False)
292 4
                        hexbin = ax[j, i].hexbin(var1, var2, gridsize=gridsize, **hexbin_kwargs)
293

294 4
                    if divergences:
295 4
                        ax[j, i].plot(
296
                            var1[diverging_mask], var2[diverging_mask], **divergences_kwargs
297
                        )
298

299 4
                    if kind == "hexbin" and colorbar:
300 4
                        hexbin_values.append(hexbin.norm.vmin)
301 4
                        hexbin_values.append(hexbin.norm.vmax)
302 4
                        divider = make_axes_locatable(ax[-1, -1])
303 4
                        cax = divider.append_axes(loc, size="7%", pad="5%")
304 4
                        cbar = fig.colorbar(
305
                            hexbin, ticks=[hexbin.norm.vmin, hexbin.norm.vmax], cax=cax
306
                        )
307 4
                        cbar.ax.set_yticklabels(["low", "high"], fontsize=ax_labelsize)
308

309 4
                    if point_estimate:
310 4
                        pe_x = calculate_point_estimate(point_estimate, var1)
311 4
                        pe_y = calculate_point_estimate(point_estimate, var2)
312 4
                        ax[j, i].axvline(pe_x, **point_estimate_kwargs)
313 4
                        ax[j, i].axhline(pe_y, **point_estimate_kwargs)
314

315 4
                        if marginals:
316 0
                            ax[j - 1, i].axvline(pe_x, **point_estimate_kwargs)
317 0
                            pe_last = calculate_point_estimate(point_estimate, plotters[-1][-1])
318 0
                            ax[-1, -1].axvline(pe_last, **point_estimate_kwargs)
319

320 4
                        ax[j, i].scatter(pe_x, pe_y, **point_estimate_marker_kwargs)
321

322 4
                    if reference_values:
323 4
                        x_name = flat_var_names[i]
324 4
                        y_name = flat_var_names[j + not_marginals]
325 4
                        if x_name and y_name not in difference:
326 4
                            ax[j, i].plot(
327
                                reference_values_copy[x_name],
328
                                reference_values_copy[y_name],
329
                                **reference_values_kwargs,
330
                            )
331

332 4
                if j != vars_to_plot - 1:
333 4
                    ax[j, i].axes.get_xaxis().set_major_formatter(NullFormatter())
334
                else:
335 4
                    ax[j, i].set_xlabel(
336
                        "{}".format(flat_var_names[i]), fontsize=ax_labelsize, wrap=True
337
                    )
338 4
                if i != 0:
339 4
                    ax[j, i].axes.get_yaxis().set_major_formatter(NullFormatter())
340
                else:
341 4
                    ax[j, i].set_ylabel(
342
                        "{}".format(flat_var_names[j + not_marginals]),
343
                        fontsize=ax_labelsize,
344
                        wrap=True,
345
                    )
346 4
                ax[j, i].tick_params(labelsize=xt_labelsize)
347

348 4
    if backend_show(show):
349 0
        plt.show()
350

351 4
    if marginals and numvars == 2:
352 4
        return ax_return
353 4
    return ax

Read our documentation on viewing source code .

Loading