microsoft / graspologic
1
# Copyright 2019 NeuroData (http://neurodata.io)
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4 2
# you may not use this file except in compliance with the License.
5 2
# You may obtain a copy of the License at
6 2
#
7 2
#     http://www.apache.org/licenses/LICENSE-2.0
8 2
#
9 2
# Unless required by applicable law or agreed to in writing, software
10 2
# distributed under the License is distributed on an "AS IS" BASIS,
11 2
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13 2
# limitations under the License.
14 2

15

16
import matplotlib.pyplot as plt
17 2
from matplotlib.colors import Colormap
18
import numpy as np
19
import pandas as pd
20
import seaborn as sns
21
from mpl_toolkits.axes_grid1 import make_axes_locatable
22
from sklearn.utils import check_array, check_consistent_length
23
from sklearn.preprocessing import Binarizer
24

25
from ..embed import selectSVD
26
from ..utils import import_graph, pass_to_ranks
27

28 2

29 2
def _check_common_inputs(
30 2
    figsize=None,
31 2
    height=None,
32
    title=None,
33
    context=None,
34 2
    font_scale=None,
35 2
    legend_name=None,
36 2
    title_pad=None,
37 2
    hier_label_fontsize=None,
38
):
39
    # Handle figsize
40 2
    if figsize is not None:
41 2
        if not isinstance(figsize, tuple):
42 2
            msg = "figsize must be a tuple, not {}.".format(type(figsize))
43 2
            raise TypeError(msg)
44

45
    # Handle heights
46 2
    if height is not None:
47 2
        if not isinstance(height, (int, float)):
48 2
            msg = "height must be an integer or float, not {}.".format(type(height))
49 2
            raise TypeError(msg)
50 2

51 2
    # Handle title
52
    if title is not None:
53
        if not isinstance(title, str):
54
            msg = "title must be a string, not {}.".format(type(title))
55 2
            raise TypeError(msg)
56

57
    # Handle context
58 2
    if context is not None:
59 2
        if not isinstance(context, str):
60 2
            msg = "context must be a string, not {}.".format(type(context))
61
            raise TypeError(msg)
62
        elif context not in ["paper", "notebook", "talk", "poster"]:
63 2
            msg = "context must be one of (paper, notebook, talk, poster), \
64
                not {}.".format(
65
                context
66 2
            )
67 0
            raise ValueError(msg)
68 0

69 0
    # Handle font_scale
70
    if font_scale is not None:
71 2
        if not isinstance(font_scale, (int, float)):
72 2
            msg = "font_scale must be an integer or float, not {}.".format(
73 2
                type(font_scale)
74
            )
75
            raise TypeError(msg)
76 2

77
    # Handle legend name
78 2
    if legend_name is not None:
79 2
        if not isinstance(legend_name, str):
80 2
            msg = "legend_name must be a string, not {}.".format(type(legend_name))
81 2
            raise TypeError(msg)
82

83
    if hier_label_fontsize is not None:
84 2
        if not isinstance(hier_label_fontsize, (int, float)):
85 2
            msg = "hier_label_fontsize must be a scalar, not {}.".format(
86 2
                type(legend_name)
87
            )
88
            raise TypeError(msg)
89 2

90 2
    if title_pad is not None:
91 2
        if not isinstance(title_pad, (int, float)):
92
            msg = "title_pad must be a scalar, not {}.".format(type(legend_name))
93 0
            raise TypeError(msg)
94 2

95 2

96 2
def _transform(arr, method):
97 2
    if method is not None:
98 2
        if method in ["log", "log10"]:
99
            # arr = np.log(arr, where=(arr > 0))
100 2
            # hacky, but np.log(arr, where=arr>0) is really buggy
101
            arr = arr.copy()
102
            if method == "log":
103
                arr[arr > 0] = np.log(arr[arr > 0])
104 0
            else:
105
                arr[arr > 0] = np.log10(arr[arr > 0])
106 2
        elif method in ["zero-boost", "simple-all", "simple-nonzero"]:
107
            arr = pass_to_ranks(arr, method=method)
108
        elif method == "binarize":
109 2
            transformer = Binarizer().fit(arr)
110
            arr = transformer.transform(arr)
111
        else:
112
            msg = "Transform must be one of {log, log10, binarize, zero-boost, simple-all, \
113 2
            simple-nonzero, not {}.".format(
114 2
                method
115
            )
116 2
            raise ValueError(msg)
117

118 2
    return arr
119 0

120 0

121 0
def _process_graphs(
122
    graphs, inner_hier_labels, outer_hier_labels, transform, sort_nodes
123 0
):
124
    """Handles transformation and sorting of graphs for plotting"""
125 2
    for g in graphs:
126 2
        check_consistent_length(g, inner_hier_labels, outer_hier_labels)
127

128 2
    graphs = [_transform(arr, transform) for arr in graphs]
129

130
    if inner_hier_labels is not None:
131
        inner_hier_labels = np.array(inner_hier_labels)
132 2
        if outer_hier_labels is None:
133
            outer_hier_labels = np.ones_like(inner_hier_labels)
134
        else:
135 2
            outer_hier_labels = np.array(outer_hier_labels)
136
    else:
137
        inner_hier_labels = np.ones(graphs[0].shape[0], dtype=int)
138
        outer_hier_labels = np.ones_like(inner_hier_labels)
139

140
    graphs = [
141
        _sort_graph(arr, inner_hier_labels, outer_hier_labels, sort_nodes)
142
        for arr in graphs
143
    ]
144
    return graphs
145

146

