1
"""Matplotlib khatplot."""
2 2
import warnings
3

4 2
import matplotlib as mpl
5 2
import matplotlib.cm as cm
6 2
import matplotlib.pyplot as plt
7 2
import numpy as np
8 2
from matplotlib.colors import to_rgba_array
9

10 2
from ....stats.density_utils import histogram
11 2
from ...plot_utils import _scale_fig_size, color_from_dim, set_xticklabels, vectorized_to_hex
12 2
from . import backend_kwarg_defaults, backend_show, create_axes_grid, matplotlib_kwarg_dealiaser
13

14

15 2
def plot_khat(
16
    hover_label,
17
    hover_format,
18
    ax,
19
    figsize,
20
    xdata,
21
    khats,
22
    kwargs,
23
    annotate,
24
    coord_labels,
25
    show_bins,
26
    hlines_kwargs,
27
    xlabels,
28
    legend,
29
    color,
30
    dims,
31
    textsize,
32
    markersize,
33
    n_data_points,
34
    bin_format,
35
    backend_kwargs,
36
    show,
37
):
38
    """Matplotlib khat plot."""
39 2
    if hover_label and mpl.get_backend() not in mpl.rcsetup.interactive_bk:
40 2
        hover_label = False
41 2
        warnings.warn(
42
            "hover labels are only available with interactive backends. To switch to an "
43
            "interactive backend from ipython or jupyter, use `%matplotlib` there should be "
44
            "no need to restart the kernel. For other cases, see "
45
            "https://matplotlib.org/3.1.0/tutorials/introductory/usage.html#backends",
46
            UserWarning,
47
        )
48

49 2
    if backend_kwargs is None:
50 2
        backend_kwargs = {}
51

52 2
    backend_kwargs = {
53
        **backend_kwarg_defaults(constrained_layout=(not xlabels)),
54
        **backend_kwargs,
55
    }
56

57 2
    (figsize, ax_labelsize, _, xt_labelsize, linewidth, scaled_markersize) = _scale_fig_size(
58
        figsize, textsize
59
    )
60 2
    backend_kwargs.setdefault("figsize", figsize)
61 2
    backend_kwargs["squeeze"] = True
62

63 2
    hlines_kwargs = matplotlib_kwarg_dealiaser(hlines_kwargs, "hlines")
64 2
    hlines_kwargs.setdefault("linestyle", [":", "-.", "--", "-"])
65 2
    hlines_kwargs.setdefault("alpha", 0.7)
66 2
    hlines_kwargs.setdefault("zorder", -1)
67 2
    hlines_kwargs.setdefault("color", "C1")
68 2
    hlines_kwargs["color"] = vectorized_to_hex(hlines_kwargs["color"])
69

70 2
    if markersize is None:
71 2
        markersize = scaled_markersize ** 2  # s in scatter plot mus be markersize square
72
        # for dots to have the same size
73

74 2
    kwargs = matplotlib_kwarg_dealiaser(kwargs, "scatter")
75 2
    kwargs.setdefault("s", markersize)
76 2
    kwargs.setdefault("marker", "+")
77 2
    color_mapping = None
78 2
    cmap = None
79 2
    if isinstance(color, str):
80 2
        if color in dims:
81 2
            colors, color_mapping = color_from_dim(khats, color)
82 2
            cmap_name = kwargs.get("cmap", plt.rcParams["image.cmap"])
83 2
            cmap = getattr(cm, cmap_name)
84 2
            rgba_c = cmap(colors)
85
        else:
86 2
            legend = False
87 2
            rgba_c = to_rgba_array(np.full(n_data_points, color))
88
    else:
89 2
        legend = False
90 2
        try:
91 2
            rgba_c = to_rgba_array(color)
92 2
        except ValueError:
93 2
            cmap_name = kwargs.get("cmap", plt.rcParams["image.cmap"])
94 2
            cmap = getattr(cm, cmap_name)
95 2
            rgba_c = cmap(color)
96

97 2
    khats = khats if isinstance(khats, np.ndarray) else khats.values.flatten()
98 2
    alphas = 0.5 + 0.2 * (khats > 0.5) + 0.3 * (khats > 1)
99 2
    rgba_c[:, 3] = alphas
100 2
    rgba_c = vectorized_to_hex(rgba_c)
101

102 2
    if ax is None:
103 2
        fig, ax = create_axes_grid(
104
            1,
105
            backend_kwargs=backend_kwargs,
106
        )
107
    else:
108 0
        fig = ax.get_figure()
109

110 2
    sc_plot = ax.scatter(xdata, khats, c=rgba_c, **kwargs)
111

112 2
    if annotate:
