#325 Explain text predictions of Keras classifiers

Open Tomas Baltrunas teabolt
Showing 9 of 86 files from the diff.
Newly tracked file
eli5/nn/text.py created.
Newly tracked file
eli5/nn/gradcam.py created.
Other files ignored by Codecov
.travis.yml has changed.
tox.ini has changed.
README.rst has changed.

@@ -1,4 +1,9 @@
Loading
1 1
# -*- coding: utf-8 -*-
2 2
3 -
from .explain_prediction import explain_prediction_keras
4 -
from .gradcam import gradcam, gradcam_backend

@@ -1,6 +1,6 @@
Loading
1 1
# -*- coding: utf-8 -*-
2 2
from __future__ import absolute_import
3 -
from typing import Union, Optional, Callable, Tuple, List, TYPE_CHECKING
3 +
from typing import Union, Optional, Callable, List, Tuple, Generator, TYPE_CHECKING
4 4
if TYPE_CHECKING:
5 5
    import PIL
6 6
@@ -15,26 +15,57 @@
Loading
15 15
    AveragePooling2D,
16 16
    GlobalMaxPooling2D,
17 17
    GlobalAveragePooling2D,
18 +
    Conv1D,
19 +
    Embedding,
20 +
    AveragePooling1D,
21 +
    MaxPooling1D,
22 +
    GlobalAveragePooling1D,
23 +
    GlobalMaxPooling1D,
24 +
    RNN,
25 +
    LSTM,
26 +
    GRU,
27 +
    Bidirectional,
18 28
)
19 29
from keras.preprocessing.image import array_to_img
20 30
21 -
from eli5.base import Explanation, TargetExplanation
31 +
from eli5.base import (
32 +
    Explanation, 
33 +
    TargetExplanation, 
34 +
)
22 35
from eli5.explain import explain_prediction
23 -
from .gradcam import gradcam, gradcam_backend
36 +
from eli5.nn.gradcam import (
37 +
    gradcam_heatmap,
38 +
    DESCRIPTION_GRADCAM,
39 +
)
40 +
from eli5.nn.text import (
41 +
    gradcam_spans,
42 +
)
43 +
from .gradcam import (
44 +
    gradcam_backend_keras,
45 +
)
46 +
47 +
48 +
image_model_layers = (Conv2D, MaxPooling2D, AveragePooling2D, GlobalMaxPooling2D,
49 +
                      GlobalAveragePooling2D,)
24 50
25 51
26 -
DESCRIPTION_KERAS = """Grad-CAM visualization for image classification; 
27 -
output is explanation object that contains input image 
28 -
and heatmap image for a target.
29 -
"""
52 +
text_layers = (Conv1D, RNN, LSTM, GRU, Bidirectional,)
53 +
temporal_layers = (AveragePooling1D, MaxPooling1D, GlobalAveragePooling1D, GlobalMaxPooling1D,)
54 +
30 55
31 56
# note that keras.models.Sequential subclasses keras.models.Model
32 57
@explain_prediction.register(Model)
33 58
def explain_prediction_keras(model, # type: Model
34 59
                             doc, # type: np.ndarray
35 60
                             targets=None, # type: Optional[list]
36 61
                             layer=None, # type: Optional[Union[int, str, Layer]]
62 +
                             relu=True, # type: bool
63 +
                             counterfactual=False, # type: bool
37 64
                             image=None,
65 +
                             tokens=None,
66 +
                             pad_value=None,
67 +
                             pad_token=None,
68 +
                             interpolation_kind='linear',
38 69
                             ):
39 70
    # type: (...) -> Explanation
40 71
    """
@@ -46,51 +77,65 @@
Loading
46 77
        Instance of a Keras neural network model,
47 78
        whose predictions are to be explained.
48 79
80 +
81 +
        :raises ValueError: if ``model`` can not be differentiated.
49 82
    :param numpy.ndarray doc:
50 83
        An input to ``model`` whose prediction will be explained.
51 84
52 85
        Currently only numpy arrays are supported.
86 +
        The data format must be in "channels last".
53 87
54 88
        The tensor must be of suitable shape for the ``model``.
55 89
56 90
        Check ``model.input_shape`` to confirm the required dimensions of the input tensor.
57 91
58 92
59 93
        :raises TypeError: if ``doc`` is not a numpy array.
60 -
        :raises ValueError: if ``doc`` shape does not match.
94 +
        :raises ValueError: if ``doc`` batch size is not 1.
61 95
62 96
    :param targets:
63 97
        Prediction ID's to focus on.
64 98
65 99
        *Currently only the first prediction from the list is explained*.
66 100
        The list must be length one.
67 101
68 -
        If None, the model is fed the input image and its top prediction
102 +
        If None, the model is fed the input ``doc`` and the top prediction 
69 103
        is taken as the target automatically.
70 104
71 105
72 106
        :raises ValueError: if ``targets`` is a list with more than one item.
73 107
        :raises TypeError: if ``targets`` is not list or None.
108 +
        :raises TypeError: if ``targets`` does not contain an integer target.
109 +
        :raises ValueError: if integer target is not in the classes that ``model`` predicts.
74 110
    :type targets: list[int], optional
75 111
76 112
    :param layer:
77 113
        The activation layer in the model to perform Grad-CAM on:
78 114
        a valid keras layer name, layer index, or an instance of a Keras layer.
79 115
80 -
        If None, a suitable layer is attempted to be retrieved.
81 -
        For best results, pick a layer that:
82 -
83 -
        * has spatial or temporal information (conv, recurrent, pooling, embedding)
84 -
          (not dense layers).
85 -
        * shows high level features.
86 -
        * has large enough dimensions for resizing over input to work.
116 +
        If None, a suitable layer is attempted to be chosen automatically.
87 117
88 118
89 119
        :raises TypeError: if ``layer`` is not None, str, int, or keras.layers.Layer instance.
90 -
        :raises ValueError: if suitable layer can not be found.
120 +
        :raises ValueError: if suitable layer can not be found automatically.
91 121
        :raises ValueError: if differentiation fails with respect to retrieved ``layer``.
92 122
    :type layer: int or str or keras.layers.Layer, optional
93 123
124 +
    :param relu:
125 +
        Whether to apply ReLU on the resulting heatmap.
126 +
127 +
        Set to `False` to see the "negative" of a class.
128 +
129 +
        Default is `True`.
130 +
    :type relu: bool, optional
131 +
132 +
    :param counterfactual:
133 +
        Whether to negate gradients during the heatmap calculation.
134 +
        Useful for highlighting what makes the prediction or class score go down.
135 +
136 +
        Default is `False`.
137 +
    :type counterfactual: bool, optional
138 +
94 139
95 140
    See :func:`eli5.explain_prediction` for more information about the ``model``,
96 141
    ``doc``, and ``targets`` parameters.
@@ -104,16 +149,37 @@
Loading
104 149
    -------
105 150
      expl : :class:`eli5.base.Explanation`
106 151
        An :class:`eli5.base.Explanation` object for the relevant implementation.
152 +
153 +
        The following attributes are supported by all concrete implementations:
154 +
155 +
        * ``targets`` a list of :class:`eli5.base.TargetExplanation` objects \
156 +
            for each target. Currently only 1 target is supported.
157 +
        * ``layer`` used for Grad-CAM.
158 +
107 159
    """
