microsoft / graspologic
1
# Copyright 2019 NeuroData (http://neurodata.io)
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4 2
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6 2
#
7 2
#     http://www.apache.org/licenses/LICENSE-2.0
8 2
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11 2
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from abc import ABC, abstractmethod
16

17
from sklearn.base import BaseEstimator, ClusterMixin
18
from sklearn.metrics import adjusted_rand_score
19
from sklearn.utils.validation import check_is_fitted
20

21

22
class BaseCluster(ABC, BaseEstimator, ClusterMixin):
23
    """
24
    Base clustering class.
25
    """
26

27
    @abstractmethod
28
    def fit(self, X, y=None):
29
        """
30
        Compute clusters based on given method.
31

32
        Parameters
33
        ----------
34
        X : array-like, shape (n_samples, n_features)
35
            List of n_features-dimensional data points. Each row
36
            corresponds to a single data point.
37

38
        y : array-like, shape (n_samples,), optional (default=None)
39
            List of labels for X if available. Used to compute
40
            ARI scores.
41

42
        Returns
43
        -------
44
        self
45
        """
46

47
    def predict(self, X, y=None):  # pragma: no cover
48
        """
49
        Predict clusters based on best model.
50

51
        Parameters
52
        ----------
53
        X : array-like, shape (n_samples, n_features)
54
            List of n_features-dimensional data points. Each row
55
            corresponds to a single data point.
56
        y : array-like, shape (n_samples, ), optional (default=None)
57
            List of labels for X if available. Used to compute
58
            ARI scores.
59

60
        Returns
61
        -------
62
        labels : array, shape (n_samples,)
63
            Component labels.
64

65
        ari : float
66
            Adjusted Rand index. Only returned if y is given.
67
        """
68
        # Check if fit is already called
69
        check_is_fitted(self, ["model_"], all_or_any=all)
70
        labels = self.model_.predict(X)
71

72
        if y is None:
73
            return labels
74
        else:
75
            ari = adjusted_rand_score(y, labels)
76
            return labels, ari
77

78
    def fit_predict(self, X, y=None):  # pragma: no cover
79
        """
80
        Fit the models and predict clusters based on best model.
81

82
        Parameters
83
        ----------
84
        X : array-like, shape (n_samples, n_features)
85
            List of n_features-dimensional data points. Each row
86
            corresponds to a single data point.
87

88
        y : array-like, shape (n_samples,), optional (default=None)
89
            List of labels for X if available. Used to compute
90
            ARI scores.
91

92
        Returns
93
        -------
94
        labels : array, shape (n_samples,)
95
            Component labels.
96

97
        ari : float
98
            Adjusted Rand index. Only returned if y is given.
99
        """
100
        self.fit(X, y)
101

102
        if y is None:
103
            labels = self.predict(X, y)
104
            return labels
105
        else:
106
            labels, ari = self.predict(X, y)
107
            return labels, ari

Read our documentation on viewing source code .

Loading