1
"""
2
Function for drawing and visualisations
3

4
Copyright (C) 2017-2019 Jiri Borovec <jiri.borovec@fel.cvut.cz>
5
"""
6

7 30
import collections
8 30
import logging
9 30
import os
10

11 30
import matplotlib.pylab as plt
12 30
import numpy as np
13 30
from PIL import ImageDraw
14 30
from matplotlib import colors as plt_colors, ticker as plt_ticker
15

16 30
from birl.utilities.data_io import convert_ndarray2image
17 30
from birl.utilities.dataset import scale_large_images_landmarks
18 30
from birl.utilities.evaluate import compute_matrix_user_ranking
19

20
#: default figure size for visualisations
21 30
MAX_FIGURE_SIZE = 18  # inches
22

23

24 30
def draw_image_points(image, points, color='green', marker_size=5, shape='o'):
25
    """ draw marker in the image and add to each landmark its index
26

27
    :param ndarray image: input image
28
    :param ndarray points: np.array<nb_points, dim>
29
    :param str color: color of the marker
30
    :param int marker_size: radius of the circular marker
31
    :param str shape: marker shape: 'o' for circle, '.' for dot
32
    :return: np.ndarray
33

34
    >>> image = np.zeros((10, 10, 3))
35
    >>> points = np.array([[9, 1], [2, 2], [5, 5]])
36
    >>> img = draw_image_points(image, points, marker_size=1, shape='s')
37
    >>> img.shape == (10, 10, 3)  # Windows x64 returns (10L, 10L, 3L)
38
    True
39
    >>> np.round(img[:, :, 1], 2)
40
    array([[ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0.5,  0.5],
41
           [ 0. ,  0.5,  0.5,  0.5,  0. ,  0. ,  0. ,  0. ,  0.5,  0. ],
42
           [ 0. ,  0.5,  0. ,  0.5,  0. ,  0. ,  0. ,  0. ,  0.5,  0.5],
43
           [ 0. ,  0.5,  0.5,  0.5,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
44
           [ 0. ,  0. ,  0. ,  0. ,  0.5,  0.5,  0.5,  0. ,  0. ,  0. ],
45
           [ 0. ,  0. ,  0. ,  0. ,  0.5,  0. ,  0.5,  0. ,  0. ,  0. ],
46
           [ 0. ,  0. ,  0. ,  0. ,  0.5,  0.5,  0.5,  0. ,  0. ,  0. ],
47
           [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
48
           [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ],
49
           [ 0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ,  0. ]])
50
    >>> img = draw_image_points(None, points, marker_size=1)
51
    """
52 30
    assert list(points), 'missing points'
53 30
    if image is None:
54
        # landmark range plus minimal offset to avoid zero image
55 30
        lnds_range = np.max(points, axis=0) - np.min(points, axis=0) + 1
56 30
        image = np.zeros(lnds_range.astype(int).tolist() + [3])
57 30
    image = convert_ndarray2image(image)
58 30
    draw = ImageDraw.Draw(image)
59 30
    for i, (x, y) in enumerate(points):
60 30
        pos_marker = (x - marker_size, y - marker_size, x + marker_size, y + marker_size)
61 30
        pos_text = tuple(points[i] + marker_size)
62 30
        if shape == 'o':
63 30
            draw.ellipse(pos_marker, outline=color)
64 30
        elif shape == '.':
65 0
            draw.ellipse(pos_marker, fill=color)
66 30
        elif shape == 's':
67 30
            draw.rectangle(pos_marker, outline=color)
68
        else:
69 0
            draw.ellipse(pos_marker, fill=color, outline=color)
70 30
        draw.text(pos_text, str(i + 1), fill=(0, 0, 0))
71 30
    image = np.array(image) / 255.
72 30
    return image
73

74