108 -
    # Note that this function should only do dispatch
109 -
    # and no other processing
160 +
    # Note that this function should only do dispatch and no other processing
161 +
    assert image is None or tokens is None # only one of image or tokens must be passed
110 162
    if image is not None or _maybe_image(model, doc):
111 163
        return explain_prediction_keras_image(model,
112 164
                                              doc,
113 165
                                              image=image,
114 166
                                              targets=targets,
115 167
                                              layer=layer,
168 +
                                              relu=relu,
169 +
                                              counterfactual=counterfactual,
116 170
                                              )
171 +
    elif tokens is not None:
172 +
        return explain_prediction_keras_text(model,
173 +
                                             doc,
174 +
                                             tokens=tokens,
175 +
                                             pad_value=pad_value,
176 +
                                             pad_token=pad_token,
177 +
                                             interpolation_kind=interpolation_kind,
178 +
                                             targets=targets,
179 +
                                             layer=layer,
180 +
                                             relu=relu,
181 +
                                             counterfactual=counterfactual,
182 +
                                             )
117 183
    else:
118 184
        return explain_prediction_keras_not_supported(model, doc)
119 185
@@ -123,19 +189,23 @@
Loading
123 189
    Can not do an explanation based on the passed arguments.
124 190
    Did you pass either "image" or "tokens"?
125 191
    """
192 +
    # An alternative to giving up is to still generate a heatmap
193 +
    # but not do any formatting, i.e. useful if we do not support some task.
126 194
    return Explanation(
127 195
        model.name,
128 196
        error='model "{}" is not supported, '
129 -
              'try passing the "image" argument if explaining an image model.'.format(model.name),
197 +
              'try passing the "image" argument if explaining an image model, '
198 +
              'or the "tokens" argument if explaining a text model.'.format(model.name),
130 199
    )
131 -
    # TODO (open issue): implement 'other'/differentiable network type explanations
132 200
133 201
134 202
def explain_prediction_keras_image(model,
135 203
                                   doc,
136 204
                                   image=None, # type: Optional['PIL.Image.Image']
137 205
                                   targets=None,
138 206
                                   layer=None,
207 +
                                   relu=True,
208 +
                                   counterfactual=False,
139 209
                                   ):
140 210
    """
141 211
    Explain an image-based model, highlighting what contributed in the image.
@@ -167,82 +237,196 @@
Loading
167 237
    expl : eli5.base.Explanation
168 238
      An :class:`eli5.base.Explanation` object with the following attributes:
169 239
          * ``image`` a Pillow image representing the input.
170 -
          * ``targets`` a list of :class:`eli5.base.TargetExplanation` objects \
171 -
              for each target. Currently only 1 target is supported.
240 +
          * ``targets`` and ``layer`` attributes:
241 +
            See :func:`eli5.keras.explain_prediction.explain_prediction_keras`.
242 +
172 243
      The :class:`eli5.base.TargetExplanation` objects will have the following attributes:
173 244
          * ``heatmap`` a rank 2 numpy array with the localization map \
174 245
            values as floats.
175 246
          * ``target`` ID of target class.
176 247
          * ``score`` value for predicted class.
177 248
    """
249 +
    _validate_params(model, doc)
178 250
    if image is None:
179 251
        image = _extract_image(doc)
180 -
    _validate_doc(model, doc)
181 -
    activation_layer = _get_activation_layer(model, layer)
182 252
183 -
    # TODO: maybe do the sum / loss calculation in this function and pass it to gradcam.
184 -
    # This would be consistent with what is done in
185 -
    # https://github.com/ramprs/grad-cam/blob/master/misc/utils.lua
186 -
    # and https://github.com/ramprs/grad-cam/blob/master/classification.lua
187 -
    values = gradcam_backend(model, doc, targets, activation_layer)
188 -
    weights, activations, grads, predicted_idx, predicted_val = values
189 -
    heatmap = gradcam(weights, activations)
253 +
    if layer is not None:
254 +
        activation_layer = _get_layer(model, layer)
255 +
    else:
256 +
        activation_layer = _autoget_layer_image(model)
257 +
258 +
    vals = gradcam_backend_keras(model, doc, targets, activation_layer)
259 +
    activations, grads, predicted_idx, predicted_val = vals
260 +
    heatmap = gradcam_heatmap(activations,
261 +
                              grads,
262 +
                              relu=relu,
263 +
                              counterfactual=counterfactual,
264 +
                              )
265 +
    # take from batch
266 +
    predicted_idx, = predicted_idx
267 +
    predicted_val, = predicted_val
268 +
    heatmap, = heatmap
190 269
191 270
    return Explanation(
192 271
        model.name,
193 -
        description=DESCRIPTION_KERAS,
272 +
        description=DESCRIPTION_GRADCAM,
194 273
        error='',
195 274
        method='Grad-CAM',
196 -
        image=image,
275 +
        image=image, # RGBA Pillow image
276 +
        layer=activation_layer,
197 277
        targets=[TargetExplanation(
198 278
            predicted_idx,
199 279
            score=predicted_val, # for now we keep the prediction in the .score field (not .proba)
200 -
            heatmap=heatmap, # 2D [0, 1] numpy array
280 +
            heatmap=heatmap, # 2D numpy array
201 281
        )],
202 282
        is_regression=False, # might be relevant later when explaining for regression tasks
203 -
        highlight_spaces=None, # might be relevant later when explaining text models
283 +
    )
284 +
285 +
286 +
def explain_prediction_keras_text(model,
287 +
                                  doc,
288 +
                                  tokens=None, # type: Optional[Union[List[str], np.ndarray]]
289 +
                                  pad_value=None, # type: Optional[Union[int, float]]
290 +
                                  pad_token=None, # type: Optional[str]
291 +
                                  interpolation_kind='linear', # type: Union[str, int]
292 +
                                  targets=None,
293 +
                                  layer=None,
294 +
                                  relu=True,
295 +
                                  counterfactual=False,
296 +
                                  ):
297 +
    """
