microsoft / graspologic
1
# Copyright 2019 NeuroData (http://neurodata.io)
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4 2
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6 2
#
7 2
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 2
# Unless required by applicable law or agreed to in writing, software
10 2
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13 2
# limitations under the License.
14

15
from abc import abstractmethod
16

17
import numpy as np
18
from sklearn.base import BaseEstimator
19

20
from ..utils import import_graph, is_almost_symmetric
21
from .svd import selectSVD
22

23

24
class BaseEmbed(BaseEstimator):
25
    """
26
    A base class for embedding a graph.
27

28
    Parameters
29
    ----------
30
    n_components : int or None, default = None
31
        Desired dimensionality of output data. If "full",
32
        n_components must be <= min(X.shape). Otherwise, n_components must be
33
        < min(X.shape). If None, then optimal dimensions will be chosen by
34
        ``select_dimension`` using ``n_elbows`` argument.
35
    n_elbows : int, optional, default: 2
36
        If `n_compoents=None`, then compute the optimal embedding dimension using
37
        `select_dimension`. Otherwise, ignored.
38
    algorithm : {'full', 'truncated' (default), 'randomized'}, optional
39
        SVD solver to use:
40

41
        - 'full'
42
            Computes full svd using ``scipy.linalg.svd``
43
        - 'truncated'
44
            Computes truncated svd using ``scipy.sparse.linalg.svd``
45
        - 'randomized'
46
            Computes randomized svd using
47
            ``sklearn.utils.extmath.randomized_svd``
48
    n_iter : int, optional (default = 5)
49
        Number of iterations for randomized SVD solver. Not used by 'full' or
50
        'truncated'. The default is larger than the default in randomized_svd
51
        to handle sparse matrices that may have large slowly decaying spectrum.
52
    check_lcc : bool , optional (defult =True)
53
        Whether to check if input graph is connected. May result in non-optimal
54
        results if the graph is unconnected. Not checking for connectedness may
55
        result in faster computation.
56 2

57
    Attributes
58
    ----------
59
    n_components_ : int
60
        Dimensionality of the embedded space.
61

62
    See Also
63
    --------
64 2
    graspy.embed.selectSVD, graspy.embed.select_dimension
65 2
    """
66 2

67 2
    def __init__(
68 2
        self,
69
        n_components=None,
70 2
        n_elbows=2,
71
        algorithm="randomized",
72
        n_iter=5,
73
        check_lcc=True,
74
    ):
75
        self.n_components = n_components
76
        self.n_elbows = n_elbows
77
        self.algorithm = algorithm
78
        self.n_iter = n_iter
79
        self.check_lcc = check_lcc
80 2

81
    def _reduce_dim(self, A):
82
        """
83
        A function that reduces the dimensionality of an adjacency matrix
84
        using the desired embedding method.
85

86
        Parameters
87
        ----------
88 2
        A: array-like, shape (n_vertices, n_vertices)
89 2
            Adjacency matrix to embed.
90 2
        """
91 2
        U, D, V = selectSVD(
92 2
            A,
93
            n_components=self.n_components,
94 2
            n_elbows=self.n_elbows,
95
            algorithm=self.algorithm,
96 2
            n_iter=self.n_iter,
97
        )
98

99 2
        self.n_components_ = D.size
100
        self.singular_values_ = D
101
        self.latent_left_ = U @ np.diag(np.sqrt(D))
102
        if not is_almost_symmetric(A):
103
            self.latent_right_ = V.T @ np.diag(np.sqrt(D))
104
        else:
105
            self.latent_right_ = None
106

107
    @property
108
    def _pairwise(self):
109
        """This is for sklearn compliance."""
110
        return True
111

112
    @abstractmethod
113
    def fit(self, graph, y=None):
114
        """
115
        A method for embedding.
116

117
        Parameters
118
        ----------
119
        graph: np.ndarray or networkx.Graph
120

121
        y : Ignored
122

123
        Returns
124
        -------
125
        lpm : LatentPosition object
126
            Contains X (the estimated latent positions), Y (same as X if input is
127
            undirected graph, or right estimated positions if directed graph), and d.
128 2

129
        See Also
130 2
        --------
131
        import_graph, LatentPosition
132 2
        """