147
def heatmap(
148
    X,
149
    transform=None,
150
    figsize=(10, 10),
151
    title=None,
152
    context="talk",
153
    font_scale=1,
154
    xticklabels=False,
155
    yticklabels=False,
156
    cmap="RdBu_r",
157
    vmin=None,
158
    vmax=None,
159
    center=0,
160
    cbar=True,
161
    inner_hier_labels=None,
162
    outer_hier_labels=None,
163
    hier_label_fontsize=30,
164
    ax=None,
165
    title_pad=None,
166
    sort_nodes=False,
167
    **kwargs
168
):
169
    r"""
170
    Plots a graph as a color-encoded matrix.
171

172
    Nodes can be grouped by providing `inner_hier_labels` or both
173
    `inner_hier_labels` and `outer_hier_labels`. Nodes can also
174
    be sorted by the degree from largest to smallest degree nodes.
175
    The nodes will be sorted within each group if labels are also
176
    provided.
177

178
    Read more in the :ref:`tutorials <plot_tutorials>`
179

180
    Parameters
181
    ----------
182
    X : nx.Graph or np.ndarray object
183
        Graph or numpy matrix to plot
184

185
    transform : None, or string {'log', 'log10', 'zero-boost', 'simple-all', 'simple-nonzero'}
186

187
        - 'log'
188
            Plots the natural log of all nonzero numbers
189
        - 'log10'
190
            Plots the base 10 log of all nonzero numbers
191
        - 'zero-boost'
192
            Pass to ranks method. preserves the edge weight for all 0s, but ranks
193
            the other edges as if the ranks of all 0 edges has been assigned.
194
        - 'simple-all'
195
            Pass to ranks method. Assigns ranks to all non-zero edges, settling
196
            ties using the average. Ranks are then scaled by
197
            :math:`\frac{rank(\text{non-zero edges})}{n^2 + 1}`
198
            where n is the number of nodes
199
        - 'simple-nonzero'
200
            Pass to ranks method. Same as simple-all, but ranks are scaled by
201
            :math:`\frac{rank(\text{non-zero edges})}{\text{# non-zero edges} + 1}`
202
        - 'binarize'
203
            Binarize input graph such that any edge weight greater than 0 becomes 1.
204

205
    figsize : tuple of integers, optional, default: (10, 10)
206
        Width, height in inches.
207

208
    title : str, optional, default: None
209
        Title of plot.
210

211
    context :  None, or one of {paper, notebook, talk (default), poster}
212
        The name of a preconfigured set.
213

214
    font_scale : float, optional, default: 1
215
        Separate scaling factor to independently scale the size of the font
216
        elements.
217

218
    xticklabels, yticklabels : bool or list, optional
219
        If list-like, plot these alternate labels as the ticklabels.
220

221
    cmap : str, list of colors, or matplotlib.colors.Colormap, default: 'RdBu_r'
222
        Valid matplotlib color map.
223

224
    vmin, vmax : floats, optional (default=None)
225
        Values to anchor the colormap, otherwise they are inferred from the data and
226
        other keyword arguments.
227

228
    center : float, default: 0
229
        The value at which to center the colormap
230

231
    cbar : bool, default: True
232
        Whether to draw a colorbar.
233

234
    inner_hier_labels : array-like, length of X's first dimension, default: None
235
        Categorical labeling of the nodes. If not None, will group the nodes
236
        according to these labels and plot the labels on the marginal
237

238
    outer_hier_labels : array-like, length of X's first dimension, default: None
239
        Categorical labeling of the nodes, ignored without ``inner_hier_labels``
240
        If not None, will plot these labels as the second level of a hierarchy on the
241
        marginals
242

243
    hier_label_fontsize : int
244
        Size (in points) of the text labels for the ``inner_hier_labels`` and
245
        ``outer_hier_labels``.
246

247
    ax : matplotlib Axes, optional
248
        Axes in which to draw the plot, otherwise will generate its own axes
249

250 2
    title_pad : int, float or None, optional (default=None)
251
        Custom padding to use for the distance of the title from the heatmap. Autoscales
252
        if ``None``
253

254
    sort_nodes : boolean, optional (default=False)
255
        Whether or not to sort the nodes of the graph by the sum of edge weights
256
        (degree for an unweighted graph). If ``inner_hier_labels`` is passed and
257
        ``sort_nodes`` is ``True``, will sort nodes this way within block.
258

259
    **kwargs : dict, optional
260 2
        additional plotting arguments passed to Seaborn's ``heatmap``
261 2
    """
262 2
    _check_common_inputs(
263 2
        figsize=figsize,
264 2
        title=title,
265 2
        context=context,
266 2
        font_scale=font_scale,
267
        hier_label_fontsize=hier_label_fontsize,
268 2
        title_pad=title_pad,
269 2
    )
270 0

271 0
    # Handle ticklabels
272 2
    if isinstance(xticklabels, list):
273 0
        if len(xticklabels) != X.shape[1]:
274 0
            msg = "xticklabels must have same length {}.".format(X.shape[1])
275
            raise ValueError(msg)
276
    elif not isinstance(xticklabels, bool):
277 2
        msg = "xticklabels must be a bool or a list, not {}".format(type(xticklabels))
278 2
        raise TypeError(msg)
279 2

280 2
    if isinstance(yticklabels, list):
281
        if len(yticklabels) != X.shape[0]:
282
            msg = "yticklabels must have same length {}.".format(X.shape[0])
283 2
            raise ValueError(msg)
284 2
    elif not isinstance(yticklabels, bool):