75 30
def draw_landmarks_origin_target_warped(ax, points_origin, points_target, points_warped=None, marker='o'):
76
    """ visualisation of transforming points, presenting 3 set of points:
77
    original points, targeting points, and the estimate of target points
78

79
    scenario 1:
80
    original - moving landmarks
81
    target - reference landmarks
82
    estimate - transformed landmarks
83

84
    scenario 2:
85
    original - reference landmarks
86
    target - moving landmarks
87
    estimate - transformed landmarks
88

89
    :param ax: matplotlib figure
90
    :param ndarray points_origin: np.array<nb_points, dim>
91
    :param ndarray points_target: np.array<nb_points, dim>
92
    :param ndarray points_warped: np.array<nb_points, dim>
93
    :param str marker: set the marker shape
94

95
    >>> points = np.array([[20, 30], [40, 10], [15, 25]])
96
    >>> draw_landmarks_origin_target_warped(plt.figure().gca(),
97
    ...                                     points, points + 1, points - 1)
98
    """
99 30
    pts_sizes = [len(pts) for pts in [points_origin, points_target, points_warped] if pts is not None]
100 30
    assert pts_sizes, 'no landmarks points given'
101 30
    min_pts = min(pts_sizes)
102 30
    assert min(pts_sizes) > 0, 'no points given for sizes: %r' % pts_sizes
103 30
    points_origin = points_origin[:min_pts] if points_origin is not None else None
104 30
    points_target = points_target[:min_pts] if points_target is not None else None
105

106 30
    def _draw_lines(points1, points2, style, color, label):
107 30
        if points1 is None or points2 is None:
108 0
            return
109 30
        for start, stop in zip(points1, points2):
110 30
            x, y = zip(start, stop)
111 30
            ax.plot(x, y, style, color=color, linewidth=2)
112 30
        ax.plot([0, 0], [0, 0], style, color=color, linewidth=2, label=label)
113

114 30
    if points_origin is not None:
115 30
        ax.plot(points_origin[:, 0], points_origin[:, 1], marker, color='g', label='Original positions')
116
    # draw a dotted line between origin and target
117 30
    _draw_lines(points_target, points_origin, '-.', 'g', 'true shift')
118 30
    if points_target is not None:
119 30
        ax.plot(points_target[:, 0], points_target[:, 1], marker, color='m', label='Target positions')
120

121 30
    if points_warped is not None:
122 30
        points_warped = points_warped[:min_pts]
123
        # draw a dotted line between origin and warped
124 30
        _draw_lines(points_origin, points_warped, '-.', 'b', 'warped shift')
125
        # draw line that  should be minimal between target and estimate
126

127 30
        _draw_lines(points_target, points_warped, '-', 'r', 'regist. error (TRE)')
128 30
        ax.plot(points_warped[:, 0], points_warped[:, 1], marker, color='b', label='Estimated positions')
129

130

131 30
def overlap_two_images(image1, image2, transparent=0.5):
132
    """ merge two images together with transparency level
133

134
    :param ndarray image1: np.array<height, with, dim>
135
    :param ndarray image2: np.array<height, with, dim>
136
    :param float transparent: level ot transparency in range (0, 1)
137
        with 1 to see only first image nad 0 to see the second one
138
    :return: np.array<height, with, dim>
139

140
    >>> img1 = np.ones((5, 6, 1)) * 0.2
141
    >>> img2 = np.ones((6, 5, 1)) * 0.8
142
    >>> overlap_two_images(img1, img2, transparent=0.5)[:, :, 0]
143
    array([[ 0.5,  0.5,  0.5,  0.5,  0.5,  0.1],
144
           [ 0.5,  0.5,  0.5,  0.5,  0.5,  0.1],
145
           [ 0.5,  0.5,  0.5,  0.5,  0.5,  0.1],
146
           [ 0.5,  0.5,  0.5,  0.5,  0.5,  0.1],
147
           [ 0.5,  0.5,  0.5,  0.5,  0.5,  0.1],
148
           [ 0.4,  0.4,  0.4,  0.4,  0.4,  0. ]])
149
    """
150 30
    assert image1.ndim == 3, 'required RGB images, got %i' % image1.ndim
151 30
    assert image1.ndim == image2.ndim, 'image dimension has to match, %r != %r' \
152
                                       % (image1.ndim, image2.ndim)
153 30
    size1, size2 = image1.shape, image2.shape
154 30
    max_size = np.max(np.array([size1, size2]), axis=0)
155 30
    image = np.zeros(max_size)
156 30
    image[0:size1[0], 0:size1[1], 0:size1[2]] += image1 * transparent
157 30
    image[0:size2[0], 0:size2[1], 0:size2[2]] += image2 * (1. - transparent)
