1
|
|
# Copyright (c) 2016, Thomas P. Robitaille
|
2
|
|
# All rights reserved.
|
3
|
|
#
|
4
|
|
# Redistribution and use in source and binary forms, with or without
|
5
|
|
# modification, are permitted provided that the following conditions are met:
|
6
|
|
#
|
7
|
|
# 1. Redistributions of source code must retain the above copyright notice,
|
8
|
|
# this list of conditions and the following disclaimer.
|
9
|
|
#
|
10
|
|
# 2. Redistributions in binary form must reproduce the above copyright notice,
|
11
|
|
# this list of conditions and the following disclaimer in the documentation
|
12
|
|
# and/or other materials provided with the distribution.
|
13
|
|
#
|
14
|
|
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
15
|
|
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
16
|
|
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
17
|
|
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
18
|
|
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
19
|
|
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
20
|
|
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
21
|
|
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
22
|
|
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
23
|
|
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
24
|
|
# POSSIBILITY OF SUCH DAMAGE.
|
25
|
|
#
|
26
|
|
# This package was derived from pytest-mpl, which is released under a BSD
|
27
|
|
# license and can be found here:
|
28
|
|
#
|
29
|
|
# https://github.com/astrofrog/pytest-mpl
|
30
|
|
|
31
|
|
|
32
|
3
|
from functools import wraps
|
33
|
|
|
34
|
3
|
import os
|
35
|
3
|
import abc
|
36
|
3
|
import shutil
|
37
|
3
|
import tempfile
|
38
|
3
|
import warnings
|
39
|
3
|
from distutils.version import StrictVersion
|
40
|
|
|
41
|
3
|
import six
|
42
|
3
|
from six.moves.urllib.request import urlopen
|
43
|
|
|
44
|
3
|
import pytest
|
45
|
3
|
import numpy as np
|
46
|
|
|
47
|
|
|
48
|
3
|
if six.PY2:
|
49
|
1
|
def abstractstaticmethod(func):
|
50
|
1
|
return func
|
51
|
1
|
def abstractclassmethod(func):
|
52
|
1
|
return func
|
53
|
|
else:
|
54
|
2
|
abstractstaticmethod = abc.abstractstaticmethod
|
55
|
2
|
abstractclassmethod = abc.abstractclassmethod
|
56
|
|
|
57
|
|
|
58
|
3
|
@six.add_metaclass(abc.ABCMeta)
|
59
|
3
|
class BaseDiff(object):
|
60
|
|
|
61
|
3
|
@abstractstaticmethod
|
62
|
|
def read(filename):
|
63
|
|
"""
|
64
|
|
Given a filename, return a data object.
|
65
|
|
"""
|
66
|
0
|
raise NotImplementedError()
|
67
|
|
|
68
|
3
|
@abstractstaticmethod
|
69
|
|
def write(filename, data, **kwargs):
|
70
|
|
"""
|
71
|
|
Given a filename and a data object (and optional keyword arguments),
|
72
|
|
write the data to a file.
|
73
|
|
"""
|
74
|
0
|
raise NotImplementedError()
|
75
|
|
|
76
|
3
|
@abstractclassmethod
|
77
|
3
|
def compare(self, reference_file, test_file, atol=None, rtol=None):
|
78
|
|
"""
|
79
|
|
Given a reference and test filename, compare the data to the specified
|
80
|
|
absolute (``atol``) and relative (``rtol``) tolerances.
|
81
|
|
|
82
|
|
Should return two arguments: a boolean indicating whether the data are
|
83
|
|
identical, and a string giving the full error message if not.
|
84
|
|
"""
|
85
|
0
|
raise NotImplementedError()
|
86
|
|
|
87
|
|
|
88
|
3
|
class SimpleArrayDiff(BaseDiff):
|
89
|
|
|
90
|
3
|
@classmethod
|
91
|
3
|
def compare(cls, reference_file, test_file, atol=None, rtol=None):
|
92
|
|
|
93
|
3
|
array_ref = cls.read(reference_file)
|
94
|
3
|
array_new = cls.read(test_file)
|
95
|
|
|
96
|
3
|
try:
|
97
|
3
|
np.testing.assert_allclose(array_ref, array_new, atol=atol, rtol=rtol)
|
98
|
0
|
except AssertionError as exc:
|
99
|
0
|
message = "\n\na: {0}".format(test_file) + '\n'
|
100
|
0
|
message += "b: {0}".format(reference_file) + '\n'
|
101
|
0
|
message += exc.args[0]
|
102
|
0
|
return False, message
|
103
|
|
else:
|
104
|
3
|
return True, ""
|
105
|
|
|
106
|
|
|
107
|
3
|
class FITSDiff(BaseDiff):
|
108
|
|
|
109
|
3
|
extension = 'fits'
|
110
|
|
|
111
|
3
|
@staticmethod
|
112
|
|
def read(filename):
|
113
|
0
|
from astropy.io import fits
|
114
|
0
|
return fits.getdata(filename)
|
115
|
|
|
116
|
3
|
@staticmethod
|
117
|
|
def write(filename, data, **kwargs):
|
118
|
3
|
from astropy.io import fits
|
119
|
3
|
if isinstance(data, np.ndarray):
|
120
|
3
|
data = fits.PrimaryHDU(data)
|
121
|
3
|
return data.writeto(filename, **kwargs)
|
122
|
|
|
123
|
3
|
@classmethod
|
124
|
3
|
def compare(cls, reference_file, test_file, atol=None, rtol=None):
|
125
|
3
|
import astropy
|
126
|
3
|
from astropy.io.fits.diff import FITSDiff
|
127
|
3
|
from astropy.utils.introspection import minversion
|
128
|
3
|
if minversion(astropy, '2.0'):
|
129
|
3
|
diff = FITSDiff(reference_file, test_file, rtol=rtol)
|
130
|
|
else:
|
131
|
0
|
diff = FITSDiff(reference_file, test_file, tolerance=rtol)
|
132
|
3
|
return diff.identical, diff.report()
|
133
|
|
|
134
|
|
|
135
|
3
|
class TextDiff(SimpleArrayDiff):
|
136
|
|
|
137
|
3
|
extension = 'txt'
|
138
|
|
|
139
|
3
|
@staticmethod
|
140
|
|
def read(filename):
|
141
|
3
|
return np.loadtxt(filename)
|
142
|
|
|
143
|
3
|
@staticmethod
|
144
|
|
def write(filename, data, **kwargs):
|
145
|
3
|
fmt = kwargs.get('fmt', '%g')
|
146
|
|
# Workaround for a known issue in `numpy.savetxt` for the `fmt` argument:
|
147
|
|
# https://github.com/numpy/numpy/pull/4053#issuecomment-263808221
|
148
|
|
# Convert `unicode` to `str` (i.e. bytes) on Python 2
|
149
|
3
|
if six.PY2 and isinstance(fmt, six.text_type):
|
150
|
0
|
fmt = fmt.encode('ascii')
|
151
|
|
|
152
|
3
|
kwargs['fmt'] = fmt
|
153
|
|
|
154
|
3
|
return np.savetxt(filename, data, **kwargs)
|
155
|
|
|
156
|
|
|
157
|
3
|
class PDHDFDiff(BaseDiff):
|
158
|
|
|
159
|
3
|
extension = 'h5'
|
160
|
|
|
161
|
3
|
@staticmethod
|
162
|
|
def read(filename):
|
163
|
0
|
import pandas as pd
|
164
|
0
|
return pd.read_hdf(filename)
|
165
|
|
|
166
|
3
|
@staticmethod
|
167
|
|
def write(filename, data, **kwargs):
|
168
|
3
|
import pandas as pd
|
169
|
3
|
key = os.path.basename(filename).replace('.h5', '')
|
170
|
3
|
return data.to_hdf(filename, key, **kwargs)
|
171
|
|
|
172
|
3
|
@classmethod
|
173
|
3
|
def compare(cls, reference_file, test_file, atol=None, rtol=None):
|
174
|
3
|
import pandas.testing as pdt
|
175
|
3
|
import pandas as pd
|
176
|
|
|
177
|
3
|
ref_data = pd.read_hdf(reference_file)
|
178
|
3
|
test_data = pd.read_hdf(test_file)
|
179
|
3
|
try:
|
180
|
3
|
pdt.assert_frame_equal(ref_data, test_data)
|
181
|
0
|
except AssertionError as exc:
|
182
|
0
|
message = "\n\na: {0}".format(test_file) + '\n'
|
183
|
0
|
message += "b: {0}".format(reference_file) + '\n'
|
184
|
0
|
message += exc.args[0]
|
185
|
0
|
return False, message
|
186
|
|
else:
|
187
|
3
|
return True, ""
|
188
|
|
|
189
|
|
|
190
|
3
|
FORMATS = {}
|
191
|
3
|
FORMATS['fits'] = FITSDiff
|
192
|
3
|
FORMATS['text'] = TextDiff
|
193
|
3
|
FORMATS['pdhdf'] = PDHDFDiff
|
194
|
|
|
195
|
|
|
196
|
3
|
def _download_file(url):
|
197
|
0
|
u = urlopen(url)
|
198
|
0
|
result_dir = tempfile.mkdtemp()
|
199
|
0
|
filename = os.path.join(result_dir, 'downloaded')
|
200
|
0
|
with open(filename, 'wb') as tmpfile:
|
201
|
0
|
tmpfile.write(u.read())
|
202
|
0
|
return filename
|
203
|
|
|
204
|
|
|
205
|
3
|
def pytest_addoption(parser):
|
206
|
3
|
group = parser.getgroup("general")
|
207
|
3
|
group.addoption('--arraydiff', action='store_true',
|
208
|
|
help="Enable comparison of arrays to reference arrays stored in files")
|
209
|
3
|
group.addoption('--arraydiff-generate-path',
|
210
|
|
help="directory to generate reference files in, relative to location where py.test is run", action='store')
|
211
|
3
|
group.addoption('--arraydiff-reference-path',
|
212
|
|
help="directory containing reference files, relative to location where py.test is run", action='store')
|
213
|
3
|
group.addoption('--arraydiff-default-format',
|
214
|
|
help="Default format for the reference arrays (can be 'fits' or 'text' currently)")
|
215
|
|
|
216
|
|
|
217
|
3
|
def pytest_configure(config):
|
218
|
|
|
219
|
3
|
if config.getoption("--arraydiff") or config.getoption("--arraydiff-generate-path") is not None:
|
220
|
|
|
221
|
3
|
reference_dir = config.getoption("--arraydiff-reference-path")
|
222
|
3
|
generate_dir = config.getoption("--arraydiff-generate-path")
|
223
|
|
|
224
|
3
|
if reference_dir is not None and generate_dir is not None:
|
225
|
0
|
warnings.warn("Ignoring --arraydiff-reference-path since --arraydiff-generate-path is set")
|
226
|
|
|
227
|
3
|
if reference_dir is not None:
|
228
|
0
|
reference_dir = os.path.abspath(reference_dir)
|
229
|
3
|
if generate_dir is not None:
|
230
|
3
|
reference_dir = os.path.abspath(generate_dir)
|
231
|
|
|
232
|
3
|
default_format = config.getoption("--arraydiff-default-format") or 'text'
|
233
|
|
|
234
|
3
|
config.pluginmanager.register(ArrayComparison(config,
|
235
|
|
reference_dir=reference_dir,
|
236
|
|
generate_dir=generate_dir,
|
237
|
|
default_format=default_format))
|
238
|
|
|
239
|
|
|
240
|
3
|
class ArrayComparison(object):
|
241
|
|
|
242
|
3
|
def __init__(self, config, reference_dir=None, generate_dir=None, default_format='text'):
|
243
|
3
|
self.config = config
|
244
|
3
|
self.reference_dir = reference_dir
|
245
|
3
|
self.generate_dir = generate_dir
|
246
|
3
|
self.default_format = default_format
|
247
|
|
|
248
|
3
|
def pytest_runtest_setup(self, item):
|
249
|
|
|
250
|
3
|
if StrictVersion(pytest.__version__) < StrictVersion("3.6"):
|
251
|
0
|
compare = item.get_marker('array_compare')
|
252
|
|
else:
|
253
|
3
|
compare = item.get_closest_marker('array_compare')
|
254
|
|
|
255
|
3
|
if compare is None:
|
256
|
3
|
return
|
257
|
|
|
258
|
3
|
file_format = compare.kwargs.get('file_format', self.default_format)
|
259
|
|
|
260
|
3
|
if file_format not in FORMATS:
|
261
|
0
|
raise ValueError("Unknown format: {0}".format(file_format))
|
262
|
|
|
263
|
3
|
if 'extension' in compare.kwargs:
|
264
|
0
|
extension = compare.kwargs['extension']
|
265
|
|
else:
|
266
|
3
|
extension = FORMATS[file_format].extension
|
267
|
|
|
268
|
3
|
atol = compare.kwargs.get('atol', 0.)
|
269
|
3
|
rtol = compare.kwargs.get('rtol', 1e-7)
|
270
|
|
|
271
|
3
|
single_reference = compare.kwargs.get('single_reference', False)
|
272
|
|
|
273
|
3
|
write_kwargs = compare.kwargs.get('write_kwargs', {})
|
274
|
|
|
275
|
3
|
original = item.function
|
276
|
|
|
277
|
3
|
@wraps(item.function)
|
278
|
|
def item_function_wrapper(*args, **kwargs):
|
279
|
|
|
280
|
3
|
reference_dir = compare.kwargs.get('reference_dir', None)
|
281
|
3
|
if reference_dir is None:
|
282
|
3
|
if self.reference_dir is None:
|
283
|
3
|
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
|
284
|
|
else:
|
285
|
3
|
reference_dir = self.reference_dir
|
286
|
|
else:
|
287
|
3
|
if not reference_dir.startswith(('http://', 'https://')):
|
288
|
3
|
reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
|
289
|
|
|
290
|
3
|
baseline_remote = reference_dir.startswith('http')
|
291
|
|
|
292
|
|
# Run test and get figure object
|
293
|
3
|
import inspect
|
294
|
3
|
if inspect.ismethod(original): # method
|
295
|
0
|
array = original(*args[1:], **kwargs)
|
296
|
|
else: # function
|
297
|
3
|
array = original(*args, **kwargs)
|
298
|
|
|
299
|
|
# Find test name to use as plot name
|
300
|
3
|
filename = compare.kwargs.get('filename', None)
|
301
|
3
|
if filename is None:
|
302
|
3
|
if single_reference:
|
303
|
0
|
filename = original.__name__ + '.' + extension
|
304
|
|
else:
|
305
|
3
|
filename = item.name + '.' + extension
|
306
|
3
|
filename = filename.replace('[', '_').replace(']', '_')
|
307
|
3
|
filename = filename.replace('_.' + extension, '.' + extension)
|
308
|
|
|
309
|
|
# What we do now depends on whether we are generating the reference
|
310
|
|
# files or simply running the test.
|
311
|
3
|
if self.generate_dir is None:
|
312
|
|
|
313
|
|
# Save the figure
|
314
|
3
|
result_dir = tempfile.mkdtemp()
|
315
|
3
|
test_array = os.path.abspath(os.path.join(result_dir, filename))
|
316
|
|
|
317
|
3
|
FORMATS[file_format].write(test_array, array, **write_kwargs)
|
318
|
|
|
319
|
|
# Find path to baseline array
|
320
|
3
|
if baseline_remote:
|
321
|
0
|
baseline_file_ref = _download_file(reference_dir + filename)
|
322
|
|
else:
|
323
|
3
|
baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
|
324
|
|
|
325
|
3
|
if not os.path.exists(baseline_file_ref):
|
326
|
3
|
raise Exception("""File not found for comparison test
|
327
|
|
Generated file:
|
328
|
|
\t{test}
|
329
|
|
This is expected for new tests.""".format(
|
330
|
|
test=test_array))
|
331
|
|
|
332
|
|
# distutils may put the baseline arrays in non-accessible places,
|
333
|
|
# copy to our tmpdir to be sure to keep them in case of failure
|
334
|
3
|
baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
|
335
|
3
|
shutil.copyfile(baseline_file_ref, baseline_file)
|
336
|
|
|
337
|
3
|
identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
|
338
|
|
|
339
|
3
|
if identical:
|
340
|
3
|
shutil.rmtree(result_dir)
|
341
|
|
else:
|
342
|
0
|
raise Exception(msg)
|
343
|
|
|
344
|
|
else:
|
345
|
|
|
346
|
3
|
if not os.path.exists(self.generate_dir):
|
347
|
3
|
os.makedirs(self.generate_dir)
|
348
|
|
|
349
|
3
|
FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
|
350
|
|
|
351
|
3
|
pytest.skip("Skipping test, since generating data")
|
352
|
|
|
353
|
3
|
if item.cls is not None:
|
354
|
3
|
setattr(item.cls, item.function.__name__, item_function_wrapper)
|
355
|
|
else:
|
356
|
3
|
item.obj = item_function_wrapper
|