285 2
        msg = "yticklabels must be a bool or a list, not {}".format(type(yticklabels))
286 2
        raise TypeError(msg)
287

288
    # Handle cmap
289 2
    if not isinstance(cmap, (str, list, Colormap)):
290 2
        msg = "cmap must be a string, list of colors, or matplotlib.colors.Colormap,"
291 2
        msg += " not {}.".format(type(cmap))
292
        raise TypeError(msg)
293 2

294
    # Handle center
295 2
    if center is not None:
296
        if not isinstance(center, (int, float)):
297
            msg = "center must be a integer or float, not {}.".format(type(center))
298
            raise TypeError(msg)
299

300 2
    # Handle cbar
301
    if not isinstance(cbar, bool):
302 2
        msg = "cbar must be a bool, not {}.".format(type(center))
303 2
        raise TypeError(msg)
304 2

305 2
    arr = import_graph(X)
306

307
    arr = _process_graphs(
308
        [arr], inner_hier_labels, outer_hier_labels, transform, sort_nodes
309
    )[0]
310

311
    # Global plotting settings
312
    CBAR_KWS = dict(shrink=0.7)  # norm=colors.Normalize(vmin=0, vmax=1))
313

314
    with sns.plotting_context(context, font_scale=font_scale):
315
        if ax is None:
316
            fig, ax = plt.subplots(figsize=figsize)
317
        plot = sns.heatmap(
318
            arr,
319
            cmap=cmap,
320 2
            square=True,
321 0
            xticklabels=xticklabels,
322 0
            yticklabels=yticklabels,
323 0
            cbar_kws=CBAR_KWS,
324
            center=center,
325 0
            cbar=cbar,
326 0
            ax=ax,
327 2
            vmin=vmin,
328 0
            vmax=vmax,
329 0
            **kwargs
330 0
        )
331 0

332
        if title is not None:
333
            if title_pad is None:
334
                if inner_hier_labels is not None:
335
                    title_pad = 1.5 * font_scale + 1 * hier_label_fontsize + 30
336
                else:
337
                    title_pad = 1.5 * font_scale + 15
338
            plot.set_title(title, pad=title_pad)
339 0
        if inner_hier_labels is not None:
340 2
            if outer_hier_labels is not None:
341
                plot.set_yticklabels([])
342
                plot.set_xticklabels([])
343 2
                _plot_groups(
344
                    plot,
345
                    arr,
346
                    inner_hier_labels,
347
                    outer_hier_labels,
348
                    fontsize=hier_label_fontsize,
349
                )
350
            else:
351
                _plot_groups(plot, arr, inner_hier_labels, fontsize=hier_label_fontsize)
352
    return plot
353

354

355
def gridplot(
356
    X,
357
    labels=None,
358
    transform=None,
359
    height=10,
360
    title=None,
361
    context="talk",
362
    font_scale=1,
363
    alpha=0.7,
364
    sizes=(10, 200),
365
    palette="Set1",
366
    legend_name="Type",
367
    inner_hier_labels=None,
368
    outer_hier_labels=None,
369
    hier_label_fontsize=30,
370
    title_pad=None,
371
    sort_nodes=False,
372
):
373
    r"""
374
    Plots multiple graphs on top of each other with dots as edges.
375

376
    This function is useful for visualizing multiple graphs simultaneously.
377
    The size of the dots correspond to the edge weights of the graphs, and
378
    colors represent input graphs.
379

380
    Read more in the :ref:`tutorials <plot_tutorials>`
381

382
    Parameters
383
    ----------
384
    X : list of nx.Graph or np.ndarray object
385
        List of nx.Graph or numpy arrays to plot
386
    labels : list of str
387
        List of strings, which are labels for each element in X.
388
        ``len(X) == len(labels)``.
389
    transform : None, or string {'log', 'log10', 'zero-boost', 'simple-all', 'simple-nonzero'}
390

391
        - 'log'
392
            Plots the natural log of all nonzero numbers
393
        - 'log10'
394
            Plots the base 10 log of all nonzero numbers
395
        - 'zero-boost'
396
            Pass to ranks method. preserves the edge weight for all 0s, but ranks
397
            the other edges as if the ranks of all 0 edges has been assigned.
398
        - 'simple-all'
399
            Pass to ranks method. Assigns ranks to all non-zero edges, settling
400
            ties using the average. Ranks are then scaled by
401
            :math:`\frac{rank(\text{non-zero edges})}{n^2 + 1}`
402
            where n is the number of nodes
403
        - 'simple-nonzero'
404
            Pass to ranks method. Same as simple-all, but ranks are scaled by
405
            :math:`\frac{rank(\text{non-zero edges})}{\text{# non-zero edges} + 1}`
406
        - 'binarize'
407
            Binarize input graph such that any edge weight greater than 0 becomes 1.
408
    height : int, optional, default: 10
409
        Height of figure in inches.
410
    title : str, optional, default: None
411
        Title of plot.
412
    context :  None, or one of {paper, notebook, talk (default), poster}
413
        The name of a preconfigured set.
414
    font_scale : float, optional, default: 1
415
        Separate scaling factor to independently scale the size of the font
416
        elements.
417
    palette : str, dict, optional, default: 'Set1'
418
        Set of colors for mapping the ``hue`` variable. If a dict, keys should
419
        be values in the hue variable
420
    alpha : float [0, 1], default : 0.7
421
        Alpha value of plotted gridplot points
422
    sizes : length 2 tuple, default: (10, 200)
423
        Min and max size to plot edge weights
424
    legend_name : string, default: 'Type'
425
        Name to plot above the legend
426
    inner_hier_labels : array-like, length of X's first dimension, default: None
427
        Categorical labeling of the nodes. If not None, will group the nodes
428
        according to these labels and plot the labels on the marginal
429
    outer_hier_labels : array-like, length of X's first dimension, default: None
430
        Categorical labeling of the nodes, ignored without ``inner_hier_labels``
431
        If not None, will plot these labels as the second level of a hierarchy on the
432 2
        marginals
433
    hier_label_fontsize : int
434
        Size (in points) of the text labels for the ``inner_hier_labels`` and
435
        ``outer_hier_labels``.
436
    title_pad : int, float or None, optional (default=None)
437
        Custom padding to use for the distance of the title from the heatmap. Autoscales
438
        if ``None``
439
    sort_nodes : boolean, optional (default=False)
440
        Whether or not to sort the nodes of the graph by the sum of edge weights
441 2
        (degree for an unweighted graph). If ``inner_hier_labels`` is passed and
442 2
        ``sort_nodes`` is ``True``, will sort nodes this way within block.
443
    """