298 +
    Explain a text-based model, highlighting parts of text that contributed to the prediction.
299 +
300 +
    In the case of binary classification, this highlights what makes the output go up.
301 +
302 +
    See :func:`eli5.keras.explain_prediction.explain_prediction_keras` for description of ``targets``, 
303 +
    ``layer``, ``relu``, and ``counterfactual`` parameters.
304 +
305 +
    :param numpy.ndarray doc:
306 +
        Suitable input tensor. Temporal with batch size. May have padding.
307 +
308 +
    :param tokens:
309 +
        Tokens that correspond to ``doc``.
310 +
        With padding if ``doc`` has padding.
311 +
312 +
        A Python list or a numpy array of strings. With the same length as ``doc``.
313 +
        If ``doc`` has batch size = 1, batch dimension from tokens may be omitted.
314 +
315 +
        These tokens will be highlighted for text-based explanations.
316 +
317 +
318 +
        :raises TypeError: if ``tokens`` has wrong type or contains wrong types.
319 +
        :raises ValueError: if ``tokens`` dimensions do not match.
320 +
    :type tokens: list[str], optional
321 +
322 +
    :param pad_value:
323 +
        Integer identifier of the padding token.
324 +
325 +
        If given, cuts padding off.
326 +
    :type pad_value: int or float, optional
327 +
328 +
    :param pad_token:
329 +
        A string token inside ``tokens`` identifying padding.
330 +
331 +
        If given, cuts padding off.
332 +
    :type pad_token: str, optional
333 +
334 +
    :param interpolation_kind:
335 +
        Interpolation method. See :func:`eli5.nn.text.resize_1d` for more details.
336 +
    :type interpolation_kind: str or int, optional
337 +
338 +
    Returns
339 +
    -------
340 +
    expl : eli5.base.Explanation
341 +
      An :class:`eli5.base.Explanation` object with the following attributes:
342 +
          * ``targets`` and ``layer`` attributes:
343 +
            See :func:`eli5.keras.explain_prediction.explain_prediction_keras`.
344 +
345 +
      The :class:`eli5.base.TargetExplanation` objects will have the following attributes:
346 +
          * ``weighted_spans`` a :class:`eli5.base.WeightedSpans` object with \
347 +
            weights for parts of text to be highlighted.
348 +
          * ``heatmap`` a rank 1 numpy array with with the localization map \
349 +
              values as floats.
350 +
          * ``target`` ID of target class.
351 +
          * ``score`` value for predicted class.
352 +
353 +
    """
354 +
    assert tokens is not None
355 +
    _validate_params(model, doc)
356 +
    tokens = _unbatch_tokens(tokens)
357 +
358 +
    if layer is not None:
359 +
        activation_layer = _get_layer(model, layer)
360 +
    else:
361 +
        activation_layer = _autoget_layer_text(model)
362 +
363 +
    vals = gradcam_backend_keras(model, doc, targets, activation_layer)
364 +
    activations, grads, predicted_idx, predicted_val = vals
365 +
    heatmap = gradcam_heatmap(activations,
366 +
                              grads,
367 +
                              relu=relu,
368 +
                              counterfactual=counterfactual,
369 +
                              )
370 +
    # take from batch
371 +
    predicted_idx, = predicted_idx
372 +
    predicted_val, = predicted_val
373 +
    heatmap, = heatmap
374 +
    text_vals = gradcam_spans(heatmap,
375 +
                              tokens,
376 +
                              doc,
377 +
                              pad_value=pad_value,
378 +
                              pad_token=pad_token,
379 +
                              interpolation_kind=interpolation_kind,
380 +
                              )
381 +
    tokens, heatmap, weighted_spans = text_vals
382 +
    return Explanation(
383 +
        model.name,
384 +
        description=DESCRIPTION_GRADCAM,
385 +
        error='',
386 +
        method='Grad-CAM',
387 +
        layer=activation_layer,
388 +
        targets=[TargetExplanation(
389 +
            predicted_idx,
390 +
            weighted_spans=weighted_spans,
391 +
            score=predicted_val,
392 +
            heatmap=heatmap,
393 +
        )],
394 +
        is_regression=False,
204 395
    )
205 396
206 397
207 398
def _maybe_image(model, doc):
208 399
    # type: (Model, np.ndarray) -> bool
209 -
    """Decide whether we are dealing with a image-based explanation 
210 -
    based on heuristics on ``model`` and ``doc``."""
400 +
    """
401 +
    Decide whether we are dealing with a image-based explanation
402 +
    based on heuristics on ``model`` and ``doc``.
