mlprodict/onnxrt/onnx_inference.py
changed.
Newly tracked file
mlprodict/onnxrt/ops_cpu/op_loop.py
created.
Other files ignored by Codecov
HISTORY.rst
has changed.
220 | 220 | ".".format(type(self.ops_))) |
|
221 | 221 | if len(self.outputs) != len(res): |
|
222 | 222 | raise RuntimeError( # pragma: no cover |
|
223 | - | "Mismatch number of outputs got {} for names {} (node='{}')." |
|
223 | + | "Mismatch number of outputs got {} != {} for names {} (node='{}')." |
|
224 | 224 | "\n{}".format( |
|
225 | - | len(res), list(self.outputs), |
|
225 | + | len(res), len(self.outputs), list(self.outputs), |
|
226 | 226 | self.ops_.__class__.__name__, |
|
227 | - | pprint.pformat(self.desc))) |
|
227 | + | pprint.pformat(self.desc, depth=2))) |
|
228 | 228 | for name, value in zip(self.outputs, res): |
|
229 | 229 | values[name] = value |
|
230 | 230 | return values |
52 | 52 | from .op_linear_classifier import LinearClassifier |
|
53 | 53 | from .op_linear_regressor import LinearRegressor |
|
54 | 54 | from .op_log import Log |
|
55 | + | from .op_loop import Loop |
|
55 | 56 | from .op_lp_normalization import LpNormalization |
|
56 | 57 | from .op_matmul import MatMul |
|
57 | 58 | from .op_max import Max |
1 | + | # -*- encoding: utf-8 -*- |
|
2 | + | # pylint: disable=E0203,E1101,C0111 |
|
3 | + | """ |
|
4 | + | @file |
|
5 | + | @brief Runtime operator. |
|
6 | + | """ |
|
7 | + | import numpy |
|
8 | + | from ._op import OpRun |
|
9 | + | ||
10 | + | ||
11 | + | class Loop(OpRun): |
|
12 | + | ||
13 | + | atts = { |
|
14 | + | 'body': None, |
|
15 | + | } |
|
16 | + | ||
17 | + | def __init__(self, onnx_node, desc=None, **options): |
|
18 | + | OpRun.__init__(self, onnx_node, desc=desc, |
|
19 | + | expected_attributes=Loop.atts, |
|
20 | + | **options) |
|
21 | + | if not hasattr(self.body, 'run'): |
|
22 | + | raise RuntimeError("Parameter 'body' must have a method 'run', " |
|
23 | + | "type {}.".format(type(self.body))) |
|
24 | + | ||
25 | + | self._run_meth = (self.body.run_in_scan |
|
26 | + | if hasattr(self.body, 'run_in_scan') |
|
27 | + | else self.body.run) |
|
28 | + | ||
29 | + | def _run(self, M, cond, v_initial, *args): # pylint: disable=W0221 |
|
30 | + | inputs = {name: None for name in self.body.input_names} |
|
31 | + | inputs[self.body.input_names[0]] = cond |
|
32 | + | inputs[self.body.input_names[1]] = v_initial |
|
33 | + | cond_name = self.body.output_names[0] |
|
34 | + | if len(args) > 0: |
|
35 | + | begin = len(self.body.input_names) - len(args) |
|
36 | + | for name, val in zip(self.body.input_names[begin:], args): |
|
37 | + | inputs[name] = val |
|
38 | + | it = 0 |
|
39 | + | while cond and it < M: |
|
40 | + | outputs = self._run_meth_then(inputs) |
|
41 | + | cond = outputs[cond_name] |
|
42 | + | for i, o in zip(self.body.input_names[2:], |
|
43 | + | self.body.output_names[2:]): |
|
44 | + | inputs[i] = outputs[o] |
|
45 | + | it += 1 |
|
46 | + | if it == 0: |
|
47 | + | outputs = {self.body.output_names[1]: cond} |
|
48 | + | for i, o in zip(self.body.input_names[2:], |
|
49 | + | self.body.output_names[2:]): |
|
50 | + | outputs[o] = inputs[i] |
|
51 | + | for o in self.body.output_names: |
|
52 | + | if o not in outputs: |
|
53 | + | outputs[o] = numpy.empty(shape=tuple()) |
|
54 | + | return tuple([outputs[name] for name in self.body.output_names[1:]]) |
|
55 | + | ||
56 | + | def _infer_shapes(self, M, cond, v_initial, *args): # pylint: disable=W0221 |
|
57 | + | res = self.body._set_shape_inference_runtime() |
|
58 | + | return tuple([res[name] for name in self.body.output_names[1:]]) |
Files | Coverage |
---|---|
mlprodict | 91.25% |
Project Totals (225 files) | 91.25% |