1
# -*- encoding: utf-8 -*-
2
# pylint: disable=E0203,E1101,C0111
3 2
"""
4
@file
5
@brief Runtime operator.
6
"""
7 2
import numpy
8 2
from ._op import OpRun
9 2
from ..shape_object import ShapeObject
10

11

12 2
def _pad_impl(data, raw_pads, mode, constant_values=0.0):
13 2
    input_rank = data.ndim
14 2
    if input_rank * 2 != raw_pads.size:
15 0
        raise Exception(
16
            'The number of elements in raw_pads should be 2 * data_rank')
17

18 2
    half = raw_pads.shape[0] // 2
19 2
    pad_width = tuple((raw_pads[i], raw_pads[i + half])
20
                      for i in range(0, half))
21

22 2
    if mode == 'constant':
23 2
        return numpy.pad(data, pad_width=pad_width, mode=mode,
24
                         constant_values=constant_values)
25 2
    return numpy.pad(data, pad_width=pad_width, mode=mode)
26

27

28 2
class Pad(OpRun):
29

30 2
    atts = {'mode': b'constant'}
31

32 2
    def __init__(self, onnx_node, desc=None, **options):
33 2
        OpRun.__init__(self, onnx_node, desc=desc,
34
                       expected_attributes=Pad.atts,
35
                       **options)
36 2
        self.mode_ = self.mode.decode('ascii')
37

38 2
    def _run(self, data, pads, constant_value=None):  # pylint: disable=W0221
39 2
        return (_pad_impl(data, pads, mode=self.mode_,
40
                          constant_values=constant_value), )
41

42 2
    def _infer_shapes(self, data, pads, constant_value=None):  # pylint: disable=E0202,W0221
43
        """
44
        Returns the same shape by default.
45
        """
46 2
        return (ShapeObject(None, data.dtype), )

Read our documentation on viewing source code .

Loading