444 2
    _check_common_inputs(
445 2
        height=height,
446
        title=title,
447 2
        context=context,
448 0
        font_scale=font_scale,
449
        hier_label_fontsize=hier_label_fontsize,
450 2
        title_pad=title_pad,
451
    )
452 2

453
    if isinstance(X, list):
454
        graphs = [import_graph(x) for x in X]
455
    else:
456 2
        msg = "X must be a list, not {}.".format(type(X))
457 2
        raise TypeError(msg)
458

459 2
    if labels is None:
460 2
        labels = np.arange(len(X))
461 2

462 2
    check_consistent_length(X, labels)
463 2

464
    graphs = _process_graphs(
465
        X, inner_hier_labels, outer_hier_labels, transform, sort_nodes
466
    )
467 2

468 2
    if isinstance(palette, str):
469
        palette = sns.color_palette(palette, desat=0.75, n_colors=len(labels))
470 2

471
    dfs = []
472 2
    for idx, graph in enumerate(graphs):
473 2
        rdx, cdx = np.where(graph > 0)
474 2
        weights = graph[(rdx, cdx)]
475
        df = pd.DataFrame(
476
            np.vstack([rdx + 0.5, cdx + 0.5, weights]).T,
477
            columns=["rdx", "cdx", "Weights"],
478
        )
479
        df[legend_name] = [labels[idx]] * len(cdx)
480
        dfs.append(df)
481

482
    df = pd.concat(dfs, axis=0)
483

484
    with sns.plotting_context(context, font_scale=font_scale):
485
        sns.set_style("white")
486
        plot = sns.relplot(
487
            data=df,
488
            x="cdx",
489
            y="rdx",
490
            hue=legend_name,
491 2
            size="Weights",
492 2
            sizes=sizes,
493 2
            alpha=alpha,
494 2
            palette=palette,
495 2
            height=height,
496 0
            facet_kws={
497
                "sharex": True,
498 2
                "sharey": True,
499 2
                "xlim": (0, graph.shape[0] + 1),
500 2
                "ylim": (0, graph.shape[0] + 1),
501 0
            },
502 0
        )
503
        plot.ax.axis("off")
504
        plot.ax.invert_yaxis()
505
        if title is not None:
506
            if title_pad is None:
507
                if inner_hier_labels is not None:
508
                    title_pad = 1.5 * font_scale + 1 * hier_label_fontsize + 30
509
                else:
510 0
                    title_pad = 1.5 * font_scale + 15
511
            plt.title(title, pad=title_pad)
512
    if inner_hier_labels is not None:
513 2
        if outer_hier_labels is not None:
514
            _plot_groups(
515
                plot.ax,
516 2
                graphs[0],
517
                inner_hier_labels,
518
                outer_hier_labels,
519
                fontsize=hier_label_fontsize,
520
            )
521
        else:
522
            _plot_groups(
523
                plot.ax, graphs[0], inner_hier_labels, fontsize=hier_label_fontsize
524
            )
525
    return plot
526

527

528
def pairplot(
529
    X,
530
    labels=None,
531
    col_names=None,
532
    title=None,
533
    legend_name=None,
534
    variables=None,
535
    height=2.5,
536
    context="talk",
537
    font_scale=1,
538
    palette="Set1",
539
    alpha=0.7,
540
    size=50,
541
    marker=".",
542
    diag_kind="auto",
543
):
544
    r"""
545
    Plot pairwise relationships in a dataset.
546

547
    By default, this function will create a grid of Axes such that each dimension
548
    in data will by shared in the y-axis across a single row and in the x-axis
549
    across a single column.
550

551
    The off-diagonal Axes show the pairwise relationships displayed as scatterplot.
552
    The diagonal Axes show the univariate distribution of the data for that
553
    dimension displayed as either a histogram or kernel density estimates (KDEs).
554

555
    Read more in the :ref:`tutorials <plot_tutorials>`
556

557
    Parameters
558
    ----------
559
    X : array-like, shape (n_samples, n_features)
560
        Input data.
561
    labels : array-like or list, shape (n_samples), optional
562
        Labels that correspond to each sample in X.
563
    col_names : array-like or list, shape (n_features), optional
564
        Names or labels for each feature in X. If not provided, the default
565
        will be `Dimension 1, Dimension 2, etc`.
566
    title : str, optional, default: None
567
        Title of plot.
568
    legend_name : str, optional, default: None
569
        Title of the legend.
570
    variables : list of variable names, optional
571
        Variables to plot based on col_names, otherwise use every column with
572
        a numeric datatype.
573
    height : int, optional, default: 10
574
        Height of figure in inches.
575
    context :  None, or one of {paper, notebook, talk (default), poster}
576
        The name of a preconfigured set.
577
    font_scale : float, optional, default: 1
578
        Separate scaling factor to independently scale the size of the font
579 2
        elements.
580
    palette : str, dict, optional, default: 'Set1'
581
        Set of colors for mapping the ``hue`` variable. If a dict, keys should
582
        be values in the hue variable.
583
    alpha : float, optional, default: 0.7
584
        Opacity value of plotter markers between 0 and 1
585
    size : float or int, optional, default: 50
586
        Size of plotted markers.
587
    marker : string, optional, default: '.'
588 2
        Matplotlib style marker specification
589 2
        https://matplotlib.org/api/markers_api.html
590 2
    """
