1 2
from typing import Iterator, List
2

3 2
from tqdm import tqdm
4

5 2
from snorkel.augmentation.policy.core import Policy
6 2
from snorkel.augmentation.tf import BaseTransformationFunction
7 2
from snorkel.types import DataPoint, DataPoints
8 2
from snorkel.utils.data_operators import check_unique_names
9

10

11 2
class BaseTFApplier:
12
    """Base class for TF applier objects.
13

14
    Base class for TF applier objects, which execute a set of TF
15
    on a collection of data points. Subclasses should operate on
16
    a single data point collection format (e.g. ``DataFrame``).
17
    Subclasses must implement the ``apply`` method.
18

19
    Parameters
20
    ----------
21
    tfs
22
        TFs that this applier executes on examples
23
    policy
24
        Augmentation policy used to generate sequences of TFs
25

26
    Raises
27
    ------
28
    ValueError
29
        If names of TFs are not unique
30
    """
31

32 2
    def __init__(self, tfs: List[BaseTransformationFunction], policy: Policy) -> None:
33 2
        self._tfs = tfs
34 2
        self._tf_names = [tf.name for tf in tfs]
35 2
        check_unique_names(self._tf_names)
36 2
        self._policy = policy
37

38 2
    def _apply_policy_to_data_point(self, x: DataPoint) -> DataPoints:
39 2
        x_transformed = []
40 2
        for seq in self._policy.generate_for_example():
41 2
            x_t = x
42
            # Handle empty sequence for `keep_original`
43 2
            transform_applied = len(seq) == 0
44
            # Apply TFs
45 2
            for tf_idx in seq:
46 2
                tf = self._tfs[tf_idx]
47 2
                x_t_or_none = tf(x_t)
48
                # Update if transformation was applied
49 2
                if x_t_or_none is not None:
50 2
                    transform_applied = True
51 2
                    x_t = x_t_or_none
52
            # Add example if original or transformations applied
53 2
            if transform_applied:
54 2
                x_transformed.append(x_t)
55 2
        return x_transformed
56

57
    def __repr__(self) -> str:
58
        policy_name = type(self._policy).__name__
59
        return f"{type(self).__name__}, Policy: {policy_name}, TFs: {self._tf_names}"
60

61

62 2
class TFApplier(BaseTFApplier):
63
    """TF applier for a list of data points.
64

65
    Augments a list of data points (e.g. ``SimpleNamespace``). Primarily
66
    useful for testing.
67
    """
68

69 2
    def apply_generator(
70
        self, data_points: DataPoints, batch_size: int
71
    ) -> Iterator[List[DataPoint]]:
72
        """Augment a list of data points using TFs and policy in batches.
73

74
        This method acts as a generator, yielding augmented data points for
75
        a given input batch of data points. This can be useful in a training
76
        loop when it is too memory-intensive to pregenerate all transformed
77
        examples.
78

79
        Parameters
80
        ----------
81
        data_points
82
            List containing data points to be transformed
83
        batch_size
84
            Batch size for generator. Yields augmented data points
85
            for the next ``batch_size`` input data points.
86

87
        Yields
88
        ------
89
        List[DataPoint]
90
            List of data points in augmented data set for batches of inputs
91
        """
92 2
        for i in range(0, len(data_points), batch_size):
93 2
            batch_transformed: List[DataPoint] = []
94 2
            for x in data_points[i : i + batch_size]:
95 2
                batch_transformed.extend(self._apply_policy_to_data_point(x))
96 2
            yield batch_transformed
97

98 2
    def apply(
99
        self, data_points: DataPoints, progress_bar: bool = True
100
    ) -> List[DataPoint]:
101
        """Augment a list of data points using TFs and policy.
102

103
        Parameters
104
        ----------
105
        data_points
106
            List containing data points to be transformed
107
        progress_bar
108
            Display a progress bar?
109

110
        Returns
111
        -------
112
        List[DataPoint]
113
            List of data points in augmented data set
114
        """
115 2
        x_transformed: List[DataPoint] = []
116 2
        for x in tqdm(data_points, disable=(not progress_bar)):
117 2
            x_transformed.extend(self._apply_policy_to_data_point(x))
118 2
        return x_transformed

Read our documentation on viewing source code .

Loading