1 1 ```import logging ``` 2 1 ```from collections import defaultdict ``` 3 1 ```from typing import DefaultDict, Dict, List, Tuple ``` 4 5 1 ```import numpy as np ``` 6 7 1 ```from snorkel.utils import to_int_label_array ``` 8 9 10 1 ```def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]: ``` 11 ``` """Return data point indices bucketed by label combinations. ``` 12 13 ``` Parameters ``` 14 ``` ---------- ``` 15 ``` *y ``` 16 ``` A list of np.ndarray of (int) labels ``` 17 18 ``` Returns ``` 19 ``` ------- ``` 20 ``` Dict[Tuple[int, ...], np.ndarray] ``` 21 ``` A mapping of each label bucket to a NumPy array of its corresponding indices ``` 22 23 ``` Example ``` 24 ``` ------- ``` 25 ``` A common use case is calling ``buckets = label_buckets(Y_gold, Y_pred)`` where ``` 26 ``` ``Y_gold`` is a set of gold (i.e. ground truth) labels and ``Y_pred`` is a ``` 27 ``` corresponding set of predicted labels. ``` 28 29 ``` >>> Y_gold = np.array([1, 1, 1, 0]) ``` 30 ``` >>> Y_pred = np.array([1, 1, -1, -1]) ``` 31 ``` >>> buckets = get_label_buckets(Y_gold, Y_pred) ``` 32 33 ``` The returned ``buckets[(i, j)]`` is a NumPy array of data point indices with ``` 34 ``` true label i and predicted label j. ``` 35 36 ``` More generally, the returned indices within each bucket refer to the order of the ``` 37 ``` labels that were passed in as function arguments. ``` 38 39 ``` >>> buckets[(1, 1)] # true positives ``` 40 ``` array([0, 1]) ``` 41 ``` >>> (1, 0) in buckets # false positives ``` 42 ``` False ``` 43 ``` >>> (0, 1) in buckets # false negatives ``` 44 ``` False ``` 45 ``` >>> (0, 0) in buckets # true negatives ``` 46 ``` False ``` 47 ``` >>> buckets[(1, -1)] # abstained positives ``` 48 ``` array() ``` 49 ``` >>> buckets[(0, -1)] # abstained negatives ``` 50 ``` array() ``` 51 ``` """ ``` 52 1 ``` buckets: DefaultDict[Tuple[int, int], List[int]] = defaultdict(list) ``` 53 1 ``` y_flat = list(map(lambda x: to_int_label_array(x, flatten_vector=True), y)) ``` 54 1 ``` if len(set(map(len, y_flat))) != 1: ``` 55 1 ``` raise ValueError("Arrays must all have the same number of elements") ``` 56 1 ``` for i, labels in enumerate(zip(*y_flat)): ``` 57 1 ``` buckets[labels].append(i) ``` 58 1 ``` return {k: np.array(v) for k, v in buckets.items()} ``` 59 60 61 1 ```def get_label_instances( ``` 62 ``` bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray ``` 63 ```) -> np.ndarray: ``` 64 ``` """Return instances in x with the specified combination of labels. ``` 65 66 ``` Parameters ``` 67 ``` ---------- ``` 68 ``` bucket ``` 69 ``` A tuple of label values corresponding to which instances from x are returned ``` 70 ``` x ``` 71 ``` NumPy array of data instances to be returned ``` 72 ``` *y ``` 73 ``` A list of np.ndarray of (int) labels ``` 74 75 ``` Returns ``` 76 ``` ------- ``` 77 ``` np.ndarray ``` 78 ``` NumPy array of instances from x with the specified combination of labels ``` 79 80 ``` Example ``` 81 ``` ------- ``` 82 ``` A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)`` ``` 83 ``` where ``x`` is a NumPy array of data instances that the labels correspond to, ``` 84 ``` ``Y_gold`` is a list of gold (i.e. ground truth) labels, and ``` 85 ``` ``Y_pred`` is a corresponding list of predicted labels. ``` 86 87 ``` >>> import pandas as pd ``` 88 ``` >>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]}) ``` 89 ``` >>> Y_gold = np.array([1, 1, 1]) ``` 90 ``` >>> Y_pred = np.array([1, 0, 0]) ``` 91 ``` >>> bucket = (1, 0) ``` 92 93 ``` The returned NumPy array of data instances from ``x`` will correspond to ``` 94 ``` the rows where the first list had a 1 and the second list had a 0. ``` 95 ``` >>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred) ``` 96 ``` array([['a second string', '2'], ``` 97 ``` ['a third string', '3']], dtype=object) ``` 98 99 ``` More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...`` ``` 100 ``` the returned data instances from ``x`` will correspond to the rows where ``` 101 ``` y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y`` ``` 102 ``` must all be the same length. ``` 103 ``` """ ``` 104 1 ``` if len(y) != len(bucket): ``` 105 1 ``` raise ValueError("Number of lists must match the amount of labels in bucket") ``` 106 1 ``` if x.shape != len(y): ``` 107 ``` # Note: the check for all y having the same number of elements occurs in get_label_buckets ``` 108 1 ``` raise ValueError( ``` 109 ``` "Number of rows in x does not match number of elements in at least one label list" ``` 110 ``` ) ``` 111 1 ``` buckets = get_label_buckets(*y) ``` 112 1 ``` try: ``` 113 1 ``` indices = buckets[bucket] ``` 114 1 ``` except KeyError: ``` 115 1 ``` logging.warning("Bucket" + str(bucket) + " does not exist.") ``` 116 1 ``` return np.array([]) ``` 117 1 ``` instances = x[indices] ``` 118 1 ``` return instances ```

Read our documentation on viewing source code .