133 2
        # call self._reduce_dim(A) from your respective embedding technique.
134
        # import graph(s) to an adjacency matrix using import_graph function
135 2
        # here
136

137 2
        return self
138

139
    def _fit_transform(self, graph):
140
        "Fits the model and returns the estimated latent positions"
141
        self.fit(graph)
142

143
        if self.latent_right_ is None:
144
            return self.latent_left_
145
        else:
146
            return self.latent_left_, self.latent_right_
147

148
    def fit_transform(self, graph, y=None):
149
        """
150
        Fit the model with graphs and apply the transformation.
151

152
        n_dimension is either automatically determined or based on user input.
153

154
        Parameters
155
        ----------
156 2
        graph: np.ndarray or networkx.Graph
157
            Input graph to embed.
158

159 2
        Returns
160 2
        -------
161
        out : np.ndarray, shape (n_vertices, n_dimension) OR tuple (len 2)
162
            Where both elements have shape (n_vertices, n_dimension)
163
            A single np.ndarray represents the latent position of an undirected
164
            graph, wheras a tuple represents the left and right latent positions
165
            for a directed graph
166
        """
167
        return self._fit_transform(graph)
168

169 2

170
class BaseEmbedMulti(BaseEmbed):
171
    def __init__(
172
        self,
173
        n_components=None,
174
        n_elbows=2,
175
        algorithm="randomized",
176
        n_iter=5,
177 2
        check_lcc=True,
178 2
    ):
179 2
        super().__init__(
180
            n_components=n_components,
181 2
            n_elbows=n_elbows,
182
            algorithm=algorithm,
183
            n_iter=n_iter,
184
            check_lcc=check_lcc,
185
        )
186

187
    def _check_input_graphs(self, graphs):
188
        """
189
        Checks if all graphs in list have same shapes.
190

191
        Raises an ValueError if there are more than one shape in the input list,
192
        or if the list is empty or has one element.
193

194
        Parameters
195
        ----------
196
        graphs : list of nx.Graph or ndarray, or ndarray
197
            If list of nx.Graph, each Graph must contain same number of nodes.
198
            If list of ndarray, each array must have shape (n_vertices, n_vertices).
199
            If ndarray, then array must have shape (n_graphs, n_vertices, n_vertices).
200

201
        Returns
202
        -------
203
        out : ndarray, shape (n_graphs, n_vertices, n_vertices)
204

205
        Raises
206
        ------
207 2
        ValueError
208 2
            If all graphs do not have same shape, or input list is empty or has
209 2
            one element.
210
        """
211
        # Convert input to np.arrays
212 2
        # This check is needed because np.stack will always duplicate array in memory.
213 2
        if isinstance(graphs, (list, tuple)):
214 2
            if len(graphs) <= 1:
215 2
                msg = "Input {} must have at least 2 graphs, not {}.".format(
216 2
                    type(graphs), len(graphs)
217
                )
218
                raise ValueError(msg)
219 2
            out = [import_graph(g, copy=False) for g in graphs]
220 2
        elif isinstance(graphs, np.ndarray):
221 2
            if graphs.ndim != 3:
222
                msg = "Input tensor must be 3-dimensional, not {}-dimensional.".format(
223
                    graphs.ndim
224 2
                )
225 2
                raise ValueError(msg)
226
            elif graphs.shape[0] <= 1:
227 0
                msg = "Input tensor must have at least 2 elements, not {}.".format(
228 0
                    graphs.shape[0]
229
                )
230
                raise ValueError(msg)
231 2
            out = import_graph(graphs, copy=False)
232 2
        else:
233
            msg = "Input must be a list or ndarray, not {}.".format(type(graphs))
234 2
            raise TypeError(msg)
235

236 2
        # Save attributes
237
        self.n_graphs_ = len(out)
238
        self.n_vertices_ = out[0].shape[0]
239

240
        return out

Read our documentation on viewing source code .

Loading