403 +
    """
211 404
    return _maybe_image_input(doc) and _maybe_image_model(model)
212 405
213 406
214 407
def _maybe_image_input(doc):
215 408
    # type: (np.ndarray) -> bool
216 409
    """Decide whether ``doc`` represents an image input."""
217 -
    rank = len(doc.shape)
218 -
    # image with channels or without (spatial only)
219 -
    return rank == 4 or rank == 3
410 +
    try:
411 +
        _validate_doc(doc)
412 +
    except (TypeError, ValueError):
413 +
        return False
414 +
    else:
415 +
        rank = len(doc.shape)
416 +
        # image with channels or without (spatial only)
417 +
        return rank == 4 or rank == 3
220 418
221 419
222 420
def _maybe_image_model(model):
223 421
    # type: (Model) -> bool
224 422
    """Decide whether ``model`` is used for images."""
225 -
    # FIXME: replace try-except with something else
226 -
    try:
227 -
        # search for the first occurrence of an "image" layer
228 -
        _search_layer_backwards(model, _is_possible_image_model_layer)
229 -
        return True
230 -
    except ValueError:
231 -
        return False
232 -
233 -
234 -
image_model_layers = (Conv2D,
235 -
                      MaxPooling2D,
236 -
                      AveragePooling2D,
237 -
                      GlobalMaxPooling2D,
238 -
                      GlobalAveragePooling2D,
239 -
                      )
240 -
241 -
242 -
def _is_possible_image_model_layer(model, layer):
243 -
    # type: (Model, Layer) -> bool
244 -
    """Check that the given ``layer`` is usually used for images."""
245 -
    return isinstance(layer, image_model_layers)
423 +
    # search for the first occurrence of an "image" layer
424 +
    l = _search_layer(model,
425 +
                     _backward_layers,
426 +
                     lambda model, layer:
427 +
                     isinstance(layer, image_model_layers)
428 +
                     )
429 +
    return l is not None
246 430
247 431
248 432
def _extract_image(doc):
@@ -253,90 +437,127 @@
Loading
253 437
    return image
254 438
255 439
256 -
def _validate_doc(model, doc):
257 -
    # type: (Model, np.ndarray) -> None
258 -
    """
259 -
    Check that the input ``doc`` is suitable for ``model``.
260 -
    """
261 -
    if not isinstance(doc, np.ndarray):
262 -
        raise TypeError('doc must be a numpy.ndarray, got: {}'.format(doc))
263 -
    input_sh = model.input_shape
264 -
    doc_sh = doc.shape
265 -
    if len(input_sh) == 4:
266 -
        # rank 4 with (batch, ...) shape
267 -
        # check that we have only one image (batch size 1)
268 -
        single_batch = (1,) + input_sh[1:]
269 -
        if doc_sh != single_batch:
270 -
            raise ValueError('Batch size does not match (must be 1). ' 
271 -
                             'doc must be of shape: {}, '
272 -
                             'got: {}'.format(single_batch, doc_sh))
440 +
def _unbatch_tokens(tokens):
441 +
    # type: (np.ndarray) -> np.ndarray
442 +
    """If ``tokens`` has batch size, take out the first sample from the batch."""
443 +
    an_entry = tokens[0]
444 +
    if isinstance(an_entry, str):
445 +
        # not batched
446 +
        return tokens
273 447
    else:
274 -
        # other shapes
275 -
        if doc_sh != input_sh:
276 -
            raise ValueError('Input and doc shapes do not match.'
277 -
                             'input: {}, doc: {}'.format(input_sh, doc_sh))
448 +
        # batched, return first entry
449 +
        return an_entry
278 450
279 451
280 -
def _get_activation_layer(model, layer):
281 -
    # type: (Model, Union[None, int, str, Layer]) -> Layer
452 +
def _get_layer(model, layer): 
453 +
    # type: (Model, Union[int, str, Layer]) -> Layer
282 454
    """
283 -
    Get an instance of the desired activation layer in ``model``,
284 -
    as specified by ``layer``.
455 +
    Wrapper around ``model.get_layer()`` for int, str, or Layer argument``.
456 +
    Return a keras Layer instance.
285 457
    """
286 -
    if layer is None:
287 -
        # Automatically get the layer if not provided
288 -
        activation_layer = _search_layer_backwards(model, _is_suitable_activation_layer)
289 -
        return activation_layer
290 -
458 +
    # currently we don't do any validation on the retrieved layer
291 459
    if isinstance(layer, Layer):
292 -
        activation_layer = layer
293 -
    # get_layer() performs a bottom-up horizontal graph traversal
294 -
    # it can raise ValueError if the layer index / name specified is not found
460 +
        return layer
295 461
    elif isinstance(layer, int):
296 -
        activation_layer = model.get_layer(index=layer)
462 +
        # keras.get_layer() performs a bottom-up horizontal graph traversal
463 +
        # the function raises ValueError if the layer index / name specified is not found
464 +
        return model.get_layer(index=layer)
297 465
    elif isinstance(layer, str):
298 -
        activation_layer = model.get_layer(name=layer)
466 +
        return model.get_layer(name=layer)
467 +
    else:
468 +
        raise TypeError('Invalid layer (must be str, int, or keras.layers.Layer): %s' % layer)
469 +
470 +
471 +
def _autoget_layer_image(model):
472 +
    # type: (Model) -> Layer
473 +
    """Try find a suitable hidden layer for an image ``model``."""
474 +
    l = _search_layer(model,
475 +
                      _backward_layers,
476 +
                      lambda model, layer:
477 +
                      (len(layer.output_shape) == len(model.input_shape) and
478 +
                       '2d' in layer.__class__.__name__.lower())
479 +
                      )
480 +
481 +
    if l is None:
482 +
        raise ValueError('Could not find a suitable image layer automatically. '
483 +
                         'Try passing the "layer" argument.')
299 484
    else:
300 -
        raise TypeError('Invalid layer (must be str, int, keras.layers.Layer, or None): %s' % layer)
485 +
        return l
301 486
302 -
    if _is_suitable_activation_layer(model, activation_layer):
303 -
        # final validation step
304 -
        return activation_layer
487 +
488 +
def _autoget_layer_text(model):
489 +
    # type: (Model) -> Layer
490 +
    """Try find a suitable hidden layer for a text ``model``."""
491 +
    def wanted(layers):
492 +
        return lambda model, layer: (len(layer.output_shape) == 3 and
493 +
                                     isinstance(layer, layers))
494 +
495 +
    l = _search_layer(model, _backward_layers, wanted(text_layers))
496 +
    if l is None:
497 +
        l = _search_layer(model, _backward_layers, wanted(temporal_layers))
498 +
        if l is None:
499 +
            l = _search_layer(model, _backward_layers, wanted((Embedding,)))
500 +
501 +
    if l is None:
502 +
        raise ValueError('Could not find a suitable text layer automatically. '
503 +
                         'Try passing the "layer" argument.')
305 504
    else:
306 -
        raise ValueError('Can not perform Grad-CAM on the retrieved activation layer')
505 +
        return l
307 506
308 507
309 -
def _search_layer_backwards(model, condition):
310 -
    # type: (Model, Callable[[Model, Layer], bool]) -> Layer
508 +
def _search_layer(model, # type: Model
509 +
                  layers_generator, # type: Callable[[Model], Generator[Layer, None, None]]
510 +
                  layer_condition, # type: Callable[[Model, Layer], bool]
511 +
                  ):
512 +
    # type: (...) -> Optional[Layer]
311 513
    """
