1 2
import inspect
2 2
import pickle
3 2
from collections import Hashable
4 2
from types import SimpleNamespace
5 2
from typing import Any, Callable, Dict, List, Mapping, Optional
6

7 2
import numpy as np
8 2
import pandas as pd
9

10 2
from snorkel.types import DataPoint, FieldMap, HashingFunction
11

12 2
MapFunction = Callable[[DataPoint], Optional[DataPoint]]
13

14

15 2
def get_parameters(
16
    f: Callable[..., Any], allow_args: bool = False, allow_kwargs: bool = False
17
) -> List[str]:
18
    """Get names of function parameters."""
19 2
    params = inspect.getfullargspec(f)
20 2
    if not allow_args and params[1] is not None:
21 2
        raise ValueError(f"Function {f.__name__} should not have *args")
22 2
    if not allow_kwargs and params[2] is not None:
23 2
        raise ValueError(f"Function {f.__name__} should not have **kwargs")
24 2
    return params[0]
25

26

27 2
def is_hashable(obj: Any) -> bool:
28
    """Test if object is hashable via duck typing.
29

30
    NB: not using ``collections.Hashable`` as some objects
31
    (e.g. pandas.Series) have a ``__hash__`` method to throw
32
    a more specific exception.
33
    """
34 2
    try:
35 2
        hash(obj)
36 2
        return True
37 2
    except Exception:
38 2
        return False
39

40

41 2
def get_hashable(obj: Any) -> Hashable:
42
    """Get a hashable version of a potentially unhashable object.
43

44
    This helper is used for caching mapper outputs of data points.
45
    For common data point formats (e.g. SimpleNamespace, pandas.Series),
46
    produces hashable representations of the values using a ``frozenset``.
47
    For objects like ``pandas.Series``, the name/index indentifier is dropped.
48

49
    Parameters
50
    ----------
51
    obj
52
        Object to get hashable version of
53

54
    Returns
55
    -------
56
    Hashable
57
        Hashable representation of object values
58

59
    Raises
60
    ------
61
    ValueError
62
        No hashable proxy for object
63
    """
64
    # If hashable already, just return
65 2
    if is_hashable(obj):
66 2
        return obj
67
    # Get dictionary from SimpleNamespace
68 2
    if isinstance(obj, SimpleNamespace):
69 2
        obj = vars(obj)
70
    # For dictionaries or pd.Series, construct a frozenset from items
71
    # Also recurse on values in case they aren't hashable
72 2
    if isinstance(obj, (dict, pd.Series)):
73 2
        return frozenset((k, get_hashable(v)) for k, v in obj.items())
74
    # For lists, recurse on values
75 2
    if isinstance(obj, (list, tuple)):
76 2
        return tuple(get_hashable(v) for v in obj)
77
    # For NumPy arrays, hash the byte representation of the data array
78 2
    if isinstance(obj, np.ndarray):
79 2
        return obj.data.tobytes()
80 2
    raise ValueError(f"Object {obj} has no hashing proxy.")
81

82

83 2
class BaseMapper:
84
    """Base class for ``Mapper`` and ``LambdaMapper``.
85

86
    Implements nesting, memoization, and deep copy functionality.
87
    Used primarily for type checking.
88

89
    Parameters
90
    ----------
91
    name
92
        Name of the mapper
93
    pre
94
        Mappers to run before this mapper is executed
95
    memoize
96
        Memoize mapper outputs?
97
    memoize_key
98
        Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
99

100
    Raises
101
    ------
102
    NotImplementedError
103
        Subclasses need to implement ``_generate_mapped_data_point``
104

105
    Attributes
106
    ----------
107
    memoize
108
        Memoize mapper outputs?
109
    """
110

111 2
    def __init__(
112
        self,
113
        name: str,
114
        pre: List["BaseMapper"],
115
        memoize: bool,
116
        memoize_key: Optional[HashingFunction] = None,
117
    ) -> None:
118 2
        if memoize_key is None:
119 2
            memoize_key = get_hashable
120 2
        self.name = name
121 2
        self._pre = pre
122 2
        self._memoize_key = memoize_key
123 2
        self.memoize = memoize
124 2
        self.reset_cache()
125

126 2
    def reset_cache(self) -> None:
127
        """Reset the memoization cache."""
128 2
        self._cache: Dict[DataPoint, DataPoint] = {}
129

130 2
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
131
        raise NotImplementedError
132

133 2
    def __call__(self, x: DataPoint) -> Optional[DataPoint]:
