@@ -1,3 +1,4 @@
Loading
1 +
import logging
1 2
from collections import defaultdict
2 3
from typing import DefaultDict, Dict, List, Tuple
3 4
@@ -55,3 +56,63 @@
Loading
55 56
    for i, labels in enumerate(zip(*y_flat)):
56 57
        buckets[labels].append(i)
57 58
    return {k: np.array(v) for k, v in buckets.items()}
59 +
60 +
61 +
def get_label_instances(
62 +
    bucket: Tuple[int, ...], x: np.ndarray, *y: np.ndarray
63 +
) -> np.ndarray:
64 +
    """Return instances in x with the specified combination of labels.
65 +
66 +
    Parameters
67 +
    ----------
68 +
    bucket
69 +
        A tuple of label values corresponding to which instances from x are returned
70 +
    x
71 +
        NumPy array of data instances to be returned
72 +
    *y
73 +
        A list of np.ndarray of (int) labels
74 +
75 +
    Returns
76 +
    -------
77 +
    np.ndarray
78 +
        NumPy array of instances from x with the specified combination of labels
79 +
80 +
    Example
81 +
    -------
82 +
    A common use case is calling ``get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)``
83 +
    where ``x`` is a NumPy array of data instances that the labels correspond to,
84 +
    ``Y_gold`` is a list of gold (i.e. ground truth) labels, and
85 +
    ``Y_pred`` is a corresponding list of predicted labels.
86 +
87 +
    >>> import pandas as pd
88 +
    >>> x = pd.DataFrame(data={'col1': ["this is a string", "a second string", "a third string"], 'col2': ["1", "2", "3"]})
89 +
    >>> Y_gold = np.array([1, 1, 1])
90 +
    >>> Y_pred = np.array([1, 0, 0])
91 +
    >>> bucket = (1, 0)
92 +
93 +
    The returned NumPy array of data instances from ``x`` will correspond to
94 +
    the rows where the first list had a 1 and the second list had a 0.
95 +
    >>> get_label_instances(bucket, x.to_numpy(), Y_gold, Y_pred)
96 +
    array([['a second string', '2'],
97 +
           ['a third string', '3']], dtype=object)
98 +
99 +
    More generally, given bucket ``(i, j, ...)`` and lists ``y1, y2, ...``
100 +
    the returned data instances from ``x`` will correspond to the rows where
101 +
    y1 had label i, y2 had label j, and so on. Note that ``x`` and ``y``
102 +
    must all be the same length.
103 +
    """
104 +
    if len(y) != len(bucket):
105 +
        raise ValueError("Number of lists must match the amount of labels in bucket")
106 +
    if x.shape[0] != len(y[0]):
107 +
        # Note: the check for all y having the same number of elements occurs in get_label_buckets
108 +
        raise ValueError(
109 +
            "Number of rows in x does not match number of elements in at least one label list"
110 +
        )
111 +
    buckets = get_label_buckets(*y)
112 +
    try:
113 +
        indices = buckets[bucket]
114 +
    except KeyError:
115 +
        logging.warning("Bucket" + str(bucket) + " does not exist.")
116 +
        return np.array([])
117 +
    instances = x[indices]
118 +
    return instances

@@ -1,5 +1,5 @@
Loading
1 1
"""Generic model analysis utilities shared across Snorkel."""
2 2
3 -
from .error_analysis import get_label_buckets  # noqa: F401
3 +
from .error_analysis import get_label_buckets, get_label_instances  # noqa: F401
4 4
from .metrics import metric_score  # noqa: F401
5 5
from .scorer import Scorer  # noqa: F401
Files Coverage
snorkel 97.21%
Project Totals (68 files) 97.21%
347.1
TRAVIS_PYTHON_VERSION=3.6
TRAVIS_OS_NAME=linux
TOXENV=coverage,complex,spark,doctest,type,check
347.2
TRAVIS_PYTHON_VERSION=3.7
TRAVIS_OS_NAME=linux
TOXENV=coverage,complex,spark,doctest,type,check
1
coverage:
2
  status:
3
    project:
4
      default:
5
        target: 95%
6
    patch:
7
      default:
8
        threshold: 2%
9

10
comment:
11
  layout: "header, diff, flags, files"
Sunburst
The inner-most circle is the entire project, moving away from the center are folders then, finally, a single file. The size and color of each slice is representing the number of statements and the coverage, respectively.
Icicle
The top section represents the entire project. Proceeding with folders and finally individual files. The size and color of each slice is representing the number of statements and the coverage, respectively.
Grid
Each block represents a single file in the project. The size and color of each block is represented by the number of statements and the coverage, respectively.
Loading