312 -
    Search for a layer in ``model``, backwards (starting from the output layer),
313 -
    checking if the layer is suitable with the callable ``condition``,
514 +
    Search for a layer in ``model``, iterating through layers in the order specified by
515 +
    ``layers_generator``, returning the first layer that matches ``layer_condition``.
516 +
    If no layer could be found, return None.
314 517
    """
315 518
    # linear search in reverse through the flattened layers
316 -
    for layer in model.layers[::-1]:
317 -
        if condition(model, layer):
519 +
    for layer in layers_generator(model):
520 +
        if layer_condition(model, layer):
318 521
            # linear search succeeded
319 522
            return layer
320 523
    # linear search ended with no results
321 -
    raise ValueError('Could not find a suitable target layer automatically.')        
524 +
    return None
322 525
323 526
324 -
def _is_suitable_activation_layer(model, layer):
325 -
    # type: (Model, Layer) -> bool
326 -
    """
327 -
    Check whether the layer ``layer`` matches what is required 
328 -
    by ``model`` to do Grad-CAM on ``layer``.
329 -
    Returns a boolean.
527 +
def _backward_layers(model):
528 +
    # type: (Model) -> Generator[Layer, None, None]
529 +
    """Return layers going from output to input (backwards)."""
530 +
    return (model.get_layer(index=i) for i in range(len(model.layers)-1, -1, -1))
531 +
532 +
533 +
def _validate_params(model, # type: Model
534 +
                     doc, # type: np.ndarray
535 +
                     ):
536 +
    # type: (...) -> None
537 +
    """Helper for validating explanation function parameters."""
538 +
    _validate_model(model)
539 +
    _validate_doc(doc)
330 540
331 -
    Matching Criteria:
332 -
        * Rank of the layer's output tensor.
541 +
542 +
def _validate_model(model):
543 +
    if len(model.layers) == 0:
544 +
        # "empty" model
545 +
        raise ValueError('Model must have at least 1 layer. '
546 +
                         'Got model with layers: "{}"'.format(model.layers))
547 +
548 +
549 +
def _validate_doc(doc):
550 +
    # type: (np.ndarray) -> None
333 551
    """
334 -
    # TODO: experiment with this, using many models and images, to find what works best
335 -
    # Some ideas: 
336 -
    # check layer type, i.e.: isinstance(l, keras.layers.Conv2D)
337 -
    # check layer name
338 -
339 -
    # a check that asks "can we resize this activation layer over the image?"
340 -
    rank = len(layer.output_shape)
341 -
    required_rank = len(model.input_shape)
342 -
    return rank == required_rank
552 +
    Check that the input ``doc`` has the correct type and values.
553 +
    """
554 +
    if not isinstance(doc, np.ndarray):
555 +
        raise TypeError('"doc" must be an instace of numpy.ndarray. '
556 +
                        'Got: {} (type "{}")'.format(doc, type(doc)))
557 +
558 +
    # check that batch=1 (batch greater than 1 is currently not supported)
559 +
    batch_size = doc.shape[0]
560 +
    if batch_size != 1:
561 +
        raise ValueError('"doc" batch size must be 1. '
562 +
                         'Got doc with batch size: %d' % batch_size)
563 +
    # Note that validation of the input shape, etc is done by Keras

@@ -0,0 +1,299 @@
Loading
1 +
# -*- coding: utf-8 -*-
2 +
from typing import Optional, Union, Tuple, List
3 +
4 +
import numpy as np # type: ignore
5 +
from scipy.interpolate import interp1d # type: ignore
6 +
7 +
from eli5.base import (
8 +
    WeightedSpans,
9 +
    DocWeightedSpans,
10 +
)
11 +
from eli5.nn.gradcam import (
12 +
    _validate_heatmap,
13 +
)
14 +
15 +
16 +
def gradcam_spans(heatmap, # type: np.ndarray
17 +
                  tokens, # type: Union[np.ndarray, list]
18 +
                  doc, # type: np.ndarray
19 +
                  pad_value=None, # type: Optional[Union[int, float]]
20 +
                  pad_token=None, # type: Optional[str]
21 +
                  interpolation_kind='linear' # type: Union[str, int]
22 +
                  ):
23 +
    # type: (...) -> Tuple[Union[np.ndarray, list], np.ndarray, WeightedSpans]