591
    _check_common_inputs(
592
        height=height,
593 2
        title=title,
594 2
        context=context,
595 0
        font_scale=font_scale,
596 0
        legend_name=legend_name,
597 2
    )
598 2

599
    # Handle X
600
    if not isinstance(X, (list, np.ndarray)):
601 2
        msg = "X must be array-like, not {}.".format(type(X))
602
        raise TypeError(msg)
603

604 2
    # Handle Y
605 2
    if labels is not None:
606 2
        if not isinstance(labels, (list, np.ndarray)):
607 2
            msg = "Y must be array-like or list, not {}.".format(type(labels))
608 2
            raise TypeError(msg)
609 2
        elif X.shape[0] != len(labels):
610 2
            msg = "Expected length {}, but got length {} instead for Y.".format(
611
                X.shape[0], len(labels)
612
            )
613 2
            raise ValueError(msg)
614

615
    # Handle col_names
616 2
    if col_names is None:
617 2
        col_names = ["Dimension {}".format(i) for i in range(1, X.shape[1] + 1)]
618 2
    elif not isinstance(col_names, list):
619 2
        msg = "col_names must be a list, not {}.".format(type(col_names))
620
        raise TypeError(msg)
621 2
    elif X.shape[1] != len(col_names):
622 2
        msg = "Expected length {}, but got length {} instead for col_names.".format(
623 2
            X.shape[1], len(col_names)
624 2
        )
625
        raise ValueError(msg)
626 2

627
    # Handle variables
628 2
    if variables is not None:
629 2
        if len(variables) > len(col_names):
630 2
            msg = "variables cannot contain more elements than col_names."
631 2
            raise ValueError(msg)
632 2
        else:
633 2
            for v in variables:
634
                if v not in col_names:
635 2
                    msg = "{} is not a valid key.".format(v)
636 2
                    raise KeyError(msg)
637 0
    else:
638 2
        variables = col_names
639

640
    df = pd.DataFrame(X, columns=col_names)
641
    if labels is not None:
642
        if legend_name is None:
643
            legend_name = "Type"
644
        df_labels = pd.DataFrame(labels, columns=[legend_name])
645 2
        df = pd.concat([df_labels, df], axis=1)
646 2

647 2
        names, counts = np.unique(labels, return_counts=True)
648
        if counts.min() < 2:
649
            diag_kind = "hist"
650
    plot_kws = dict(
651
        alpha=alpha,
652
        s=size,
653
        # edgecolor=None, # could add this latter
654
        linewidth=0,
655
        marker=marker,
656
    )
657 2
    with sns.plotting_context(context=context, font_scale=font_scale):
658
        if labels is not None:
659
            pairs = sns.pairplot(
660
                df,
661
                hue=legend_name,
662
                vars=variables,
663
                height=height,
664
                palette=palette,
665 2
                diag_kind=diag_kind,
666 2
                plot_kws=plot_kws,
667 2
            )
668
        else:
669 2
            pairs = sns.pairplot(
670
                df,
671
                vars=variables,
672 2
                height=height,
673
                palette=palette,
674
                diag_kind=diag_kind,
675
                plot_kws=plot_kws,
676
            )
677
        pairs.set(xticks=[], yticks=[])
678
        pairs.fig.subplots_adjust(top=0.945)
679
        pairs.fig.suptitle(title)
680

681
    return pairs
682

683

684
def _distplot(
685 0
    data,
686 0
    labels=None,
687 0
    direction="out",
688 0
    title="",
689 0
    context="talk",
690 0
    font_scale=1,
691 0
    figsize=(10, 5),
692 0
    palette="Set1",
693 0
    xlabel="",
694 0
    ylabel="Density",
695 0
):
696 0

697 0
    plt.figure(figsize=figsize)
698
    ax = plt.gca()
699 0
    palette = sns.color_palette(palette)
700 0
    plt_kws = {"cumulative": True}
701
    with sns.plotting_context(context=context, font_scale=font_scale):
702 0
        if labels is not None:
703 0
            categories, counts = np.unique(labels, return_counts=True)
704
            for i, cat in enumerate(categories):
705 0
                cat_data = data[np.where(labels == cat)]
706
                if counts[i] > 1 and cat_data.min() != cat_data.max():
707 0
                    x = np.sort(cat_data)
708 0
                    y = np.arange(len(x)) / float(len(x))
709 0
                    plt.plot(x, y, label=cat, color=palette[i])
710
                else:
711 0
                    ax.axvline(cat_data[0], label=cat, color=palette[i])
712
            plt.legend()
713
        else:
714 2
            if data.min() != data.max():
715
                sns.distplot(data, hist=False, kde_kws=plt_kws)
716
            else:
717
                ax.axvline(data[0])
718