158
    # np.clip(image, a_min=0., a_max=1., out=image)
159 30
    return image
160

161

162 30
def draw_images_warped_landmarks(
163
    image_target, image_source, points_init, points_target, points_warped, fig_size_max=MAX_FIGURE_SIZE
164
):
165
    """ composed form several functions - images overlap + landmarks + legend
166

167
    :param ndarray image_target: np.array<height, with, dim>
168
    :param ndarray image_source: np.array<height, with, dim>
169
    :param ndarray points_target: np.array<nb_points, dim>
170
    :param ndarray points_init: np.array<nb_points, dim>
171
    :param ndarray points_warped: np.array<nb_points, dim>
172
    :param float fig_size_max: maximal figure size for major image dimension
173
    :return: object
174

175
    >>> image = np.random.random((50, 50, 3))
176
    >>> points = np.array([[20, 30], [40, 10], [15, 25], [5, 50], [10, 60]])
177
    >>> fig = draw_images_warped_landmarks(image, 1 - image, points, points + 1, points - 1)
178
    >>> isinstance(fig, plt.Figure)
179
    True
180
    >>> fig = draw_images_warped_landmarks(None, None, points, points + 1, points - 1)
181
    >>> isinstance(fig, plt.Figure)
182
    True
183
    >>> draw_images_warped_landmarks(image, None, points, points + 1, points - 1)  # doctest: +ELLIPSIS
184
    <...>
185
    >>> draw_images_warped_landmarks(None, image, points, points + 1, points - 1)  # doctest: +ELLIPSIS
186
    <...>
187
    """
188
    # down-scale images and landmarks if they are too large
189 30
    (image_target, image_source), (points_init, points_target, points_warped) = \
190
        scale_large_images_landmarks([image_target, image_source],
191
                                     [points_init, points_target, points_warped])
192

193 30
    if image_target is not None and image_source is not None:
194 30
        image = overlap_two_images(image_target, image_source, transparent=0.5)
195 30
    elif image_target is not None:
196 30
        image = image_target
197 30
    elif image_source is not None:
198 30
        image = image_source
199
    else:
200 30
        image = None
201

202 30
    if image is not None:
203 30
        im_size = image.shape
204 30
        fig, ax = create_figure(im_size, fig_size_max)
205 30
        ax.imshow(image)
206
    else:
207 30
        lnds_size = [
208
            np.max(pts, axis=0) + np.min(pts, axis=0) for pts in [points_init, points_target, points_warped]
209
            if pts is not None
210
        ]
211 30
        im_size = np.max(lnds_size, axis=0).tolist() if lnds_size else (1, 1)
212 30
        fig, ax = create_figure(im_size, fig_size_max)
213

214 30
    draw_landmarks_origin_target_warped(ax, points_init, points_target, points_warped)
215 30
    ax.legend(loc='lower right', title='Legend')
216 30
    ax.set(xlim=[0, im_size[1]], ylim=[im_size[0], 0])
217 30
    ax.axes.get_xaxis().set_ticklabels([])
218 30
    ax.axes.get_yaxis().set_ticklabels([])
219 30
    return fig
220

221

222 30
def create_figure(im_size, figsize_max=MAX_FIGURE_SIZE):
223
    """ create an empty figure of image size maximise maximal size
224

225
    :param tuple(int,int) im_size:
226
    :param float figsize_max:
227
    :return:
228

229
    >>> fig, ax = create_figure((100, 150))
230
    >>> isinstance(fig, plt.Figure)
231
    True
232
    """
233 30
    assert len(im_size) >= 2, 'not valid image size - %r' % im_size
234 30
    size = np.array(im_size[:2])
235 30
    fig_size = size[::-1] / float(size.max()) * figsize_max
236 30
    fig, ax = plt.subplots(figsize=fig_size)
237 30
    return fig, ax
238

239

240 30
def export_figure(path_fig, fig):
241
    """ export the figure and close it afterwords
242

243
    :param str path_fig: path to the new figure image
244
    :param fig: object
245

246
    >>> path_fig = './sample_figure.jpg'
247
    >>> export_figure(path_fig, plt.figure())
248
    >>> os.remove(path_fig)
249
    """
250 30
    assert os.path.isdir(os.path.dirname(path_fig)), \
251
        'missing folder "%s"' % os.path.dirname(path_fig)