24 +
    """
25 +
    Create text spans from a Grad-CAM ``heatmap`` imposed over ``tokens``.
26 +
    Optionally cut off the padding from the explanation
27 +
    with the ``pad_value`` or ``pad_token`` arguments.
28 +
29 +
    Parameters
30 +
    ----------
31 +
    heatmap : numpy.ndarray
32 +
        Array of weights. May be resized to match the length of tokens.
33 +
34 +
        **Should be rank 1 (no batch dimension).**
35 +
36 +
37 +
        :raises TypeError: if ``heatmap`` is wrong type.
38 +
39 +
    tokens : numpy.ndarray or list
40 +
        Tokens that will be highlighted using weights from ``heatmap``.
41 +
42 +
43 +
        :raises TypeError: if ``tokens`` is wrong type.
44 +
        :raises ValueError: if ``tokens`` contents are unexpected.
45 +
46 +
    doc: numpy.ndarray
47 +
        Original input to the network, from which ``heatmap`` was created.
48 +
49 +
    pad_value: int or float, optional
50 +
        Padding number into ``doc``.
51 +
52 +
    pad_token: str, optional
53 +
        Padding symbol into ``tokens``.
54 +
55 +
        Pass one of either `pad_value` or `pad_token` to cut off padding.
56 +
57 +
    interpolation_kind: str or int, optional
58 +
        Interpolation method. See :func:`eli5.nn.text.resize_1d` for more details.
59 +
60 +
    Returns
61 +
    -------
62 +
    (tokens, heatmap, weighted_spans) : (list or numpy.ndarray, numpy.ndarray, WeightedSpans)
63 +
        ``tokens`` and ``heatmap`` optionally cut from padding.
64 +
        A :class:`eli5.base.WeightedSpans` object with a weight for each token.
65 +
    """
66 +
    # We call this before returning the explanation, NOT when formatting the explanation
67 +
    # Because WeightedSpans, etc are attributes of a returned explanation
68 +
    _validate_tokens(tokens)
69 +
    _validate_tokens_value(tokens, doc)
70 +
    if isinstance(tokens, list):
71 +
        # convert to a common data type
72 +
        tokens = np.array(tokens)
73 +
74 +
    length = len(tokens)
75 +
    heatmap = resize_1d(heatmap, length, interpolation_kind=interpolation_kind)
76 +
77 +
    # values will be cut off from the *resized* heatmap
78 +
    if pad_value is not None or pad_token is not None:
79 +
        # remove padding
80 +
        pad_indices = _find_padding(pad_value=pad_value, pad_token=pad_token, doc=doc, tokens=tokens)
81 +
        # If passed padding argument is not the actual padding token/value, behaviour is unknown
82 +
        tokens, heatmap = _trim_padding(pad_indices, tokens, heatmap)
83 +
84 +
    document = _construct_document(tokens)
85 +
    spans = _build_spans(tokens, heatmap, document)
86 +
    weighted_spans = WeightedSpans([
87 +
        DocWeightedSpans(document, spans=spans)
88 +
    ])
89 +
    # why do we have a list of WeightedSpans? One for each vectorizer?
90 +
    # But we do not use multiple vectorizers?
91 +
    return tokens, heatmap, weighted_spans
92 +
93 +
94 +
def resize_1d(heatmap, length, interpolation_kind='linear'):
95 +
    # type: (np.ndarray, int, Union[str, int]) -> np.ndarray
96 +
    """
97 +
    Resize the ``heatmap`` 1D array to match the specified ``length``.
98 +
99 +
    For example, upscale/upsample a heatmap with length 400
100 +
    to have length 500.
101 +
102 +
    Parameters
103 +
    ----------
104 +
105 +
    heatmap : numpy.ndarray
106 +
        Heatmap to be resized.
107 +
108 +
109 +
        :raises TypeError: if ``heatmap`` is wrong type.
110 +
111 +
    length : int
112 +
        Required width.
113 +
114 +
    interpolation_kind : str or int, optional
115 +
        Interpolation method used by the underlying ``scipy.interpolate.interp1d`` resize function.
116 +
117 +
        Used when resizing ``heatmap`` to the correct ``length``.
118 +
119 +
        Default is ``linear``.
120 +
121 +
    Returns
122 +
    -------
123 +
    heatmap : numpy.ndarray
124 +
        The heatmap resized.
125 +
    """
126 +
    _validate_heatmap(heatmap)
127 +
    if len(heatmap.shape) == 1 and heatmap.shape[0] == 1:
128 +
        # single weight, no batch
129 +
        heatmap = heatmap.repeat(length)
130 +
    else:
131 +
        # more than length 1
132 +
133 +
        # scipy.interpolate solution
134 +
        # https://stackoverflow.com/questions/29085268/resample-a-numpy-array
135 +
        # interp1d requires at least length 2 array
136 +
        y = heatmap  # data to interpolate
137 +
        x = np.linspace(0, 1, heatmap.size) # array matching y
138 +
        interpolant = interp1d(x, y, kind=interpolation_kind) # approximates y = f(x)
139 +
        z = np.linspace(0, 1, length)  # points where to interpolate
140 +
        heatmap = interpolant(z)  # interpolation result
141 +
142 +
        # other solutions include scipy.signal.resample (periodic, so doesn't apply)
143 +
        # and Pillow image fromarray with mode 'F'/etc and resizing (didn't seem to work)
144 +
    return heatmap
145 +
146 +
147 +
def _build_spans(tokens, # type: Union[np.ndarray, list]
148 +
                 heatmap, # type: np.ndarray
149 +
                 document, # type: str
150 +
                 ):
151 +
    """Highlight ``tokens`` in ``document``, with weights from ``heatmap``."""
152 +
    assert len(tokens) == len(heatmap)
153 +
    spans = []
154 +
    running = 0  # where to start looking for token in document
155 +
    for (token, weight) in zip(tokens, heatmap):
156 +
        # find first occurrence of token, on or after running count
157 +
        t_start = document.index(token, running)
158 +
        # create span
159 +
        t_end = t_start + len(token)
160 +
        span = (token, [(t_start, t_end,)], weight)
161 +
        spans.append(span)
162 +
        # update run
163 +
        running = t_end
164 +
    return spans
165 +
166 +
167 +
def _construct_document(tokens):
168 +
    # type: (Union[list, np.ndarray]) -> str
169 +
    """Create a document string by joining ``tokens`` sequence."""
170 +
    if _is_character_tokenization(tokens):