719
        plt.title(title)
720
        plt.xlabel(xlabel)
721
        plt.ylabel(ylabel)
722

723
    return ax
724

725

726
def degreeplot(
727
    X,
728
    labels=None,
729
    direction="out",
730
    title="Degree plot",
731
    context="talk",
732
    font_scale=1,
733
    figsize=(10, 5),
734
    palette="Set1",
735
):
736
    r"""
737
    Plots the distribution of node degrees for the input graph.
738
    Allows for sets of node labels, will plot a distribution for each
739
    node category.
740

741
    Parameters
742
    ----------
743
    X : np.ndarray (2D)
744
        input graph
745
    labels : 1d np.ndarray or list, same length as dimensions of X
746
        Labels for different categories of graph nodes
747
    direction : string, ('out', 'in')
748
        Whether to plot out degree or in degree for a directed graph
749
    title : string, default : 'Degree plot'
750
        Plot title
751
    context :  None, or one of {talk (default), paper, notebook, poster}
752
        Seaborn plotting context
753
    font_scale : float, optional, default: 1
754
        Separate scaling factor to independently scale the size of the font
755 0
        elements.
756
    palette : str, dict, optional, default: 'Set1'
757
        Set of colors for mapping the ``hue`` variable. If a dict, keys should
758 0
        be values in the hue variable.
759 0
    figsize : tuple of length 2, default (10, 5)
760 0
        Size of the figure (width, height)
761 0

762 0
    Returns
763 0
    -------
764 0
    ax : matplotlib axis object
765
        Output plot
766 0
    """
767 0
    _check_common_inputs(
768 0
        figsize=figsize, title=title, context=context, font_scale=font_scale
769
    )
770
    check_array(X)
771
    if direction == "out":
772
        axis = 0
773
        check_consistent_length((X, labels))
774
    elif direction == "in":
775
        axis = 1
776
        check_consistent_length((X.T, labels))
777
    else:
778 0
        raise ValueError('direction must be either "out" or "in"')
779
    degrees = np.count_nonzero(X, axis=axis)
780
    ax = _distplot(
781 2
        degrees,
782
        labels=labels,
783
        title=title,
784
        context=context,
785
        font_scale=font_scale,
786
        figsize=figsize,
787
        palette=palette,
788
        xlabel="Node degree",
789
    )
790
    return ax
791

792

793
def edgeplot(
794
    X,
795
    labels=None,
796
    nonzero=False,
797
    title="Edge plot",
798
    context="talk",
799
    font_scale=1,
800
    figsize=(10, 5),
801
    palette="Set1",
802
):
803
    r"""
804
    Plots the distribution of edge weights for the input graph.
805
    Allows for sets of node labels, will plot edge weight distribution
806
    for each node category.
807

808
    Parameters
809
    ----------
810
    X : np.ndarray (2D)
811
        Input graph
812
    labels : 1d np.ndarray or list, same length as dimensions of X
813
        Labels for different categories of graph nodes
814
    nonzero : boolean, default: False
815
        Whether to restrict the edgeplot to only the non-zero edges
816
    title : string, default : 'Edge plot'
817
        Plot title
818
    context :  None, or one of {talk (default), paper, notebook, poster}
819
        Seaborn plotting context
820
    font_scale : float, optional, default: 1
821
        Separate scaling factor to independently scale the size of the font
822 0
        elements.
823
    palette : str, dict, optional, default: 'Set1'
824
        Set of colors for mapping the ``hue`` variable. If a dict, keys should
825 0
        be values in the hue variable.
826 0
    figsize : tuple of length 2, default (10, 5)
827 0
        Size of the figure (width, height)
828 0

829 0
    Returns
830 0
    -------
831 0
    ax : matplotlib axis object
832 0
        Output plot
833 0
    """
834
    _check_common_inputs(
835
        figsize=figsize, title=title, context=context, font_scale=font_scale
836
    )
837
    check_array(X)
838
    check_consistent_length((X, labels))
839
    edges = X.ravel()
840
    labels = np.tile(labels, (1, X.shape[1]))
841
    labels = labels.ravel()
842
    if nonzero:
843 0
        labels = labels[edges != 0]
844
        edges = edges[edges != 0]
845
    ax = _distplot(
846 2
        edges,
847
        labels=labels,
848
        title=title,
849
        context=context,
850
        font_scale=font_scale,
851
        figsize=figsize,
852
        palette=palette,
853
        xlabel="Edge weight",
854
    )
855
    return ax
856

857

858
def screeplot(
859
    X,
860
    title="Scree plot",
861
    context="talk",
862
    font_scale=1,
863
    figsize=(10, 5),
864
    cumulative=True,
865
    show_first=None,
866
):
867
    r"""
868
    Plots the distribution of singular values for a matrix, either showing the
869
    raw distribution or an empirical CDF (depending on ``cumulative``)
870

871
    Parameters
872
    ----------
873
    X : np.ndarray (2D)
874
        Input matrix
875
    title : string, default : 'Scree plot'
876
        Plot title
877
    context :  None, or one of {talk (default), paper, notebook, poster}
878
        Seaborn plotting context
879
    font_scale : float, optional, default: 1
880
        Separate scaling factor to independently scale the size of the font
881
        elements.
882 0
    figsize : tuple of length 2, default (10, 5)
883
        Size of the figure (width, height)
884
    cumulative : boolean, default: True
885 0
        Whether or not to plot a cumulative cdf of singular values
886 0
    show_first : int or None, default: None
887 0
        Whether to restrict the plot to the first ``show_first`` components
888 0

889 0
    Returns
890 0
    -------
891 0
    ax : matplotlib axis object
892 0
        Output plot
893 0
    """