252 30
    fig.subplots_adjust(left=0., right=1., top=1., bottom=0.)
253 30
    logging.debug('exporting Figure: %s', path_fig)
254 30
    fig.savefig(path_fig)
255 30
    plt.close(fig)
256

257

258 30
def effective_decimals(num):
259
    """ find the first effective decimal
260

261
    :param float num: number
262
    :return int: number of the first effective decimals
263
    """
264 30
    dec = 0
265 30
    while 0. < num < 1.:
266 30
        dec += 1
267 30
        num *= 10
268 30
    return dec
269

270

271 30
class RadarChart(object):
272
    """
273
    * https://stackoverflow.com/questions/24659005
274
    * https://datascience.stackexchange.com/questions/6084
275

276
    >>> import pandas as pd
277
    >>> df = pd.DataFrame(np.random.random((5, 3)), columns=list('abc'))
278
    >>> RadarChart(df)  # doctest: +ELLIPSIS
279
    <...>
280
    """
281

282 30
    def __init__(self, df, steps=5, fig=None, rect=None, fill_alpha=0.05, colors='nipy_spectral', *args, **kwargs):
283
        """ draw a dataFrame with scaled axis
284

285
        :param df: data
286
        :param int steps: number of steps per axis
287
        :param obj|None fig: Figure or None for a new one
288
        :param tuple(float,float,float,float) rect: rectangle inside figure
289
        :param float fill_alpha: transparency of filled region
290
        :param str cmap: used color map
291
        :param args: optional arguments
292
        :param kwargs: optional key arguments
293
        """
294 30
        if fig is None:
295 30
            fig = plt.figure()
296 30
        if rect is None:
297 30
            rect = [0.05, 0.05, 0.95, 0.95]
298

299 30
        self.titles = list(df.columns)
300 30
        self.nb_steps = steps
301 30
        self.data = df.copy()
302 30
        self.angles = np.linspace(0, 360, len(self.titles), endpoint=False)
303 30
        self.axes = [fig.add_axes(rect, projection="polar", label="axes%d" % i) for i in range(len(self.titles))]
304 30
        self.fig = fig
305

306 30
        self.ax = self.axes[0]
307 30
        self.ax.set_thetagrids(self.angles, labels=self.titles, wrap=True)  # , fontsize=14
308

309 30
        for ax in self.axes[1:]:
310 30
            self.__ax_set_invisible(ax)
311

312 30
        for ax, angle, title in zip(self.axes, self.angles, self.titles):
313 30
            self.__draw_labels(ax, angle, title)
314

315 30
        self.maxs = np.array([self.data[title].max() for title in self.titles])
316

317
        # uf just color space is given, sample colors
318 30
        colors = _list_colors(colors, len(self.data))
319

320 30
        for i, (idx, row) in enumerate(self.data.iterrows()):
321 30
            self.__draw_curve(idx, row, fill_alpha, color=colors[i], *args, **kwargs)
322

323 30
        self._labels = []
324 30
        for ax in self.axes:
325 30
            for theta, label in zip(ax.get_xticks(), ax.get_xticklabels()):
326 30
                self.__realign_polar_xtick(ax, theta, label)
327 30
                self._labels.append(label)
328

329 30
        self._legend = self.ax.legend(loc='center left', bbox_to_anchor=(1.2, 0.7))
330

331 30
    @classmethod
332 6
    def __ax_set_invisible(self, ax):
333 30
        ax.patch.set_visible(False)
334 30
        ax.grid(False)
335 30
        ax.xaxis.set_visible(False)
336

337 30
    def __draw_labels(self, ax, angle, title):
338
        """ draw some labels
339

340
        :param ax:
341
        :param float angle: angle in degree
342
        :param str title: name
343
        """
344 30
        vals = np.linspace(self.data[title].min(), self.data[title].max(), self.nb_steps + 1)
345 30
        dec = effective_decimals(self.data[title].max()) + 1
346 30
        ticks = np.around(vals, dec)
347 30
        ax.set_rgrids(range(1, len(ticks) + 1), angle=angle, labels=ticks)
348 30
        ax.spines["polar"].set_visible(False)
349
        # ax.set_ylim(0, 5)
350

351 30
    def __draw_curve(self, idx, row, fill_alpha=0.05, *args, **kw):