171 +
        sep = ''
172 +
    else:
173 +
        sep = ' '
174 +
    return sep.join(tokens)
175 +
176 +
177 +
def _is_character_tokenization(tokens):
178 +
    # type: (Union[list, np.ndarray]) -> bool
179 +
    """Check whether tokenization is character-level (True) or word-level (False)."""
180 +
    return any(' ' in t for t in tokens)
181 +
182 +
183 +
def _find_padding(pad_value=None, # type: Union[int, float]
184 +
                  pad_token=None, # type: str
185 +
                  doc=None, # type: Optional[np.ndarray]
186 +
                  tokens=None # type: Optional[Union[np.ndarray, list]]
187 +
                  ):
188 +
    # type: (...) -> np.ndarray
189 +
    """Dispatch to a padding finder based on arguments."""
190 +
    # check that did not pass both pad_value and pad_token
191 +
    # which is ambiguous (which should take precedence?)
192 +
    assert pad_value is None or pad_token is None
193 +
    if pad_value is not None and doc is not None:
194 +
        return _find_padding_values(pad_value, doc)
195 +
    elif pad_token is not None and tokens is not None:
196 +
        return _find_padding_tokens(pad_token, tokens)
197 +
    else:
198 +
        raise TypeError('Pass "doc" and "pad_value", '
199 +
                        'or "tokens" and "pad_token".')
200 +
201 +
202 +
def _find_padding_values(pad_value, doc):
203 +
    # type: (Union[int, float], np.ndarray) -> np.ndarray
204 +
    if not isinstance(pad_value, (int, float)):
205 +
        raise TypeError('"pad_value" must be int or float. Got "{}"'.format(type(pad_value)))
206 +
    _validate_doc(doc)
207 +
    _, indices = np.where(doc == pad_value)
208 +
    return indices
209 +
210 +
211 +
def _find_padding_tokens(pad_token, tokens):
212 +
    # type: (str, np.ndarray) -> np.ndarray
213 +
    if not isinstance(pad_token, str):
214 +
        raise TypeError('"pad_token" must be str. Got "{}"'.format(type(pad_token)))
215 +
    indices = np.where(tokens == pad_token)
216 +
    return indices
217 +
218 +
219 +
def _trim_padding(pad_indices, # type: np.ndarray
220 +
                  tokens, # type: np.ndarray
221 +
                  heatmap, # type: np.ndarray
222 +
                  ):
223 +
    # type: (...) -> Tuple[Union[list, np.ndarray], np.ndarray]
224 +
    """Remove padding from ``tokens`` and ``heatmap``."""
225 +
    # heatmap and tokens must be same length?
226 +
    if 0 < len(pad_indices):
227 +
        # found some padding symbols
228 +
229 +
        # delete all values along indices
230 +
        # this is not as robust as explicitly finding pre and post padding characters
231 +
        # and we can not detect and raise an error if there is padding in the middle of the text
232 +
        tokens = np.delete(tokens, pad_indices)
233 +
        heatmap = np.delete(heatmap, pad_indices)
234 +
    return tokens, heatmap
235 +
236 +
237 +
def _validate_doc(doc):
238 +
    """Check that ``doc`` has the right type."""
239 +
    if not isinstance(doc, np.ndarray):
240 +
        raise TypeError('"doc" must be an instance of numpy.ndarray. '
241 +
                        'Got "{}" (type "{}")'.format(doc, type(doc)))
242 +
243 +
244 +
def _validate_tokens(tokens):
245 +
    # type: (np.ndarray) -> None
246 +
    """Check that ``tokens`` contains correct items."""
247 +
    if not isinstance(tokens, (list, np.ndarray)):
248 +
        # wrong type
249 +
        raise TypeError('"tokens" must be list or numpy.ndarray. '
250 +
                        'Got "{}".'.format(tokens))
251 +
    if len(tokens) == 0:
252 +
        # empty list
253 +
        raise ValueError('"tokens" is empty: {}'.format(tokens))
254 +
255 +
256 +
def _validate_tokens_value(tokens, doc):
257 +
    # type: (Union[np.ndarray, list], np.ndarray) -> None
258 +
    """Check that the contents of ``tokens`` are consistent with ``doc``."""
259 +
    doc_batch, doc_len = doc.shape[0], doc.shape[1]
260 +
    an_entry = tokens[0]
261 +
    if isinstance(an_entry, str):
262 +
        # no batch
263 +
        if doc_batch != 1:
264 +
            # doc is batched but tokens is not
265 +
            raise ValueError('If passing "tokens" without batch dimension, '
266 +
                             '"doc" must have batch size = 1.'
267 +
                             'Got "doc" with batch size = %d.' % doc_batch)
268 +
        tokens_len = len(tokens)
269 +
    elif isinstance(an_entry, (list, np.ndarray)):
270 +
        # batched
271 +
        tokens_batch = len(tokens)
272 +
        if tokens_batch != doc_batch:
273 +
            # batch lengths do not match
274 +
            raise ValueError('"tokens" must have same number of samples '
275 +
                             'as in doc batch. Got: "tokens" samples: %d, '
276 +
                             'doc samples: %d' % (tokens_batch, doc_batch))
277 +
278 +
        a_token = an_entry[0]
279 +
        if not isinstance(a_token, str):
280 +
            # actual contents are not strings
281 +
            raise TypeError('Second axis in "tokens" must contain strings. '
282 +
                            'Found "{}" (type "{}")'.format(a_token, type(a_token)))
283 +
284 +
        # a way to check that all elements match some condition
285 +
        # https://stackoverflow.com/a/35791116/11555448
286 +
        it = iter(tokens)
287 +
        initial_length = len(next(it))
288 +
        if not all(len(l) == initial_length for l in it):
289 +
            raise ValueError('"tokens" samples do not all have the same length.')
290 +
        tokens_len = initial_length
291 +
    else:
292 +
        raise TypeError('"tokens" must be an array of strings, '
293 +
                        'or an array of string arrays. '
294 +
                        'Got "{}".'.format(tokens))
