1 2
import hashlib
2 2
from typing import Dict, List
3

4 2
import numpy as np
5

6

7 2
def _hash(i: int) -> int:
8
    """Deterministic hash function."""
9 2
    byte_string = str(i).encode("utf-8")
10 2
    return int(hashlib.sha1(byte_string).hexdigest(), 16)
11

12

13 2
def probs_to_preds(
14
    probs: np.ndarray, tie_break_policy: str = "random", tol: float = 1e-5
15
) -> np.ndarray:
16
    """Convert an array of probabilistic labels into an array of predictions.
17

18
    Policies to break ties include:
19
    "abstain": return an abstain vote (-1)
20
    "true-random": randomly choose among the tied options
21
    "random": randomly choose among tied option using deterministic hash
22

23
    NOTE: if tie_break_policy="true-random", repeated runs may have slightly different results due to difference in broken ties
24

25
    Parameters
26
    ----------
27
    prob
28
        A [num_datapoints, num_classes] array of probabilistic labels such that each
29
        row sums to 1.
30
    tie_break_policy
31
        Policy to break ties when converting probabilistic labels to predictions
32
    tol
33
        The minimum difference among probabilities to be considered a tie
34

35
    Returns
36
    -------
37
    np.ndarray
38
        A [n] array of predictions (integers in [0, ..., num_classes - 1])
39

40
    Examples
41
    --------
42
    >>> probs_to_preds(np.array([[0.5, 0.5, 0.5]]), tie_break_policy="abstain")
43
    array([-1])
44
    >>> probs_to_preds(np.array([[0.8, 0.1, 0.1]]))
45
    array([0])
46
    """
47 2
    num_datapoints, num_classes = probs.shape
48 2
    if num_classes <= 1:
49 2
        raise ValueError(
50
            f"probs must have probabilities for at least 2 classes. "
51
            f"Instead, got {num_classes} classes."
52
        )
53

54 2
    Y_pred = np.empty(num_datapoints)
55 2
    diffs = np.abs(probs - probs.max(axis=1).reshape(-1, 1))
56

57 2
    for i in range(num_datapoints):
58 2
        max_idxs = np.where(diffs[i, :] < tol)[0]
59 2
        if len(max_idxs) == 1:
60 2
            Y_pred[i] = max_idxs[0]
61
        # Deal with "tie votes" according to the specified policy
62 2
        elif tie_break_policy == "random":
63 2
            Y_pred[i] = max_idxs[_hash(i) % len(max_idxs)]
64 2
        elif tie_break_policy == "true-random":
65 2
            Y_pred[i] = np.random.choice(max_idxs)
66 2
        elif tie_break_policy == "abstain":
67 2
            Y_pred[i] = -1
68
        else:
69 2
            raise ValueError(
70
                f"tie_break_policy={tie_break_policy} policy not recognized."
71
            )
72 2
    return Y_pred.astype(np.int)
73

74

75 2
def preds_to_probs(preds: np.ndarray, num_classes: int) -> np.ndarray:
76
    """Convert an array of predictions into an array of probabilistic labels.
77

78
    Parameters
79
    ----------
80
    pred
81
        A [num_datapoints] or [num_datapoints, 1] array of predictions
82

83
    Returns
84
    -------
85
    np.ndarray
86
        A [num_datapoints, num_classes] array of probabilistic labels with probability
87
        of 1.0 in the column corresponding to the prediction
88
    """
89 2
    if np.any(preds < 0):
90 0
        raise ValueError("Could not convert abstained vote to probability")
91 2
    return np.eye(num_classes)[preds.squeeze()]
92

93

94 2
def to_int_label_array(X: np.ndarray, flatten_vector: bool = True) -> np.ndarray:
95
    """Convert an array to a (possibly flattened) array of ints.
96

97
    Cast all values to ints and possibly flatten [n, 1] arrays to [n].
98
    This method is typically used to sanitize labels before use with analysis tools or
99
    metrics that expect 1D arrays as inputs.
100

101
    Parameters
102
    ----------
103
    X
104
        An array to possibly flatten and possibly cast to int
105
    flatten_vector
106
        If True, flatten array into a 1D array
107

108
    Returns
109
    -------
110
    np.ndarray
111
        The converted array
112

113
    Raises
114
    ------
115
    ValueError
116
        Provided input could not be converted to an np.ndarray
117
    """
118 2
    if np.any(np.not_equal(np.mod(X, 1), 0)):
119 2
        raise ValueError("Input contains at least one non-integer value.")
120 2
    X = X.astype(np.dtype(int))
121
    # Correct shape
122 2
    if flatten_vector:
123 2
        X = X.squeeze()
124 2
        if X.ndim == 0:
125 2
            X = np.expand_dims(X, 0)
126 2
        if X.ndim != 1:
127 2
            raise ValueError("Input could not be converted to 1d np.array")
128 2
    return X
129

130

131 2
def filter_labels(
132
    label_dict: Dict[str, np.ndarray], filter_dict: Dict[str, List[int]]
133
) -> Dict[str, np.ndarray]:
134
    """Filter out examples from arrays based on specified labels to filter.
135

136
    The most common use of this method is to remove examples whose gold label is
137
    unknown (marked with a -1) or examples whose predictions were abstains (also -1)
138
    before calculating metrics.
139

140
    NB: If an example matches the filter criteria for any label set, it will be removed
141
    from all label sets (so that the returned arrays are of the same size and still
142
    aligned).
143

144
    Parameters
145
    ----------
146
    label_dict
147
        A mapping from label set name to the array of labels
148
        The arrays in a label_dict.values() are assumed to be aligned
149
    filter_dict
150
        A mapping from label set name to the labels that should be filtered out for
151
        that label set
152

153
    Returns
154
    -------
155
    Dict[str, np.ndarray]
156
        A mapping with the same keys as label_dict but with filtered arrays as values
157

158
    Example
159
    -------
160
    >>> golds = np.array([-1, 0, 0, 1, 0])
161
    >>> preds = np.array([0, 0, 0, 1, -1])
162
    >>> filtered = filter_labels(
163
    ...     label_dict={"golds": golds, "preds": preds},
164
    ...     filter_dict={"golds": [-1], "preds": [-1]}
165
    ... )
166
    >>> filtered["golds"]
167
    array([0, 0, 1])
168
    >>> filtered["preds"]
169
    array([0, 0, 1])
170
    """
171 2
    masks = []
172 2
    for label_name, filter_values in filter_dict.items():
173 2
        if label_dict[label_name] is not None:
174 2
            masks.append(_get_mask(label_dict[label_name], filter_values))
175 2
    mask = (np.multiply(*masks) if len(masks) > 1 else masks[0]).squeeze()
176

177 2
    filtered = {}
178 2
    for label_name, label_array in label_dict.items():
179 2
        filtered[label_name] = label_array[mask] if label_array is not None else None
180 2
    return filtered
181

182

183 2
def _get_mask(label_array: np.ndarray, filter_values: List[int]) -> np.ndarray:
184
    """Return a boolean mask marking which labels are not in filter_values.
185

186
    Parameters
187
    ----------
188
    label_array
189
        An array of labels
190
    filter_values
191
        A list of values that should be filtered out of the label array
192

193
    Returns
194
    -------
195
    np.ndarray
196
        A boolean mask indicating whether to keep (1) or filter (0) each example
197
    """
198 2
    mask = np.ones_like(label_array).astype(bool)
199 2
    for value in filter_values:
200 2
        mask *= np.where(label_array != value, 1, 0).astype(bool)
201 2
    return mask

Read our documentation on viewing source code .

Loading