1
|
4
|
import inspect
|
2
|
4
|
import pickle
|
3
|
4
|
from collections import Hashable
|
4
|
4
|
from types import SimpleNamespace
|
5
|
4
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
6
|
|
|
7
|
4
|
import numpy as np
|
8
|
4
|
import pandas as pd
|
9
|
|
|
10
|
4
|
from snorkel.types import DataPoint, FieldMap
|
11
|
|
|
12
|
4
|
MapFunction = Callable[[DataPoint], Optional[DataPoint]]
|
13
|
|
|
14
|
|
|
15
|
4
|
def get_parameters(
|
16
|
|
f: Callable[..., Any], allow_args: bool = False, allow_kwargs: bool = False
|
17
|
|
) -> List[str]:
|
18
|
|
"""Get names of function parameters."""
|
19
|
4
|
params = inspect.getfullargspec(f)
|
20
|
4
|
if not allow_args and params[1] is not None:
|
21
|
4
|
raise ValueError(f"Function {f.__name__} should not have *args")
|
22
|
4
|
if not allow_kwargs and params[2] is not None:
|
23
|
4
|
raise ValueError(f"Function {f.__name__} should not have **kwargs")
|
24
|
4
|
return params[0]
|
25
|
|
|
26
|
|
|
27
|
4
|
def is_hashable(obj: Any) -> bool:
|
28
|
|
"""Test if object is hashable via duck typing.
|
29
|
|
|
30
|
|
NB: not using ``collections.Hashable`` as some objects
|
31
|
|
(e.g. pandas.Series) have a ``__hash__`` method to throw
|
32
|
|
a more specific exception.
|
33
|
|
"""
|
34
|
4
|
try:
|
35
|
4
|
hash(obj)
|
36
|
4
|
return True
|
37
|
4
|
except Exception:
|
38
|
4
|
return False
|
39
|
|
|
40
|
|
|
41
|
4
|
def get_hashable(obj: Any) -> Hashable:
|
42
|
|
"""Get a hashable version of a potentially unhashable object.
|
43
|
|
|
44
|
|
This helper is used for caching mapper outputs of data points.
|
45
|
|
For common data point formats (e.g. SimpleNamespace, pandas.Series),
|
46
|
|
produces hashable representations of the values using a ``frozenset``.
|
47
|
|
For objects like ``pandas.Series``, the name/index indentifier is dropped.
|
48
|
|
|
49
|
|
Parameters
|
50
|
|
----------
|
51
|
|
obj
|
52
|
|
Object to get hashable version of
|
53
|
|
|
54
|
|
Returns
|
55
|
|
-------
|
56
|
|
Hashable
|
57
|
|
Hashable representation of object values
|
58
|
|
|
59
|
|
Raises
|
60
|
|
------
|
61
|
|
ValueError
|
62
|
|
No hashable proxy for object
|
63
|
|
"""
|
64
|
|
# If hashable already, just return
|
65
|
4
|
if is_hashable(obj):
|
66
|
4
|
return obj
|
67
|
|
# Get dictionary from SimpleNamespace
|
68
|
4
|
if isinstance(obj, SimpleNamespace):
|
69
|
4
|
obj = vars(obj)
|
70
|
|
# For dictionaries or pd.Series, construct a frozenset from items
|
71
|
|
# Also recurse on values in case they aren't hashable
|
72
|
4
|
if isinstance(obj, (dict, pd.Series)):
|
73
|
4
|
return frozenset((k, get_hashable(v)) for k, v in obj.items())
|
74
|
|
# For lists, recurse on values
|
75
|
4
|
if isinstance(obj, (list, tuple)):
|
76
|
4
|
return tuple(get_hashable(v) for v in obj)
|
77
|
|
# For NumPy arrays, hash the byte representation of the data array
|
78
|
4
|
if isinstance(obj, np.ndarray):
|
79
|
4
|
return obj.data.tobytes()
|
80
|
4
|
raise ValueError(f"Object {obj} has no hashing proxy.")
|
81
|
|
|
82
|
|
|
83
|
4
|
class BaseMapper:
|
84
|
|
"""Base class for ``Mapper`` and ``LambdaMapper``.
|
85
|
|
|
86
|
|
Implements nesting, memoization, and deep copy functionality.
|
87
|
|
Used primarily for type checking.
|
88
|
|
|
89
|
|
Parameters
|
90
|
|
----------
|
91
|
|
name
|
92
|
|
Name of the mapper
|
93
|
|
pre
|
94
|
|
Mappers to run before this mapper is executed
|
95
|
|
memoize
|
96
|
|
Memoize mapper outputs?
|
97
|
|
|
98
|
|
Raises
|
99
|
|
------
|
100
|
|
NotImplementedError
|
101
|
|
Subclasses need to implement ``_generate_mapped_data_point``
|
102
|
|
|
103
|
|
Attributes
|
104
|
|
----------
|
105
|
|
memoize
|
106
|
|
Memoize mapper outputs?
|
107
|
|
"""
|
108
|
|
|
109
|
4
|
def __init__(self, name: str, pre: List["BaseMapper"], memoize: bool) -> None:
|
110
|
4
|
self.name = name
|
111
|
4
|
self._pre = pre
|
112
|
4
|
self.memoize = memoize
|
113
|
4
|
self.reset_cache()
|
114
|
|
|
115
|
4
|
def reset_cache(self) -> None:
|
116
|
|
"""Reset the memoization cache."""
|
117
|
4
|
self._cache: Dict[DataPoint, DataPoint] = {}
|
118
|
|
|
119
|
4
|
def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
|
120
|
|
raise NotImplementedError
|
121
|
|
|
122
|
4
|
def __call__(self, x: DataPoint) -> Optional[DataPoint]:
|
123
|
|
"""Run mapping function on input data point.
|
124
|
|
|
125
|
|
Deep copies the data point first so as not to make
|
126
|
|
accidental in-place changes. If ``memoize`` is set to
|
127
|
|
``True``, an internal cache is checked for results. If
|
128
|
|
no cached results are found, the computed results are
|
129
|
|
added to the cache.
|
130
|
|
|
131
|
|
Parameters
|
132
|
|
----------
|
133
|
|
x
|
134
|
|
Data point to run mapping function on
|
135
|
|
|
136
|
|
Returns
|
137
|
|
-------
|
138
|
|
DataPoint
|
139
|
|
Mapped data point of same format but possibly different fields
|
140
|
|
"""
|
141
|
4
|
if self.memoize:
|
142
|
|
# NB: don't do ``self._cache.get(...)`` first in case cached value is ``None``
|
143
|
4
|
x_hashable = get_hashable(x)
|
144
|
4
|
if x_hashable in self._cache:
|
145
|
4
|
return self._cache[x_hashable]
|
146
|
|
# NB: using pickle roundtrip as a more robust deepcopy
|
147
|
|
# As an example, calling deepcopy on a pd.Series or SimpleNamespace
|
148
|
|
# with a dictionary attribute won't create a copy of the dictionary
|
149
|
4
|
x_mapped = pickle.loads(pickle.dumps(x))
|
150
|
4
|
for mapper in self._pre:
|
151
|
4
|
x_mapped = mapper(x_mapped)
|
152
|
4
|
x_mapped = self._generate_mapped_data_point(x_mapped)
|
153
|
4
|
if self.memoize:
|
154
|
4
|
self._cache[x_hashable] = x_mapped
|
155
|
4
|
return x_mapped
|
156
|
|
|
157
|
|
def __repr__(self) -> str:
|
158
|
|
pre_str = f", Pre: {self._pre}"
|
159
|
|
return f"{type(self).__name__} {self.name}{pre_str}"
|
160
|
|
|
161
|
|
|
162
|
4
|
class Mapper(BaseMapper):
|
163
|
|
"""Base class for any data point to data point mapping in the pipeline.
|
164
|
|
|
165
|
|
Map data points to new data points by transforming, adding
|
166
|
|
additional information, or decomposing into primitives. This module
|
167
|
|
provides base classes for other operators like ``TransformationFunction``
|
168
|
|
and ``Preprocessor``. We don't expect people to construct ``Mapper``
|
169
|
|
objects directly.
|
170
|
|
|
171
|
|
A Mapper maps an data point to a new data point, possibly with
|
172
|
|
a different schema. Subclasses of Mapper need to implement the
|
173
|
|
``run`` method, which takes fields of the data point as input
|
174
|
|
and outputs new fields for the mapped data point as a dictionary.
|
175
|
|
The ``run`` method should only be called internally by the ``Mapper``
|
176
|
|
object, not directly by a user.
|
177
|
|
|
178
|
|
Mapper derivatives work for data points that have mutable attributes,
|
179
|
|
like ``SimpleNamespace``, ``pd.Series``, or ``dask.Series``. An example
|
180
|
|
of a data point type without mutable fields is ``pyspark.sql.Row``.
|
181
|
|
Use ``snorkel.map.spark.make_spark_mapper`` for PySpark compatibility.
|
182
|
|
|
183
|
|
For an example of a Mapper, see
|
184
|
|
``snorkel.preprocess.nlp.SpacyPreprocessor``
|
185
|
|
|
186
|
|
Parameters
|
187
|
|
----------
|
188
|
|
name
|
189
|
|
Name of mapper
|
190
|
|
field_names
|
191
|
|
A map from attribute names of the incoming data points
|
192
|
|
to the input argument names of the ``run`` method. If None,
|
193
|
|
the parameter names in the function signature are used.
|
194
|
|
mapped_field_names
|
195
|
|
A map from output keys of the ``run`` method to attribute
|
196
|
|
names of the output data points. If None, the original
|
197
|
|
output keys are used.
|
198
|
|
pre
|
199
|
|
Mappers to run before this mapper is executed
|
200
|
|
memoize
|
201
|
|
Memoize mapper outputs?
|
202
|
|
|
203
|
|
Raises
|
204
|
|
------
|
205
|
|
NotImplementedError
|
206
|
|
Subclasses must implement the ``run`` method
|
207
|
|
|
208
|
|
Attributes
|
209
|
|
----------
|
210
|
|
field_names
|
211
|
|
See above
|
212
|
|
mapped_field_names
|
213
|
|
See above
|
214
|
|
memoize
|
215
|
|
Memoize mapper outputs?
|
216
|
|
"""
|
217
|
|
|
218
|
4
|
def __init__(
|
219
|
|
self,
|
220
|
|
name: str,
|
221
|
|
field_names: Optional[Mapping[str, str]] = None,
|
222
|
|
mapped_field_names: Optional[Mapping[str, str]] = None,
|
223
|
|
pre: Optional[List[BaseMapper]] = None,
|
224
|
|
memoize: bool = False,
|
225
|
|
) -> None:
|
226
|
4
|
if field_names is None:
|
227
|
|
# Parse field names from ``run(...)`` if not provided
|
228
|
4
|
field_names = {k: k for k in get_parameters(self.run)[1:]}
|
229
|
4
|
self.field_names = field_names
|
230
|
4
|
self.mapped_field_names = mapped_field_names
|
231
|
4
|
super().__init__(name, pre or [], memoize)
|
232
|
|
|
233
|
4
|
def run(self, **kwargs: Any) -> Optional[FieldMap]:
|
234
|
|
"""Run the mapping operation using the input fields.
|
235
|
|
|
236
|
|
The inputs to this function are fed by extracting the fields of
|
237
|
|
the input data point using the keys of ``field_names``. The output field
|
238
|
|
names are converted using ``mapped_field_names`` and added to the
|
239
|
|
data point.
|
240
|
|
|
241
|
|
Returns
|
242
|
|
-------
|
243
|
|
Optional[FieldMap]
|
244
|
|
A mapping from canonical output field names to their values.
|
245
|
|
|
246
|
|
Raises
|
247
|
|
------
|
248
|
|
NotImplementedError
|
249
|
|
Subclasses must implement this method
|
250
|
|
"""
|
251
|
|
raise NotImplementedError
|
252
|
|
|
253
|
4
|
def _update_fields(self, x: DataPoint, mapped_fields: FieldMap) -> DataPoint:
|
254
|
|
# ``SimpleNamespace``, ``pd.Series``, and ``dask.Series`` objects all
|
255
|
|
# have attribute setting.
|
256
|
4
|
for k, v in mapped_fields.items():
|
257
|
4
|
setattr(x, k, v)
|
258
|
4
|
return x
|
259
|
|
|
260
|
4
|
def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
|
261
|
4
|
field_map = {k: getattr(x, v) for k, v in self.field_names.items()}
|
262
|
4
|
mapped_fields = self.run(**field_map)
|
263
|
4
|
if mapped_fields is None:
|
264
|
4
|
return None
|
265
|
4
|
if self.mapped_field_names is not None:
|
266
|
4
|
mapped_fields = {
|
267
|
|
v: mapped_fields[k] for k, v in self.mapped_field_names.items()
|
268
|
|
}
|
269
|
4
|
return self._update_fields(x, mapped_fields)
|
270
|
|
|
271
|
|
|
272
|
4
|
class LambdaMapper(BaseMapper):
|
273
|
|
"""Define a mapper from a function.
|
274
|
|
|
275
|
|
Convenience class for mappers that execute a simple
|
276
|
|
function with no set up. The function should map from
|
277
|
|
an input data point to a new data point directly, unlike
|
278
|
|
``Mapper.run``. The original data point will not be updated,
|
279
|
|
so in-place operations are safe.
|
280
|
|
|
281
|
|
Parameters
|
282
|
|
----------
|
283
|
|
name:
|
284
|
|
Name of mapper
|
285
|
|
f
|
286
|
|
Function executing the mapping operation
|
287
|
|
pre
|
288
|
|
Mappers to run before this mapper is executed
|
289
|
|
memoize
|
290
|
|
Memoize mapper outputs?
|
291
|
|
"""
|
292
|
|
|
293
|
4
|
def __init__(
|
294
|
|
self,
|
295
|
|
name: str,
|
296
|
|
f: MapFunction,
|
297
|
|
pre: Optional[List[BaseMapper]] = None,
|
298
|
|
memoize: bool = False,
|
299
|
|
) -> None:
|
300
|
4
|
self._f = f
|
301
|
4
|
super().__init__(name, pre or [], memoize)
|
302
|
|
|
303
|
4
|
def _generate_mapped_data_point(self, x: DataPoint) -> Optional[DataPoint]:
|
304
|
4
|
return self._f(x)
|
305
|
|
|
306
|
|
|
307
|
4
|
class lambda_mapper:
|
308
|
|
"""Decorate a function to define a LambdaMapper object.
|
309
|
|
|
310
|
|
Example
|
311
|
|
-------
|
312
|
|
>>> @lambda_mapper()
|
313
|
|
... def concatenate_text(x):
|
314
|
|
... x.article = f"{x.title} {x.body}"
|
315
|
|
... return x
|
316
|
|
>>> isinstance(concatenate_text, LambdaMapper)
|
317
|
|
True
|
318
|
|
>>> from types import SimpleNamespace
|
319
|
|
>>> x = SimpleNamespace(title="my title", body="my text")
|
320
|
|
>>> concatenate_text(x).article
|
321
|
|
'my title my text'
|
322
|
|
|
323
|
|
Parameters
|
324
|
|
----------
|
325
|
|
name
|
326
|
|
Name of mapper. If None, uses the name of the wrapped function.
|
327
|
|
pre
|
328
|
|
Mappers to run before this mapper is executed
|
329
|
|
memoize
|
330
|
|
Memoize mapper outputs?
|
331
|
|
|
332
|
|
Attributes
|
333
|
|
----------
|
334
|
|
memoize
|
335
|
|
Memoize mapper outputs?
|
336
|
|
"""
|
337
|
|
|
338
|
4
|
def __init__(
|
339
|
|
self,
|
340
|
|
name: Optional[str] = None,
|
341
|
|
pre: Optional[List[BaseMapper]] = None,
|
342
|
|
memoize: bool = False,
|
343
|
|
) -> None:
|
344
|
4
|
if callable(name):
|
345
|
4
|
raise ValueError("Looks like this decorator is missing parentheses!")
|
346
|
4
|
self.name = name
|
347
|
4
|
self.pre = pre
|
348
|
4
|
self.memoize = memoize
|
349
|
|
|
350
|
4
|
def __call__(self, f: MapFunction) -> LambdaMapper:
|
351
|
|
"""Wrap a function to create a ``LambdaMapper``.
|
352
|
|
|
353
|
|
Parameters
|
354
|
|
----------
|
355
|
|
f
|
356
|
|
Function executing the mapping operation
|
357
|
|
|
358
|
|
Returns
|
359
|
|
-------
|
360
|
|
LambdaMapper
|
361
|
|
New ``LambdaMapper`` executing operation in wrapped function
|
362
|
|
"""
|
363
|
4
|
name = self.name or f.__name__
|
364
|
4
|
return LambdaMapper(name=name, f=f, pre=self.pre, memoize=self.memoize)
|