352
        """ draw particular curve
353

354
        :param str idx: name
355
        :param row: data with values
356
        :param fill_alpha: transparency of filled region
357
        :param args: optional arguments
358
        :param kw: optional key arguments
359
        """
360 30
        vals = (row.values / self.maxs * self.nb_steps + 1).tolist()
361 30
        vals.append(vals[0])
362 30
        angs = self.angles.tolist() + [self.angles[0]]
363 30
        self.ax.plot(np.deg2rad(angs), vals, label=idx, *args, **kw)
364 30
        self.ax.fill(np.deg2rad(angs), vals, alpha=fill_alpha)
365

366 30
    @classmethod
367 6
    def __realign_polar_xtick(self, ax, theta, label):
368
        """ shift label for particular axis
369

370
        :param ax: axis
371
        :param obj theta:
372
        :param obj label:
373
        """
374
        # https://stackoverflow.com/questions/20222436
375 30
        theta = theta * ax.get_theta_direction() + ax.get_theta_offset()
376 30
        theta = np.pi / 2 - theta
377 30
        y, x = np.cos(theta), np.sin(theta)
378 30
        if x >= 0.1:
379 30
            label.set_horizontalalignment('left')
380 30
        elif x <= -0.1:
381 30
            label.set_horizontalalignment('right')
382 30
        if y >= 0.5:
383 30
            label.set_verticalalignment('bottom')
384 30
        elif y <= -0.5:
385 30
            label.set_verticalalignment('top')
386

387

388 30
def _list_colors(colors, nb):
389
    """ sample color space
390

391
    :param str|list colors:
392
    :param int nb:
393
    :return list:
394

395
    >>> _list_colors('jet', 2)
396
    [(0.0, 0.0, 0.5, 1.0), (0.5, 0.0, 0.0, 1.0)]
397
    >>> _list_colors(plt.cm.jet, 3)  # doctest: +ELLIPSIS
398
    [(0.0, 0.0, 0.5, 1.0), (0.0, 0.0, 0.5..., 1.0), (0.0, 0.0, 0.5..., 1.0)]
399
    >>> _list_colors([(255, 0, 0), (0, 255, 0)], 1)
400
    [(255, 0, 0), (0, 255, 0)]
401
    """
402
    # uf just color space is given, sample colors
403 30
    if isinstance(colors, str):
404 30
        colors = plt.get_cmap(colors, nb)
405
    # assume case that the color is callable plt.cm.jet
406 30
    if isinstance(colors, collections.Callable):
407 30
        colors = [colors(i) for i in range(nb)]
408 30
    return colors
409

410

411 30
def draw_heatmap(data, row_labels=None, col_labels=None, ax=None, cbar_kw=None, cbar_label="", **kwargs):
412
    """
413
    Create a draw_heatmap from a numpy array and two lists of labels.
414

415
    .. seealso:: https://matplotlib.org/gallery/images_contours_and_fields/image_annotated_heatmap.html
416

417
    :param data: A 2D numpy array of shape (N,M)
418
    :param row_labels: A list or array of length N with the labels for the rows
419
    :param col_labels: A list or array of length M with the labels for the columns
420
    :param ax: A matplotlib.axes.Axes instance to which the draw_heatmap is plotted.
421
     If not provided, use current axes or create a new one.
422
    :param cbar_kw: A dictionary with arguments to :meth:`matplotlib.Figure.colorbar`.
423
    :param cbar_label: The label for the colorbar
424

425
    """
426 30
    cbar_kw = {} if cbar_kw is None else cbar_kw
427 30
    ax = plt.figure(figsize=data.shape[::-1]).gca() if ax is None else ax
428
    # Plot the draw_heatmap
429 30
    im = ax.imshow(data, **kwargs)
430

431
    # Create colorbar
432 30
    cbar = ax.figure.colorbar(im, ax=ax, **cbar_kw)
433 30
    cbar.ax.set_ylabel(cbar_label, rotation=-90, va='bottom')
434

435
    # We want to show all ticks and label them with the respective list entries.
436 30
    if col_labels is not None:
437 30
        ax.set_xticks(np.arange(data.shape[1]))
438 30
        ax.set_xticklabels(col_labels, va='center')
439
    else:
