1 4
import inspect
2 4
import pickle
3 4
from collections import Hashable
4 4
from types import SimpleNamespace
5 4
from typing import Any, Callable, Dict, List, Mapping, Optional
6

7 4
import numpy as np
8 4
import pandas as pd
9

10 4
from snorkel.types import DataPoint, FieldMap
11

12 4
MapFunction = Callable[[DataPoint], Optional[DataPoint]]
13

14

15 4
def get_parameters(
16
    f: Callable[..., Any], allow_args: bool = False, allow_kwargs: bool = False
17
) -> List[str]:
18
    """Get names of function parameters."""
19 4
    params = inspect.getfullargspec(f)
20 4
    if not allow_args and params[1] is not None:
21 4
        raise ValueError(f"Function {f.__name__} should not have *args")
22 4
    if not allow_kwargs and params[2] is not None:
23 4
        raise ValueError(f"Function {f.__name__} should not have **kwargs")
24 4
    return params[0]
25

26

27 4
def is_hashable(obj: Any) -> bool:
28
    """Test if object is hashable via duck typing.
29

30
    NB: not using ``collections.Hashable`` as some objects
31
    (e.g. pandas.Series) have a ``__hash__`` method to throw
32
    a more specific exception.
33
    """
34 4
    try:
35 4
        hash(obj)
36 4
        return True
37 4
    except Exception:
38 4
        return False
39

40

41 4
def get_hashable(obj: Any) -> Hashable:
42
    """Get a hashable version of a potentially unhashable object.
43

44
    This helper is used for caching mapper outputs of data points.
45
    For common data point formats (e.g. SimpleNamespace, pandas.Series),
46
    produces hashable representations of the values using a ``frozenset``.
47
    For objects like ``pandas.Series``, the name/index indentifier is dropped.
48

49
    Parameters
50
    ----------
51
    obj
52
        Object to get hashable version of
53

54
    Returns
55
    -------
56
    Hashable
57
        Hashable representation of object values
58

59
    Raises
60
    ------
61
    ValueError
62
        No hashable proxy for object
63
    """
64
    # If hashable already, just return
65 4
    if is_hashable(obj):
66 4
        return obj
67
    # Get dictionary from SimpleNamespace
68 4
    if isinstance(obj, SimpleNamespace):
69 4
        obj = vars(obj)
70
    # For dictionaries or pd.Series, construct a frozenset from items
71
    # Also recurse on values in case they aren't hashable
72 4
    if isinstance(obj, (dict, pd.Series)):
73 4
        return frozenset((k, get_hashable(v)) for k, v in obj.items())
74
    # For lists, recurse on values
75 4
    if isinstance(obj, (list, tuple)):
76 4
        return tuple(get_hashable(v) for v in obj)
77
    # For NumPy arrays, hash the byte representation of the data array
78 4
    if isinstance(obj, np.ndarray):
79 4
        return obj.data.tobytes()
80 4
    raise ValueError(f"Object {obj} has no hashing proxy.")
81

82

83 4
class BaseMapper:
84
    """Base class for ``Mapper`` and ``LambdaMapper``.
85

86
    Implements nesting, memoization, and deep copy functionality.
87
    Used primarily for type checking.
88

89
    Parameters
90
    ----------
91
    name
92
        Name of the mapper
93
    pre
94
        Mappers to run before this mapper is executed
95
    memoize
96
        Memoize mapper outputs?
97

98
    Raises
99
    ------
100
    NotImplementedError
101
        Subclasses need to implement ``_generate_mapped_data_point``
102

103
    Attributes
104
    ----------
105
    memoize
106
        Memoize mapper outputs?
107
    """
108

109 4
    def __init__(self, name: str, pre: List["BaseMapper"], memoize: bool) -> None:
110 4
        self.name = name
111 4
        self._pre = pre
112 4
        self.memoize = memoize
113 4
        self.reset_cache()
114

115 4
    def reset_cache(self) -> None:
116
        """Reset the memoization cache."""
117 4
        self._cache: Dict[DataPoint, DataPoint] = {}
118

119 4
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
120
        raise NotImplementedError
121

122 4
    def __call__(self, x: DataPoint) -> Optional[DataPoint]:
