1 1
import logging
2 1
from collections import defaultdict
3 1
from typing import DefaultDict, Dict, List, Tuple
4

5 1
import numpy as np
6

7 1
from snorkel.utils import to_int_label_array
8

9

10 1
def get_label_buckets(*y: np.ndarray) -> Dict[Tuple[int, ...], np.ndarray]:
11
    """Return data point indices bucketed by label combinations.
12

13
    Parameters
14
    ----------
15
    *y
16
        A list of np.ndarray of (int) labels
17

18
    Returns
19
    -------
20
    Dict[Tuple[int, ...], np.ndarray]
21
        A mapping of each label bucket to a NumPy array of its corresponding indices
22

23
    Example
24
    -------
25
    A common use case is calling ``buckets = label_buckets(Y_gold, Y_pred)`` where
26
    ``Y_gold`` is a set of gold (i.e. ground truth) labels and ``Y_pred`` is a
27
    corresponding set of predicted labels.
28

29
    >>> Y_gold = np.array([1, 1, 1, 0])
30
    >>> Y_pred = np.array([1, 1, -1, -1])
31
    >>> buckets = get_label_buckets(Y_gold, Y_pred)
32

33
    The returned ``buckets[(i, j)]`` is a NumPy array of data point indices with
34
    true label i and predicted label j.
35

36
    More generally, the returned indices within each bucket refer to the order of the
37
    labels that were passed in as function arguments.
38

39
    >>> buckets[(1, 1)]  # true positives
40
    array([0, 1])
41
    >>> (1, 0) in buckets  # false positives
42
    False
43
    >>> (0, 1) in buckets  # false negatives
44
    False
45
    >>> (0, 0) in buckets  # true negatives
46
    False
47
    >>> buckets[(1, -1)]  # abstained positives
48
    array([2])
49
    >>> buckets[(0, -1)]  # abstained negatives
50
    array([3])
51
    """
52 1
    buckets: DefaultDict[Tuple[int, int], List[int]] = defaultdict(list)
53 1
    y_flat = list(map(lambda x: to_int_label_array(x, flatten_vector=True), y))
54 1
    if len(set(map(len, y_flat))) != 1:
55 1
        raise ValueError("Arrays must all have the same number of elements")
56 1
    for i, labels in enumerate(zip(*y_flat)):
57 1
        buckets[labels].append(i)
58 1
    return {k: np.array(v) for k, v in buckets.items()}
59

60

61 1
def get_label_instances(
62
    bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray
63
) -> np.ndarray:
64
    """Return instances in x with the specified combination of labels.
65

66
    Parameters
67
    ----------
68
    bucket
69
        A tuple of label values corresponding to which instances from x are returned
70
    x
71
        NumPy array of data instances to be returned
72
    *y
73
        A list of np.ndarray of (int) labels
74

75
    Returns
76
    -------
77
    np.ndarray
78
        NumPy array of instances from x with the specified combination of labels
79

80
    Example
81
    -------
82
    A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)``
83
    where ``x`` is a NumPy array of data instances that the labels correspond to,
84
    ``Y_gold`` is a list of gold (i.e. ground truth) labels, and
85
    ``Y_pred`` is a corresponding list of predicted labels.
86

87
    >>> import pandas as pd
88
    >>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]})
89
    >>> Y_gold = np.array([1, 1, 1])
90
    >>> Y_pred = np.array([1, 0, 0])
91
    >>> bucket = (1, 0)
92

93
    The returned NumPy array of data instances from ``x`` will correspond to
94
    the rows where the first list had a 1 and the second list had a 0.
95
    >>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)
96
    array([['a second string', '2'],
97
           ['a third string', '3']], dtype=object)
98

99
    More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...``
100
    the returned data instances from ``x`` will correspond to the rows where
101
    y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y``
102
    must all be the same length.
103
    """
104 1
    if len(y) != len(bucket):
105 1
        raise ValueError("Number of lists must match the amount of labels in bucket")
106 1
    if x.shape[0] != len(y[0]):
107
        # Note: the check for all y having the same number of elements occurs in get_label_buckets
108 1
        raise ValueError(
109
            "Number of rows in x does not match number of elements in at least one label list"
110
        )
111 1
    buckets = get_label_buckets(*y)
112 1
    try:
113 1
        indices = buckets[bucket]
114 1
    except KeyError:
115 1
        logging.warning("Bucket" + str(bucket) + " does not exist.")
116 1
        return np.array([])
117 1
    instances = x[indices]
118 1
    return instances

Read our documentation on viewing source code .

Loading