894 0
    _check_common_inputs(
895 0
        figsize=figsize, title=title, context=context, font_scale=font_scale
896 0
    )
897
    check_array(X)
898 0
    if show_first is not None:
899 0
        if not isinstance(show_first, int):
900 0
            msg = "show_first must be an int"
901 0
            raise TypeError(msg)
902 0
    if not isinstance(cumulative, bool):
903 0
        msg = "cumulative must be a boolean"
904 0
        raise TypeError(msg)
905 0
    _, D, _ = selectSVD(X, n_components=X.shape[1], algorithm="full")
906 0
    D /= D.sum()
907 0
    if cumulative:
908 0
        y = np.cumsum(D[:show_first])
909
    else:
910
        y = D[:show_first]
911 2
    _ = plt.figure(figsize=figsize)
912 2
    ax = plt.gca()
913 2
    xlabel = "Component"
914 2
    ylabel = "Variance explained"
915
    with sns.plotting_context(context=context, font_scale=font_scale):
916
        plt.plot(y)
917 2
        plt.title(title)
918 2
        plt.xlabel(xlabel)
919
        plt.ylabel(ylabel)
920
    return ax
921

922 2

923 2
def _sort_inds(graph, inner_labels, outer_labels, sort_nodes):
924
    sort_df = pd.DataFrame(columns=("inner_labels", "outer_labels"))
925
    sort_df["inner_labels"] = inner_labels
926 2
    sort_df["outer_labels"] = outer_labels
927 2

928
    # get frequencies of the different labels so we can sort by them
929 2
    inner_label_counts = _get_freq_vec(inner_labels)
930 2
    outer_label_counts = _get_freq_vec(outer_labels)
931

932
    # inverse counts so we can sort largest to smallest
933
    # would rather do it this way so can still sort alphabetical for ties
934
    sort_df["inner_counts"] = len(inner_labels) - inner_label_counts
935
    sort_df["outer_counts"] = len(outer_labels) - outer_label_counts
936

937
    # get node edge sums (not exactly degrees if weighted)
938 2
    node_edgesums = graph.sum(axis=1) + graph.sum(axis=0)
939 2
    sort_df["node_edgesums"] = node_edgesums.max() - node_edgesums
940

941 2
    if sort_nodes:
942 2
        by = [
943
            "outer_counts",
944
            "outer_labels",
945 2
            "inner_counts",
946 2
            "inner_labels",
947 2
            "node_edgesums",
948 2
        ]
949
    else:
950
        by = ["outer_counts", "outer_labels", "inner_counts", "inner_labels"]
951 2
    sort_df.sort_values(by=by, kind="mergesort", inplace=True)
952

953 0
    sorted_inds = sort_df.index.values
954 0
    return sorted_inds
955

956

957 0
def _sort_graph(graph, inner_labels, outer_labels, sort_nodes):
958 0
    inds = _sort_inds(graph, inner_labels, outer_labels, sort_nodes)
959 0
    graph = graph[inds, :][:, inds]
960 0
    return graph
961 0

962 0

963 0
def _get_freqs(inner_labels, outer_labels=None):
964
    # use this because unique would give alphabetical
965 0
    _, outer_freq = _unique_like(outer_labels)
966
    outer_freq_cumsum = np.hstack((0, outer_freq.cumsum()))
967

968 2
    # for each group of outer labels, calculate the boundaries of the inner labels
969
    inner_freq = np.array([])
970 2
    for i in range(outer_freq.size):
971 2
        start_ind = outer_freq_cumsum[i]
972 2
        stop_ind = outer_freq_cumsum[i + 1]
973
        _, temp_freq = _unique_like(inner_labels[start_ind:stop_ind])
974
        inner_freq = np.hstack([inner_freq, temp_freq])
975 2
    inner_freq_cumsum = np.hstack((0, inner_freq.cumsum()))
976

977 0
    return inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum
978 0

979 0

980 0
def _get_freq_vec(vals):
981 0
    # give each set of labels a vector corresponding to its frequency
982
    _, inv, counts = np.unique(vals, return_counts=True, return_inverse=True)
983
    count_vec = counts[inv]
984
    return count_vec
985 2

986 0

987 0
def _unique_like(vals):
988 0
    # gives output like
989 0
    uniques, inds, counts = np.unique(vals, return_index=True, return_counts=True)
990 0
    inds_sort = np.argsort(inds)
991
    uniques = uniques[inds_sort]
992 0
    counts = counts[inds_sort]
993 0
    return uniques, counts
994 0

995

996 0
# assume that the graph has already been plotted in sorted form
997
def _plot_groups(ax, graph, inner_labels, outer_labels=None, fontsize=30):
998
    inner_labels = np.array(inner_labels)
999 0
    plot_outer = True
1000 0
    if outer_labels is None:
1001
        outer_labels = np.ones_like(inner_labels)
1002 0
        plot_outer = False
1003 0

1004
    sorted_inds = _sort_inds(graph, inner_labels, outer_labels, False)
1005 0
    inner_labels = inner_labels[sorted_inds]
1006 0
    outer_labels = outer_labels[sorted_inds]
1007 0

1008
    inner_freq, inner_freq_cumsum, outer_freq, outer_freq_cumsum = _get_freqs(
1009
        inner_labels, outer_labels
1010 0
    )
1011 0
    inner_unique, _ = _unique_like(inner_labels)
1012 0
    outer_unique, _ = _unique_like(outer_labels)
1013 0

