1 2 from functools import partial 2 2 from typing import Callable, Dict, List, Mapping, Optional, Union 3 4 2 import numpy as np 5 2 import pandas as pd 6 7 2 from snorkel.analysis.metrics import METRICS, metric_score 8 9 10 2 class Scorer: 11 """Calculate one or more scores from user-specified and/or user-defined metrics. 12 13 Parameters 14 ---------- 15 metrics 16 A list of metric names, all of which are defined in METRICS 17 custom_metric_funcs 18 An optional dictionary mapping the names of custom metrics to the functions 19 that produce them. Each custom metric function should accept golds, preds, and 20 probs as input (just like the standard metrics in METRICS) and return either a 21 single score (float) or a dictionary of metric names to scores (if the function 22 calculates multiple values, for example). See the unit tests for an example. 23 abstain_label 24 The gold label for which examples will be ignored. By default, follow convention 25 that abstains are -1. 26 27 Raises 28 ------ 29 ValueError 30 If a specified standard metric is not found in the METRICS dictionary 31 32 Attributes 33 ---------- 34 metrics 35 A dictionary mapping metric names to the corresponding functions for calculating 36 that metric 37 """ 38 39 2 def __init__( 40 self, 41 metrics: Optional[List[str]] = None, 42 custom_metric_funcs: Optional[Mapping[str, Callable[..., float]]] = None, 43 abstain_label: Optional[int] = -1, 44 ) -> None: 45 46 2 self.metrics: Dict[str, Callable[..., float]] 47 2 self.metrics = {} 48 2 if metrics: 49 2 for metric in metrics: 50 2 if metric not in METRICS: 51 2 raise ValueError(f"Unrecognized metric: {metric}") 52 53 2 filter_dict = ( 54 {} 55 if abstain_label is None or metric == "coverage" 56 else {"golds": [abstain_label], "preds": [abstain_label]} 57 ) 58 2 self.metrics.update( 59 { 60 metric: partial( 61 metric_score, metric=metric, filter_dict=filter_dict 62 ) 63 } 64 ) 65 66 2 if custom_metric_funcs is not None: 67 2 self.metrics.update(custom_metric_funcs) 68 69 2 self.abstain_label = abstain_label 70 71 2 def score( 72 self, 73 golds: np.ndarray, 74 preds: Optional[np.ndarray] = None, 75 probs: Optional[np.ndarray] = None, 76 ) -> Dict[str, float]: 77 """Calculate scores for one or more user-specified metrics. 78 79 Parameters 80 ---------- 81 golds 82 An array of gold (int) labels to base scores on 83 preds 84 An [n_datapoints,] or [n_datapoints, 1] array of (int) predictions to score 85 probs 86 An [n_datapoints, n_classes] array of probabilistic (float) predictions 87 88 Because most metrics require either `preds` or `probs`, but not both, these 89 values are optional; it is up to the metric function that will be called to 90 raise an exception if a field it requires is not passed to the `score()` method. 91 92 Returns 93 ------- 94 Dict[str, float] 95 A dictionary mapping metric names to metric scores 96 97 Raises 98 ------ 99 ValueError 100 If no gold labels were provided 101 """ 102 2 if len(golds) == 0: # type: ignore 103 2 raise ValueError("Cannot score empty labels") 104 105 2 metric_dict = dict() 106 107 2 for metric_name, metric in self.metrics.items(): 108 2 score = metric(golds, preds, probs) 109 2 if isinstance(score, dict): 110 2 metric_dict.update(score) 111 else: 112 2 metric_dict[metric_name] = score 113 114 2 return metric_dict 115 116 2 def score_slices( 117 self, 118 S: np.recarray, 119 golds: np.ndarray, 120 preds: np.ndarray, 121 probs: np.ndarray, 122 as_dataframe: bool = False, 123 ) -> Union[Dict[str, Dict[str, float]], pd.DataFrame]: 124 """Calculate user-specified and/or user-defined metrics overall + slices. 125 126 Parameters 127 ---------- 128 S 129 A recarray with entries of length n_examples corresponding to slice names 130 golds 131 Gold (aka ground truth) labels (integers) 132 preds 133 Predictions (integers) 134 probs: 135 Probabilities (floats) 136 as_dataframe 137 A boolean indicating whether to return results as pandas ``DataFrame`` (True) 138 or dict (False) 139 140 Returns 141 ------- 142 Union[Dict, pd.DataFrame] 143 A dictionary mapping slice_name to metric names to metric scores 144 or metrics formatted as pandas ``DataFrame`` 145 """ 146 147 2 correct_shapes = S.shape[0] == len(golds) == len(preds) == len(probs) 148 2 if not correct_shapes: 149 2 raise ValueError( 150 "S, golds, preds, and probs must have the same number of elements" 151 ) 152 153 # Include overall metrics 154 2 metrics_dict = {"overall": self.score(golds, preds, probs)} 155 156 # Include slice metrics 157 2 slice_names = S.dtype.names 158 2 for slice_name in slice_names: 159 2 mask = S[slice_name].astype(bool) 160 2 metrics_dict[slice_name] = self.score(golds[mask], preds[mask], probs[mask]) 161 162 2 if as_dataframe: 163 2 return pd.DataFrame.from_dict(metrics_dict).transpose() 164 165 2 return metrics_dict