134
        """Run mapping function on input data point.
135

136
        Deep copies the data point first so as not to make
137
        accidental in-place changes. If ``memoize`` is set to
138
        ``True``, an internal cache is checked for results. If
139
        no cached results are found, the computed results are
140
        added to the cache.
141

142
        Parameters
143
        ----------
144
        x
145
            Data point to run mapping function on
146

147
        Returns
148
        -------
149
        DataPoint
150
            Mapped data point of same format but possibly different fields
151
        """
152 2
        if self.memoize:
153
            # NB: don't do ``self._cache.get(...)`` first in case cached value is ``None``
154 2
            x_hashable = self._memoize_key(x)
155 2
            if x_hashable in self._cache:
156 2
                return self._cache[x_hashable]
157
        # NB: using pickle roundtrip as a more robust deepcopy
158
        # As an example, calling deepcopy on a pd.Series or SimpleNamespace
159
        # with a dictionary attribute won't create a copy of the dictionary
160 2
        x_mapped = pickle.loads(pickle.dumps(x))
161 2
        for mapper in self._pre:
162 2
            x_mapped = mapper(x_mapped)
163 2
        x_mapped = self._generate_mapped_data_point(x_mapped)
164 2
        if self.memoize:
165 2
            self._cache[x_hashable] = x_mapped
166 2
        return x_mapped
167

168
    def __repr__(self) -> str:
169
        pre_str = f", Pre: {self._pre}"
170
        return f"{type(self).__name__} {self.name}{pre_str}"
171

172

173 2
class Mapper(BaseMapper):
174
    """Base class for any data point to data point mapping in the pipeline.
175

176
    Map data points to new data points by transforming, adding
177
    additional information, or decomposing into primitives. This module
178
    provides base classes for other operators like ``TransformationFunction``
179
    and ``Preprocessor``. We don't expect people to construct ``Mapper``
180
    objects directly.
181

182
    A Mapper maps an data point to a new data point, possibly with
183
    a different schema. Subclasses of Mapper need to implement the
184
    ``run`` method, which takes fields of the data point as input
185
    and outputs new fields for the mapped data point as a dictionary.
186
    The ``run`` method should only be called internally by the ``Mapper``
187
    object, not directly by a user.
188

189
    Mapper derivatives work for data points that have mutable attributes,
190
    like ``SimpleNamespace``, ``pd.Series``, or ``dask.Series``. An example
191
    of a data point type without mutable fields is ``pyspark.sql.Row``.
192
    Use ``snorkel.map.spark.make_spark_mapper`` for PySpark compatibility.
193

194
    For an example of a Mapper, see
195
        ``snorkel.preprocess.nlp.SpacyPreprocessor``
196

197
    Parameters
198
    ----------
199
    name
200
        Name of mapper
201
    field_names
202
        A map from attribute names of the incoming data points
203
        to the input argument names of the ``run`` method. If None,
204
        the parameter names in the function signature are used.
205
    mapped_field_names
206
        A map from output keys of the ``run`` method to attribute
207
        names of the output data points. If None, the original
208
        output keys are used.
209
    pre
210
        Mappers to run before this mapper is executed
211
    memoize
212
        Memoize mapper outputs?
213
    memoize_key
214
        Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
215

216
    Raises
217
    ------
218
    NotImplementedError
219
        Subclasses must implement the ``run`` method
220

221
    Attributes
222
    ----------
223
    field_names
224
        See above
225
    mapped_field_names
226
        See above
227
    memoize
228
        Memoize mapper outputs?
229
    """
230

231 2
    def __init__(
232
        self,
233
        name: str,
234
        field_names: Optional[Mapping[str, str]] = None,
235
        mapped_field_names: Optional[Mapping[str, str]] = None,
236
        pre: Optional[List[BaseMapper]] = None,
237
        memoize: bool = False,
238
        memoize_key: Optional[HashingFunction] = None,
239
    ) -> None:
240 2
        if field_names is None:
241
            # Parse field names from ``run(...)`` if not provided
242 2
            field_names = {k: k for k in get_parameters(self.run)[1:]}
243 2
        self.field_names = field_names
244 2
        self.mapped_field_names = mapped_field_names
245 2
        super().__init__(name, pre or [], memoize, memoize_key)
246

247 2
    def run(self, **kwargs: Any) -> Optional[FieldMap]:
248
        """Run the mapping operation using the input fields.
249

250
        The inputs to this function are fed by extracting the fields of
251
        the input data point using the keys of ``field_names``. The output field
252
        names are converted using ``mapped_field_names`` and added to the
253
        data point.
254

255
        Returns
256
        -------
257
        Optional[FieldMap]
258
            A mapping from canonical output field names to their values.
259

260
        Raises
261
        ------
262
        NotImplementedError
263
            Subclasses must implement this method
264
        """
265
        raise NotImplementedError
266

267 2
    def _update_fields(self, x: DataPoint, mapped_fields: FieldMap) -> DataPoint:
268
        # ``SimpleNamespace``, ``pd.Series``, and ``dask.Series`` objects all
269
        # have attribute setting.
270 2
        for k, v in mapped_fields.items():
271 2
            setattr(x, k, v)
272 2
        return x
273

274 2
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
275 2
        field_map = {k: getattr(x, v) for k, v in self.field_names.items()}
276 2
        mapped_fields = self.run(**field_map)
277 2
        if mapped_fields is None:
278 2
            return None
279 2
        if self.mapped_field_names is not None:
280 2
            mapped_fields = {
281
                v: mapped_fields[k] for k, v in self.mapped_field_names.items()
282
            }
283 2
        return self._update_fields(x, mapped_fields)
284

285

286 2
class LambdaMapper(BaseMapper):
287
    """Define a mapper from a function.
288

289
    Convenience class for mappers that execute a simple
290
    function with no set up. The function should map from
291
    an input data point to a new data point directly, unlike
292
    ``Mapper.run``. The original data point will not be updated,
293
    so in-place operations are safe.
294

295
    Parameters
296
    ----------
297
    name
298
        Name of mapper
299
    f
300
        Function executing the mapping operation
301
    pre
302
        Mappers to run before this mapper is executed
303
    memoize
304
        Memoize mapper outputs?
305
    memoize_key
306
        Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
307
    """
308

309 2
    def __init__(
310
        self,
311
        name: str,
312
        f: MapFunction,
313
        pre: Optional[List[BaseMapper]] = None,
314
        memoize: bool = False,
315
        memoize_key: Optional[HashingFunction] = None,
316
    ) -> None:
317 2
        self._f = f
318 2
        super().__init__(name, pre or [], memoize, memoize_key)
319

320 2
    def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
321 2
        return self._f(x)
322

323

324 2
class lambda_mapper:
325
    """Decorate a function to define a LambdaMapper object.
326

327
    Example
328
    -------
329
    >>> @lambda_mapper()
330
    ... def concatenate_text(x):
331
    ...     x.article = f"{x.title} {x.body}"
332
    ...     return x
333
    >>> isinstance(concatenate_text, LambdaMapper)
334
    True
335
    >>> from types import SimpleNamespace
336
    >>> x = SimpleNamespace(title="my title", body="my text")
337
    >>> concatenate_text(x).article
338
    'my title my text'
339

340
    Parameters
341
    ----------
342
    name
343
        Name of mapper. If None, uses the name of the wrapped function.
344
    pre
345
        Mappers to run before this mapper is executed
346
    memoize
347
        Memoize mapper outputs?
348
    memoize_key
349
        Hashing function to handle the memoization (default to snorkel.map.core.get_hashable)
350

351
    Attributes
352
    ----------
353
    memoize
354
        Memoize mapper outputs?
355
    """
356

357 2
    def __init__(
358
        self,
359
        name: Optional[str] = None,
360
        pre: Optional[List[BaseMapper]] = None,
361
        memoize: bool = False,
362
        memoize_key: Optional[HashingFunction] = None,
363
    ) -> None:
364 2
        if callable(name):
365 2
            raise ValueError("Looks like this decorator is missing parentheses!")
366 2
        self.name = name
367 2
        self.pre = pre
368 2
        self.memoize = memoize
369 2
        self.memoize_key = memoize_key
370

371 2
    def __call__(self, f: MapFunction) -> LambdaMapper:
372
        """Wrap a function to create a ``LambdaMapper``.
373

374
        Parameters
375
        ----------
376
        f
377
            Function executing the mapping operation
378

379
        Returns
380
        -------
381
        LambdaMapper
382
            New ``LambdaMapper`` executing operation in wrapped function
383
        """
384 2
        name = self.name or f.__name__
385 2
        return LambdaMapper(
386
            name=name,
387
            f=f,
388
            pre=self.pre,
389
            memoize=self.memoize,
390
            memoize_key=self.memoize_key,
391
        )

Read our documentation on viewing source code .

Loading