1014 0
    n_verts = graph.shape[0]
1015 0
    axline_kws = dict(linestyle="dashed", lw=0.9, alpha=0.3, zorder=3, color="grey")
1016 0
    # draw lines
1017
    for x in inner_freq_cumsum[1:-1]:
1018
        ax.vlines(x, 0, n_verts + 1, **axline_kws)
1019 0
        ax.hlines(x, 0, n_verts + 1, **axline_kws)
1020 0

1021 0
    # add specific lines for the borders of the plot
1022
    pad = 0.001
1023 0
    low = pad
1024
    high = 1 - pad
1025
    ax.plot((low, low), (low, high), transform=ax.transAxes, **axline_kws)
1026 0
    ax.plot((low, high), (low, low), transform=ax.transAxes, **axline_kws)
1027 0
    ax.plot((high, high), (low, high), transform=ax.transAxes, **axline_kws)
1028
    ax.plot((low, high), (high, high), transform=ax.transAxes, **axline_kws)
1029 0

1030 0
    # generic curve that we will use for everything
1031
    lx = np.linspace(-np.pi / 2.0 + 0.05, np.pi / 2.0 - 0.05, 500)
1032
    tan = np.tan(lx)
1033 0
    curve = np.hstack((tan[::-1], tan))
1034 0

1035 0
    divider = make_axes_locatable(ax)
1036

1037
    # inner curve generation
1038
    inner_tick_loc = inner_freq.cumsum() - inner_freq / 2
1039
    inner_tick_width = inner_freq / 2
1040
    # outer curve generation
1041
    outer_tick_loc = outer_freq.cumsum() - outer_freq / 2
1042
    outer_tick_width = outer_freq / 2
1043

1044
    # top inner curves
1045
    ax_x = divider.new_vertical(size="5%", pad=0.0, pack_start=False)
1046
    ax.figure.add_axes(ax_x)
1047 0
    _plot_brackets(
1048 0
        ax_x,
1049 0
        np.tile(inner_unique, len(outer_unique)),
1050
        inner_tick_loc,
1051
        inner_tick_width,
1052
        curve,
1053
        "inner",
1054
        "x",
1055
        n_verts,
1056
        fontsize,
1057
    )
1058
    # side inner curves
1059
    ax_y = divider.new_horizontal(size="5%", pad=0.0, pack_start=True)
1060
    ax.figure.add_axes(ax_y)
1061 0
    _plot_brackets(
1062
        ax_y,
1063 0
        np.tile(inner_unique, len(outer_unique)),
1064 0
        inner_tick_loc,
1065 0
        inner_tick_width,
1066 0
        curve,
1067
        "inner",
1068
        "y",
1069
        n_verts,
1070
        fontsize,
1071
    )
1072

1073
    if plot_outer:
1074
        # top outer curves
1075
        pad_scalar = 0.35 / 30 * fontsize
1076
        ax_x2 = divider.new_vertical(size="5%", pad=pad_scalar, pack_start=False)
1077
        ax.figure.add_axes(ax_x2)
1078 0
        _plot_brackets(
1079 0
            ax_x2,
1080 0
            outer_unique,
1081
            outer_tick_loc,
1082
            outer_tick_width,
1083
            curve,
1084
            "outer",
1085
            "x",
1086
            n_verts,
1087
            fontsize,
1088
        )
1089
        # side outer curves
1090
        ax_y2 = divider.new_horizontal(size="5%", pad=pad_scalar, pack_start=True)
1091 0
        ax.figure.add_axes(ax_y2)
1092
        _plot_brackets(
1093
            ax_y2,
1094 2
            outer_unique,
1095
            outer_tick_loc,
1096
            outer_tick_width,
1097 0
            curve,
1098 0
            "outer",
1099 0
            "y",
1100 0
            n_verts,
1101 0
            fontsize,
1102 0
        )
1103 0
    return ax
1104 0

1105 0

1106 0
def _plot_brackets(
1107 0
    ax, group_names, tick_loc, tick_width, curve, level, axis, max_size, fontsize
1108 0
):
1109 0
    for x0, width in zip(tick_loc, tick_width):
1110 0
        x = np.linspace(x0 - width, x0 + width, 1000)
1111 0
        if axis == "x":
1112 0
            ax.plot(x, -curve, c="k")
1113 0
            ax.patch.set_alpha(0)
1114 0
        elif axis == "y":
1115 0
            ax.plot(curve, x, c="k")
1116 0
            ax.patch.set_alpha(0)
1117 0
    ax.set_yticks([])
1118 0
    ax.set_xticks([])
1119 0
    ax.tick_params(axis=axis, which=u"both", length=0, pad=7)
1120 0
    for direction in ["left", "right", "bottom", "top"]:
1121 0
        ax.spines[direction].set_visible(False)
1122 0
    if axis == "x":
1123
        ax.set_xticks(tick_loc)
1124
        ax.set_xticklabels(group_names, fontsize=fontsize, verticalalignment="center")
1125
        ax.xaxis.set_label_position("top")
1126
        ax.xaxis.tick_top()
1127
        ax.xaxis.labelpad = 30
1128
        ax.set_xlim(0, max_size)
1129
        ax.tick_params(axis="x", which="major", pad=5 + fontsize / 4)
1130
    elif axis == "y":
1131
        ax.set_yticks(tick_loc)
1132
        ax.set_yticklabels(group_names, fontsize=fontsize, verticalalignment="center")
1133
        ax.set_ylim(0, max_size)
1134
        ax.invert_yaxis()

Read our documentation on viewing source code .

Loading