295 +
296 +
    if tokens_len != doc_len:
297 +
        raise ValueError('"tokens" and "doc" lengths must match. '
298 +
                         '"tokens" length: "%d". "doc" length: "%d"'
299 +
                         % (tokens_len, doc_len))

@@ -28,6 +28,7 @@
Loading
28 28
                 highlight_spaces=None,  # type: Optional[bool]
29 29
                 transition_features=None,  # type: Optional[TransitionFeatureWeights]
30 30
                 image=None, # type: Any
31 +
                 layer=None,  # type: Any
31 32
                 ):
32 33
        # type: (...) -> None
33 34
        self.estimator = estimator
@@ -40,7 +41,8 @@
Loading
40 41
        self.decision_tree = decision_tree
41 42
        self.highlight_spaces = highlight_spaces
42 43
        self.transition_features = transition_features
43 -
        self.image = image # if arg is not None, assume we are working with images
44 +
        self.image = image
45 +
        self.layer = layer
44 46
45 47
    def _repr_html_(self):
46 48
        """ HTML formatting for the notebook.
@@ -144,9 +146,13 @@
Loading
144 146
        self.other = other
145 147
146 148
149 +
# TODO: Can this be replaced with a namedtuple?
147 150
WeightedSpan = Tuple[
148 -
    Feature,
149 -
    List[Tuple[int, int]],  # list of spans (start, end) for this feature
151 +
    Feature, # feature name - i.e. token name such as 'john', 'software', 'sky'
152 +
    List[Tuple[int, int]],  # list of spans [start, end) for this feature
153 +
                            # indices into the document
154 +
                            # use a list when have a bag of words model
155 +
                            # and each feature has multiple spans (multiple occurrences)
150 156
    float,  # feature weight
151 157
]
152 158

@@ -0,0 +1,11 @@
Loading
1 +
# -*- coding: utf-8 -*-
2 +
3 +
from .gradcam import (
4 +
    gradcam_heatmap,
5 +
    get_localization_map,
6 +
    compute_weights,
7 +
)
8 +
from .text import (
9 +
    gradcam_spans,
10 +
    resize_1d,
11 +
)

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Click to load this diff.
Loading diff...

Learn more Showing 52 files with coverage changes found.

Changes in eli5/_graphviz.py
+11
Loading file...
Changes in eli5/sklearn/_span_analyzers.py
+60
Loading file...
Changes in eli5/sklearn/unhashing.py
+159
+3
+2
Loading file...
Changes in eli5/keras/__init__.py
+2
Loading file...
Changes in eli5/sklearn/explain_weights.py
+148
Loading file...
Changes in eli5/utils.py
+82
+1
Loading file...
Changes in eli5/formatters/utils.py
+72
+1
+1
Loading file...
Changes in eli5/_feature_names.py
+104
+4
+2
Loading file...
Changes in eli5/sklearn/utils.py
+128
+6
+8
Loading file...
Changes in eli5/lime/lime.py
+97
+2
+1
Loading file...
Changes in eli5/formatters/fields.py
+3
Loading file...
Changes in eli5/__init__.py
+40
+2
+6
Loading file...
Changes in eli5/_feature_weights.py
+43
Loading file...
Changes in eli5/sklearn_crfsuite/explain_weights.py
+54
Loading file...
Changes in eli5/keras/gradcam.py
+1
Loading file...
Changes in eli5/formatters/__init__.py
+14
Loading file...
Changes in eli5/sklearn/transform.py
+35
Loading file...
Changes in eli5/formatters/features.py
+11
+1
+2
Loading file...
Changes in eli5/transform.py
+4
+1
+1
Loading file...
Changes in eli5/ipython.py
+31
Loading file...
Changes in eli5/_feature_importances.py
+16
Loading file...
Changes in eli5/formatters/as_dict.py
+23
+1
Loading file...
Changes in eli5/base.py
+88
Loading file...
Changes in eli5/formatters/trees.py
+39
Loading file...
Changes in eli5/sklearn/__init__.py
+6
Loading file...
Changes in eli5/xgboost.py
+163
+1
Loading file...
Changes in eli5/_decision_path.py
+31
Loading file...
Changes in eli5/formatters/html.py
+110
Loading file...
Changes in eli5/formatters/text_helpers.py
+33
+1
+1
Loading file...
Changes in eli5/lightgbm.py
+111
+3
+3
Loading file...
Changes in eli5/lightning.py
+38
+2
Loading file...
Changes in eli5/lime/samplers.py
+150
+1
Loading file...
Changes in eli5/formatters/as_dataframe.py
+66
+2
+1
Loading file...
Changes in eli5/lime/__init__.py
+1
Loading file...
Changes in eli5/lime/_vectorizer.py
+33
+1
Loading file...
Changes in eli5/lime/utils.py
+64
+3
+7
Loading file...
Changes in eli5/formatters/text.py
+125
+1
Loading file...
Changes in eli5/catboost.py
+26
Loading file...
Changes in eli5/sklearn_crfsuite/__init__.py
+2
Loading file...
Changes in eli5/explain.py
+5
+2
Loading file...
Changes in eli5/lime/textutils.py
+100
Loading file...
Changes in eli5/sklearn/permutation_importance.py
+95
Loading file...
Changes in eli5/permutation_importance.py
+29
Loading file...
Changes in eli5/sklearn/treeinspect.py
+41
Loading file...
Changes in eli5/base_utils.py
+20
Loading file...
Changes in eli5/formatters/image.py
+31
Loading file...
Changes in eli5/sklearn/text.py
+91
+1
+1
Loading file...
Changes in eli5/sklearn/explain_prediction.py
+236
+3
+2
Loading file...
New file eli5/nn/gradcam.py
New
Loading file...
New file eli5/nn/text.py
New
Loading file...
New file eli5/nn/__init__.py
New
Loading file...
Changes in eli5/keras/explain_prediction.py
+9
Loading file...

159 Commits

Pull Request Base Commit
Files Coverage
eli5 0.02% 97.35%
Project Totals (52 files) 97.35%
Loading