123
        """Run mapping function on input data point.
124

125
        Deep copies the data point first so as not to make
126
        accidental in-place changes. If ``memoize`` is set to
127
        ``True``, an internal cache is checked for results. If
128
        no cached results are found, the computed results are
129
        added to the cache.
130

131
        Parameters
132
        ----------
133
        x
134
            Data point to run mapping function on
135

136
        Returns
137
        -------
138
        DataPoint
139
            Mapped data point of same format but possibly different fields
140
        """
141 4
        if self.memoize:
142
            # NB: don't do ``self._cache.get(...)`` first in case cached value is ``None``
143 4
            x_hashable = get_hashable(x)
144 4
            if x_hashable in self._cache:
145 4
                return self._cache[x_hashable]
146
        # NB: using pickle roundtrip as a more robust deepcopy
147
        # As an example, calling deepcopy on a pd.Series or SimpleNamespace
148
        # with a dictionary attribute won't create a copy of the dictionary
149 4
        x_mapped = pickle.loads(pickle.dumps(x))
150 4
        for mapper in self._pre:
151 4
            x_mapped = mapper(x_mapped)
152 4
        x_mapped = self._generate_mapped_data_point(x_mapped)
153 4
        if self.memoize:
154 4
            self._cache[x_hashable] = x_mapped
155 4
        return x_mapped
156

157
    def __repr__(self) -> str:
158
        pre_str = f", Pre: {self._pre}"
159
        return f"{type(self).__name__} {self.name}{pre_str}"
160

161

162 4
class Mapper(BaseMapper):
163
    """Base class for any data point to data point mapping in the pipeline.
164

165
    Map data points to new data points by transforming, adding
166
    additional information, or decomposing into primitives. This module
167
    provides base classes for other operators like ``TransformationFunction``
168
    and ``Preprocessor``. We don't expect people to construct ``Mapper``
169
    objects directly.
170

171
    A Mapper maps an data point to a new data point, possibly with
172
    a different schema. Subclasses of Mapper need to implement the
173
    ``run`` method, which takes fields of the data point as input
174
    and outputs new fields for the mapped data point as a dictionary.
175
    The ``run`` method should only be called internally by the ``Mapper``
176
    object, not directly by a user.
177

178
    Mapper derivatives work for data points that have mutable attributes,
179
    like ``SimpleNamespace``, ``pd.Series``, or ``dask.Series``. An example
180
    of a data point type without mutable fields is ``pyspark.sql.Row``.
181
    Use ``snorkel.map.spark.make_spark_mapper`` for PySpark compatibility.
182

183
    For an example of a Mapper, see
184
        ``snorkel.preprocess.nlp.SpacyPreprocessor``
185

186
    Parameters
187
    ----------
188
    name
189
        Name of mapper
190
    field_names
191
        A map from attribute names of the incoming data points
192
        to the input argument names of the ``run`` method. If None,
193
        the parameter names in the function signature are used.
194
    mapped_field_names
195
        A map from output keys of the ``run`` method to attribute
196
        names of the output data points. If None, the original
197
        output keys are used.
198
    pre
199
        Mappers to run before this mapper is executed
200
    memoize
201
        Memoize mapper outputs?
202

203
    Raises
204
    ------
205
    NotImplementedError
206
        Subclasses must implement the ``run`` method
207

208
    Attributes
209
    ----------
210
    field_names
211
        See above
212
    mapped_field_names
213
        See above
214
    memoize
215
        Memoize mapper outputs?
216
    """
217

218 4
    def __init__(
219
        self,
220
        name: str,
221
        field_names: Optional[Mapping[str, str]] = None,
222
        mapped_field_names: Optional[Mapping[str, str]] = None,
223
        pre: Optional[List[BaseMapper]] = None,
224
        memoize: bool = False,
225
    ) -> None:
226 4
        if field_names is None:
227
            # Parse field names from ``run(...)`` if not provided
228 4
            field_names = {k: k for k in get_parameters(self.run)[1:]}
229 4
        self.field_names = field_names
230 4
        self.mapped_field_names = mapped_field_names
231 4
        super().__init__(name, pre or [], memoize)
232

233 4
    def run(self, **kwargs: Any) -> Optional[FieldMap]:
234
        """Run the mapping operation using the input fields.
235

236
        The inputs to this function are fed by extracting the fields of
237
        the input data point using the keys of ``field_names``. The output field
238
        names are converted using ``mapped_field_names`` and added to the
239
        data point.
240

241
        Returns
242
        -------
243
        Optional[FieldMap]
244
            A mapping from canonical output field names to their values.
245

246
        Raises
247
        ------
248
        NotImplementedError
249
            Subclasses must implement this method
250
        """
251
        raise NotImplementedError
252

253 4
    def _update_fields(self, x: DataPoint, mapped_fields: FieldMap) -> DataPoint:
254
        # ``SimpleNamespace``, ``pd.Series``, and ``dask.Series`` objects all
255
        # have attribute setting.
256 4
        for k, v in mapped_fields.items():
257 4
            setattr(x, k, v)
258 4
        return x
259

260 4
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
261 4
        field_map = {k: getattr(x, v) for k, v in self.field_names.items()}
262 4
        mapped_fields = self.run(**field_map)
263 4
        if mapped_fields is None:
264 4
            return None
265 4
        if self.mapped_field_names is not None:
266 4
            mapped_fields = {
267
                v: mapped_fields[k] for k, v in self.mapped_field_names.items()
268
            }
269 4
        return self._update_fields(x, mapped_fields)
270

271

272 4
class LambdaMapper(BaseMapper):
273
    """Define a mapper from a function.
274

275
    Convenience class for mappers that execute a simple
276
    function with no set up. The function should map from
277
    an input data point to a new data point directly, unlike
278
    ``Mapper.run``. The original data point will not be updated,
279
    so in-place operations are safe.
280

281
    Parameters
282
    ----------
283
    name:
284
        Name of mapper
285
    f
286
        Function executing the mapping operation
287
    pre
288
        Mappers to run before this mapper is executed
289
    memoize
290
        Memoize mapper outputs?
291
    """
292

293 4
    def __init__(
294
        self,
295
        name: str,
296
        f: MapFunction,
297
        pre: Optional[List[BaseMapper]] = None,
298
        memoize: bool = False,
299
    ) -> None:
300 4
        self._f = f
301 4
        super().__init__(name, pre or [], memoize)
302

303 4
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
304 4
        return self._f(x)
305

306

307 4
class lambda_mapper:
308
    """Decorate a function to define a LambdaMapper object.
309

310
    Example
311
    -------
312
    >>> @lambda_mapper()
313
    ... def concatenate_text(x):
314
    ...     x.article = f"{x.title} {x.body}"
315
    ...     return x
316
    >>> isinstance(concatenate_text, LambdaMapper)
317
    True
318
    >>> from types import SimpleNamespace
319
    >>> x = SimpleNamespace(title="my title", body="my text")
320
    >>> concatenate_text(x).article
321
    'my title my text'
322

323
    Parameters
324
    ----------
325
    name
326
        Name of mapper. If None, uses the name of the wrapped function.
327
    pre
328
        Mappers to run before this mapper is executed
329
    memoize
330
        Memoize mapper outputs?
331

332
    Attributes
333
    ----------
334
    memoize
335
        Memoize mapper outputs?
336
    """
337

338 4
    def __init__(
339
        self,
340
        name: Optional[str] = None,
341
        pre: Optional[List[BaseMapper]] = None,
342
        memoize: bool = False,
343
    ) -> None:
344 4
        if callable(name):
345 4
            raise ValueError("Looks like this decorator is missing parentheses!")
346 4
        self.name = name
347 4
        self.pre = pre
348 4
        self.memoize = memoize
349

350 4
    def __call__(self, f: MapFunction) -> LambdaMapper:
351
        """Wrap a function to create a ``LambdaMapper``.
352

353
        Parameters
354
        ----------
355
        f
356
            Function executing the mapping operation
357

358
        Returns
359
        -------
360
        LambdaMapper
361
            New ``LambdaMapper`` executing operation in wrapped function
362
        """
363 4
        name = self.name or f.__name__
364 4
        return LambdaMapper(name=name, f=f, pre=self.pre, memoize=self.memoize)

Read our documentation on viewing source code .

Loading