440 0
        ax.set_xticks([])
441

442 30
    if row_labels is not None:
443 30
        ax.set_yticks(np.arange(data.shape[0]))
444 30
        ax.set_yticklabels(row_labels, va='center')
445
    else:
446 0
        ax.set_yticks([])
447

448
    # Let the horizontal axes labeling appear on top.
449 30
    ax.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)
450

451
    # Rotate the tick labels and set their alignment.
452 30
    plt.setp(ax.get_xticklabels(), rotation=90, ha='left', rotation_mode='anchor')
453

454
    # Turn spines off and create white grid.
455 30
    for _, spine in ax.spines.items():
456 30
        spine.set_visible(False)
457

458 30
    ax.grid(False)  # for the general grid
459
    # grid splitting particular color-box, kind of padding
460 30
    ax.set_xticks(np.arange(data.shape[1] + 1) - 0.5, minor=True)
461 30
    ax.set_yticks(np.arange(data.shape[0] + 1) - 0.5, minor=True)
462 30
    ax.grid(which='minor', color='w', linestyle='-', linewidth=3)
463 30
    ax.tick_params(which='minor', bottom=False, left=False)
464

465 30
    return im, cbar
466

467

468 30
def draw_matrix_user_ranking(df_stat, higher_better=False, fig=None, cmap='tab20'):
469
    """ show matrix as image, sorted per column and unique colour per user
470

471
    :param DF df_stat: table where index are users and columns are scoring
472
    :param bool higher_better: ranking such that larger value is better
473
    :param fig: optional figure
474
    :param str cmap: color map
475
    :return Figure:
476

477
    >>> import pandas as pd
478
    >>> df = pd.DataFrame(np.random.random((5, 3)), columns=list('abc'))
479
    >>> draw_matrix_user_ranking(df)  # doctest: +ELLIPSIS
480
    <...>
481
    """
482 30
    ranking = compute_matrix_user_ranking(df_stat, higher_better)
483

484 30
    if fig is None:
485 30
        fig, _ = plt.subplots(figsize=np.array(df_stat.values.shape[::-1]) * 0.35)
486 30
    ax = fig.gca()
487 30
    arange = np.linspace(-0.5, len(df_stat) - 0.5, len(df_stat) + 1)
488 30
    norm = plt_colors.BoundaryNorm(arange, len(df_stat))
489 30
    fmt = plt_ticker.FuncFormatter(lambda x, pos: df_stat.index[x])
490

491 30
    _range = np.arange(1, len(df_stat) + 1)
492 30
    draw_heatmap(
493
        ranking,
494
        _range,
495
        df_stat.columns,
496
        ax=ax,
497
        cmap=plt.get_cmap(cmap, len(df_stat)),
498
        norm=norm,
499
        cbar_kw=dict(ticks=range(len(df_stat)), format=fmt),
500
        cbar_label='Methods',
501
    )
502 30
    ax.set_ylabel('Ranking')
503

504 30
    fig.tight_layout()
505 30
    return fig
506

507

508 30
def draw_scatter_double_scale(
509
    df,
510
    colors='nipy_spectral',
511
    ax_decs=None,
512
    idx_markers=('o', 'd'),
513
    xlabel='',
514
    figsize=None,
515
    legend_style=None,
516
    plot_style=None,
517
    x_spread=(0.4, 5),
518
):
519
    """Draw a scatter with double scales on left and right
520

521
    :param DF df: dataframe
522
    :param func cmap: color mapping
523
    :param dict ax_decs: dictionary with names of left and right axis
524
    :param tuple idx_markers:
525
    :param str xlabel: title of x axis
526
    :param tuple(float,float) figsize:
527
    :param dict legend_style: legend configuration
528
    :param dict plot_style: extra plot configuration
529
    :param tuple(float,int) x_spread: range of spreads and number of samples
530
    :return tuple: figure and both axis
531

532
    >>> import pandas as pd
533
    >>> df = pd.DataFrame(np.random.random((10, 3)), columns=['col1', 'col2', 'col3'])
534
    >>> fig, axs = draw_scatter_double_scale(df, ax_decs={'name': None}, xlabel='X')
535
    >>> axs  # doctest: +ELLIPSIS
536
    {...}
537
    >>> # just the selected columns
538
    >>> fig, axs = draw_scatter_double_scale(df, ax_decs={'name1': ['col1', 'col2'],
539
    ...                                                   'name2': ['col3']})
540
    >>> fig  # doctest: +ELLIPSIS
541
    <...>
542
    >>> # for the "name2" use all remaining columns
543
    >>> fig, axs = draw_scatter_double_scale(df, ax_decs={'name1': ['col1', 'col2'],
544
    ...                                                   'name2': None})
545
    >>> fig  # doctest: +ELLIPSIS
546
    <...>
547
    """
