1 2
from functools import partial
2 2
from typing import Callable, Dict, List, Mapping, Optional, Union
3

4 2
import numpy as np
5 2
import pandas as pd
6

7 2
from snorkel.analysis.metrics import METRICS, metric_score
8

9

10 2
class Scorer:
11
    """Calculate one or more scores from user-specified and/or user-defined metrics.
12

13
    Parameters
14
    ----------
15
    metrics
16
        A list of metric names, all of which are defined in METRICS
17
    custom_metric_funcs
18
        An optional dictionary mapping the names of custom metrics to the functions
19
        that produce them. Each custom metric function should accept golds, preds, and
20
        probs as input (just like the standard metrics in METRICS) and return either a
21
        single score (float) or a dictionary of metric names to scores (if the function
22
        calculates multiple values, for example). See the unit tests for an example.
23
    abstain_label
24
        The gold label for which examples will be ignored. By default, follow convention
25
        that abstains are -1.
26

27
    Raises
28
    ------
29
    ValueError
30
        If a specified standard metric is not found in the METRICS dictionary
31

32
    Attributes
33
    ----------
34
    metrics
35
        A dictionary mapping metric names to the corresponding functions for calculating
36
        that metric
37
    """
38

39 2
    def __init__(
40
        self,
41
        metrics: Optional[List[str]] = None,
42
        custom_metric_funcs: Optional[Mapping[str, Callable[..., float]]] = None,
43
        abstain_label: Optional[int] = -1,
44
    ) -> None:
45

46 2
        self.metrics: Dict[str, Callable[..., float]]
47 2
        self.metrics = {}
48 2
        if metrics:
49 2
            for metric in metrics:
50 2
                if metric not in METRICS:
51 2
                    raise ValueError(f"Unrecognized metric: {metric}")
52

53 2
                filter_dict = (
54
                    {}
55
                    if abstain_label is None or metric == "coverage"
56
                    else {"golds": [abstain_label], "preds": [abstain_label]}
57
                )
58 2
                self.metrics.update(
59
                    {
60
                        metric: partial(
61
                            metric_score, metric=metric, filter_dict=filter_dict
62
                        )
63
                    }
64
                )
65

66 2
        if custom_metric_funcs is not None:
67 2
            self.metrics.update(custom_metric_funcs)
68

69 2
        self.abstain_label = abstain_label
70

71 2
    def score(
72
        self,
73
        golds: np.ndarray,
74
        preds: Optional[np.ndarray] = None,
75
        probs: Optional[np.ndarray] = None,
76
    ) -> Dict[str, float]:
77
        """Calculate scores for one or more user-specified metrics.
78

79
        Parameters
80
        ----------
81
        golds
82
            An array of gold (int) labels to base scores on
83
        preds
84
            An [n_datapoints,] or [n_datapoints, 1] array of (int) predictions to score
85
        probs
86
            An [n_datapoints, n_classes] array of probabilistic (float) predictions
87

88
        Because most metrics require either `preds` or `probs`, but not both, these
89
        values are optional; it is up to the metric function that will be called  to
90
        raise an exception if a field it requires is not passed to the `score()` method.
91

92
        Returns
93
        -------
94
        Dict[str, float]
95
            A dictionary mapping metric names to metric scores
96

97
        Raises
98
        ------
99
        ValueError
100
            If no gold labels were provided
101
        """
102 2
        if len(golds) == 0:  # type: ignore
103 2
            raise ValueError("Cannot score empty labels")
104

105 2
        metric_dict = dict()
106

107 2
        for metric_name, metric in self.metrics.items():
108 2
            score = metric(golds, preds, probs)
109 2
            if isinstance(score, dict):
110 2
                metric_dict.update(score)
111
            else:
112 2
                metric_dict[metric_name] = score
113

114 2
        return metric_dict
115

116 2
    def score_slices(
117
        self,
118
        S: np.recarray,
119
        golds: np.ndarray,
120
        preds: np.ndarray,
121
        probs: np.ndarray,
122
        as_dataframe: bool = False,
123
    ) -> Union[Dict[str, Dict[str, float]], pd.DataFrame]:
124
        """Calculate user-specified and/or user-defined metrics overall + slices.
125

126
        Parameters
127
        ----------
128
        S
129
            A recarray with entries of length n_examples corresponding to slice names
130
        golds
131
            Gold (aka ground truth) labels (integers)
132
        preds
133
            Predictions (integers)
134
        probs:
135
            Probabilities (floats)
136
        as_dataframe
137
            A boolean indicating whether to return results as pandas ``DataFrame`` (True)
138
            or dict (False)
139

140
        Returns
141
        -------
142
        Union[Dict, pd.DataFrame]
143
            A dictionary mapping slice_name to metric names to metric scores
144
            or metrics formatted as pandas ``DataFrame``
145
        """
146

147 2
        correct_shapes = S.shape[0] == len(golds) == len(preds) == len(probs)
148 2
        if not correct_shapes:
149 2
            raise ValueError(
150
                "S, golds, preds, and probs must have the same number of elements"
151
            )
152

153
        # Include overall metrics
154 2
        metrics_dict = {"overall": self.score(golds, preds, probs)}
155

156
        # Include slice metrics
157 2
        slice_names = S.dtype.names
158 2
        for slice_name in slice_names:
159 2
            mask = S[slice_name].astype(bool)
160 2
            metrics_dict[slice_name] = self.score(golds[mask], preds[mask], probs[mask])
161

162 2
        if as_dataframe:
163 2
            return pd.DataFrame.from_dict(metrics_dict).transpose()
164

165 2
        return metrics_dict

Read our documentation on viewing source code .

Loading