113 2
        idxs = xdata[khats > 1]
114 2
        for idx in idxs:
115 2
            ax.text(
116
                idx,
117
                khats[idx],
118
                coord_labels[idx],
119
                horizontalalignment="center",
120
                verticalalignment="bottom",
121
                fontsize=0.8 * xt_labelsize,
122
            )
123

124 2
    xmin, xmax = ax.get_xlim()
125 2
    if show_bins:
126 2
        xmax += n_data_points / 12
127 2
    ylims1 = ax.get_ylim()
128 2
    ax.hlines([0, 0.5, 0.7, 1], xmin=xmin, xmax=xmax, linewidth=linewidth, **hlines_kwargs)
129 2
    ylims2 = ax.get_ylim()
130 2
    ymin = min(ylims1[0], ylims2[0])
131 2
    ymax = min(ylims1[1], ylims2[1])
132 2
    if show_bins:
133 2
        bin_edges = np.array([ymin, 0.5, 0.7, 1, ymax])
134 2
        bin_edges = bin_edges[(bin_edges >= ymin) & (bin_edges <= ymax)]
135 2
        hist, _, _ = histogram(khats, bin_edges)
136 2
        for idx, count in enumerate(hist):
137 2
            ax.text(
138
                (n_data_points - 1 + xmax) / 2,
139
                np.mean(bin_edges[idx : idx + 2]),
140
                bin_format.format(count, count / n_data_points * 100),
141
                horizontalalignment="center",
142
                verticalalignment="center",
143
            )
144 2
    ax.set_ylim(ymin, ymax)
145 2
    ax.set_xlim(xmin, xmax)
146

147 2
    ax.set_xlabel("Data Point", fontsize=ax_labelsize)
148 2
    ax.set_ylabel(r"Shape parameter k", fontsize=ax_labelsize)
149 2
    ax.tick_params(labelsize=xt_labelsize)
150 2
    if xlabels:
151 2
        set_xticklabels(ax, coord_labels)
152 2
        fig.autofmt_xdate()
153 2
        fig.tight_layout()
154 2
    if legend:
155 2
        ncols = len(color_mapping) // 6 + 1
156 2
        for label, float_color in color_mapping.items():
157 2
            ax.scatter([], [], c=[cmap(float_color)], label=label, **kwargs)
158 2
        ax.legend(ncol=ncols, title=color)
159

160 2
    if hover_label and mpl.get_backend() in mpl.rcsetup.interactive_bk:
161 0
        _make_hover_annotation(fig, ax, sc_plot, coord_labels, rgba_c, hover_format)
162

163 2
    if backend_show(show):
164 0
        plt.show()
165

166 2
    return ax
167

168

169 2
def _make_hover_annotation(fig, ax, sc_plot, coord_labels, rgba_c, hover_format):
170
    """Show data point label when hovering over it with mouse."""
171 0
    annot = ax.annotate(
172
        "",
173
        xy=(0, 0),
174
        xytext=(0, 0),
175
        textcoords="offset points",
176
        bbox=dict(boxstyle="round", fc="w", alpha=0.4),
177
        arrowprops=dict(arrowstyle="->"),
178
    )
179 0
    annot.set_visible(False)
180 0
    xmid = np.mean(ax.get_xlim())
181 0
    ymid = np.mean(ax.get_ylim())
182 0
    offset = 10
183

184 0
    def update_annot(ind):
185

186 0
        idx = ind["ind"][0]
187 0
        pos = sc_plot.get_offsets()[idx]
188 0
        annot_text = hover_format.format(idx, coord_labels[idx])
189 0
        annot.xy = pos
190 0
        annot.set_position(
191
            (-offset if pos[0] > xmid else offset, -offset if pos[1] > ymid else offset)
192
        )
193 0
        annot.set_text(annot_text)
194 0
        annot.get_bbox_patch().set_facecolor(rgba_c[idx])
195 0
        annot.set_ha("right" if pos[0] > xmid else "left")
196 0
        annot.set_va("top" if pos[1] > ymid else "bottom")
197

198 0
    def hover(event):
199 0
        vis = annot.get_visible()
200 0
        if event.inaxes == ax:
201 0
            cont, ind = sc_plot.contains(event)
202 0
            if cont:
203 0
                update_annot(ind)
204 0
                annot.set_visible(True)
205 0
                fig.canvas.draw_idle()
206
            else:
207 0
                if vis:
208 0
                    annot.set_visible(False)
209 0
                    fig.canvas.draw_idle()
210

211 0
    fig.canvas.mpl_connect("motion_notify_event", hover)

Read our documentation on viewing source code .

Loading