548
    # https://matplotlib.org/gallery/api/two_scales.html
549 30
    fig, ax1 = plt.subplots(figsize=figsize)
550 30
    assert isinstance(ax_decs, dict)
551 30
    ax_names = list(ax_decs.keys())
552 30
    idx_names = list(df.index)
553

554
    # uf just color space is given, sample colors
555 30
    colors = _list_colors(colors, len(idx_names))
556

557
    # https://matplotlib.org/3.1.0/gallery/lines_bars_and_markers/linestyles.html
558
    # https://matplotlib.org/3.1.1/_modules/matplotlib/colors.html
559 30
    tab_colors = ('tab:brown', 'tab:gray') if len(ax_names) > 1 else ('black', )
560

561 30
    ax1.set_ylabel(ax_names[0], color=tab_colors[0])
562 30
    ax1.grid(True, linestyle='dashed', color=tab_colors[0])
563 30
    ax1.tick_params(axis='y', labelcolor=tab_colors[0])
564
    # automatically fill missing names in the other collections
565 30
    if not ax_decs[ax_names[0]]:
566
        # if it is just one add all columns else just supplement
567 30
        ax_decs[ax_names[0]] = [c for c in df.columns if c not in ax_decs[ax_names[1]]] \
568
            if len(ax_names) > 1 and ax_decs.get(ax_names[1]) else list(df.columns)
569

570
    # add second y-axes
571 30
    if len(ax_names) == 2:
572 30
        ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis
573 30
        ax2.set_ylabel(ax_names[1], color=tab_colors[1])  # we already handled the x-label with ax1
574 30
        ax2.grid(True, linestyle='dotted', color=tab_colors[1])
575 30
        ax2.tick_params(axis='y', labelcolor=tab_colors[1])
576
        # automatically fill missing names in the other collections
577 30
        if not ax_decs[ax_names[1]] and ax_decs[ax_names[0]]:
578 22
            ax_decs[ax_names[1]] = [c for c in df.columns if c not in ax_decs[ax_names[0]]]
579
    else:
580 30
        ax2 = None
581

582 30
    plot_style = plot_style if plot_style else {}
583
    # is some spread over x around zero is define
584 30
    if x_spread:
585 30
        x_offsets = np.linspace(-x_spread[0] / 2., x_spread[0] / 2., x_spread[1])
586
    else:
587 0
        x_offsets = [0]
588

589 30
    for i, col in enumerate(df.columns):
590 30
        ax = ax1 if col in ax_decs[ax_names[0]] else ax2
591 30
        for j, idx in enumerate(idx_names):
592
            # print (idx, col, i, df.loc[idx, col])
593 30
            mkr = j % len(idx_markers)
594 30
            x_off = x_offsets[j % len(x_offsets)]
595 30
            ax.plot(i + x_off, df.loc[idx, col], idx_markers[mkr], color=colors[j], label=idx, **plot_style)
596

597 30
    if xlabel:
598 30
        ax1.set_xlabel(xlabel)
599
    # X label ticks - https://stackoverflow.com/questions/43152502
600 30
    ax1.set_xticks(range(len(df.columns)))
601 30
    ax1.set_xticklabels(df.columns, rotation=45, ha="right")
602

603
    # legend - https://matplotlib.org/3.1.1/gallery/text_labels_and_annotations/custom_legends.html
604 30
    if legend_style is None:
605 30
        legend_style = dict(loc='upper center', bbox_to_anchor=(1.25, 1.0), ncol=1)
606 30
    lgd = ax1.legend(idx_names, **legend_style)
607

608 30
    extras = {'ax1': ax1, 'ax2': ax2, 'legend': lgd}
609 30
    return fig, extras

Read our documentation on viewing source code .

Loading