1
# Copyright (c) 2016, Thomas P. Robitaille
2
# All rights reserved.
3
#
4
# Redistribution and use in source and binary forms, with or without
5
# modification, are permitted provided that the following conditions are met:
6
#
7
# 1. Redistributions of source code must retain the above copyright notice,
8
# this list of conditions and the following disclaimer.
9
#
10
# 2. Redistributions in binary form must reproduce the above copyright notice,
11
# this list of conditions and the following disclaimer in the documentation
12
# and/or other materials provided with the distribution.
13
#
14
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
15
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
16
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
17
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
18
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
19
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
20
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
21
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
22
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
23
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
24
# POSSIBILITY OF SUCH DAMAGE.
25
#
26
# This package was derived from pytest-mpl, which is released under a BSD
27
# license and can be found here:
28
#
29
#   https://github.com/astrofrog/pytest-mpl
30

31

32 3
from functools import wraps
33

34 3
import os
35 3
import abc
36 3
import shutil
37 3
import tempfile
38 3
import warnings
39 3
from distutils.version import StrictVersion
40

41 3
import six
42 3
from six.moves.urllib.request import urlopen
43

44 3
import pytest
45 3
import numpy as np
46

47

48 3
if six.PY2:
49 1
    def abstractstaticmethod(func):
50 1
        return func
51 1
    def abstractclassmethod(func):
52 1
        return func
53
else:
54 2
    abstractstaticmethod = abc.abstractstaticmethod
55 2
    abstractclassmethod = abc.abstractclassmethod
56

57

58 3
@six.add_metaclass(abc.ABCMeta)
59 3
class BaseDiff(object):
60

61 3
    @abstractstaticmethod
62
    def read(filename):
63
        """
64
        Given a filename, return a data object.
65
        """
66 0
        raise NotImplementedError()
67

68 3
    @abstractstaticmethod
69
    def write(filename, data, **kwargs):
70
        """
71
        Given a filename and a data object (and optional keyword arguments),
72
        write the data to a file.
73
        """
74 0
        raise NotImplementedError()
75

76 3
    @abstractclassmethod
77 3
    def compare(self, reference_file, test_file, atol=None, rtol=None):
78
        """
79
        Given a reference and test filename, compare the data to the specified
80
        absolute (``atol``) and relative (``rtol``) tolerances.
81

82
        Should return two arguments: a boolean indicating whether the data are
83
        identical, and a string giving the full error message if not.
84
        """
85 0
        raise NotImplementedError()
86

87

88 3
class SimpleArrayDiff(BaseDiff):
89

90 3
    @classmethod
91 3
    def compare(cls, reference_file, test_file, atol=None, rtol=None):
92

93 3
        array_ref = cls.read(reference_file)
94 3
        array_new = cls.read(test_file)
95

96 3
        try:
97 3
            np.testing.assert_allclose(array_ref, array_new, atol=atol, rtol=rtol)
98 0
        except AssertionError as exc:
99 0
            message = "\n\na: {0}".format(test_file) + '\n'
100 0
            message += "b: {0}".format(reference_file) + '\n'
101 0
            message += exc.args[0]
102 0
            return False, message
103
        else:
104 3
            return True, ""
105

106

107 3
class FITSDiff(BaseDiff):
108

109 3
    extension = 'fits'
110

111 3
    @staticmethod
112
    def read(filename):
113 0
        from astropy.io import fits
114 0
        return fits.getdata(filename)
115

116 3
    @staticmethod
117
    def write(filename, data, **kwargs):
118 3
        from astropy.io import fits
119 3
        if isinstance(data, np.ndarray):
120 3
            data = fits.PrimaryHDU(data)
121 3
        return data.writeto(filename, **kwargs)
122

123 3
    @classmethod
124 3
    def compare(cls, reference_file, test_file, atol=None, rtol=None):
125 3
        import astropy
126 3
        from astropy.io.fits.diff import FITSDiff
127 3
        from astropy.utils.introspection import minversion
128 3
        if minversion(astropy, '2.0'):
129 3
            diff = FITSDiff(reference_file, test_file, rtol=rtol)
130
        else:
131 0
            diff = FITSDiff(reference_file, test_file, tolerance=rtol)
132 3
        return diff.identical, diff.report()
133

134

135 3
class TextDiff(SimpleArrayDiff):
136

137 3
    extension = 'txt'
138

139 3
    @staticmethod
140
    def read(filename):
141 3
        return np.loadtxt(filename)
142

143 3
    @staticmethod
144
    def write(filename, data, **kwargs):
145 3
        fmt = kwargs.get('fmt', '%g')
146
        # Workaround for a known issue in `numpy.savetxt` for the `fmt` argument:
147
        # https://github.com/numpy/numpy/pull/4053#issuecomment-263808221
148
        # Convert `unicode` to `str` (i.e. bytes) on Python 2
149 3
        if six.PY2 and isinstance(fmt, six.text_type):
150 0
            fmt = fmt.encode('ascii')
151

152 3
        kwargs['fmt'] = fmt
153

154 3
        return np.savetxt(filename, data, **kwargs)
155

156

157 3
class PDHDFDiff(BaseDiff):
158

159 3
    extension = 'h5'
160

161 3
    @staticmethod
162
    def read(filename):
163 0
        import pandas as pd
164 0
        return pd.read_hdf(filename)
165

166 3
    @staticmethod
167
    def write(filename, data, **kwargs):
168 0
        import pandas as pd
169 0
        key = os.path.basename(filename).replace('.h5', '')
170 0
        return data.to_hdf(filename, key, **kwargs)
171

172 3
    @classmethod
173 3
    def compare(cls, reference_file, test_file, atol=None, rtol=None):
174 0
        import pandas.testing as pdt
175

176

177 0
        try:
178 0
            pdt.assert_frame_equal(reference_file, test_file)
179 0
        except AssertionError as exc:
180 0
            message = "\n\na: {0}".format(test_file) + '\n'
181 0
            message += "b: {0}".format(reference_file) + '\n'
182 0
            message += exc.args[0]
183 0
            return False, message
184
        else:
185 0
            return True, ""
186

187

188 3
FORMATS = {}
189 3
FORMATS['fits'] = FITSDiff
190 3
FORMATS['text'] = TextDiff
191 3
FORMATS['pdhdf'] = PDHDFDiff
192

193

194 3
def _download_file(url):
195 0
    u = urlopen(url)
196 0
    result_dir = tempfile.mkdtemp()
197 0
    filename = os.path.join(result_dir, 'downloaded')
198 0
    with open(filename, 'wb') as tmpfile:
199 0
        tmpfile.write(u.read())
200 0
    return filename
201

202

203 3
def pytest_addoption(parser):
204 3
    group = parser.getgroup("general")
205 3
    group.addoption('--arraydiff', action='store_true',
206
                    help="Enable comparison of arrays to reference arrays stored in files")
207 3
    group.addoption('--arraydiff-generate-path',
208
                    help="directory to generate reference files in, relative to location where py.test is run", action='store')
209 3
    group.addoption('--arraydiff-reference-path',
210
                    help="directory containing reference files, relative to location where py.test is run", action='store')
211 3
    group.addoption('--arraydiff-default-format',
212
                    help="Default format for the reference arrays (can be 'fits' or 'text' currently)")
213

214

215 3
def pytest_configure(config):
216

217 3
    if config.getoption("--arraydiff") or config.getoption("--arraydiff-generate-path") is not None:
218

219 3
        reference_dir = config.getoption("--arraydiff-reference-path")
220 3
        generate_dir = config.getoption("--arraydiff-generate-path")
221

222 3
        if reference_dir is not None and generate_dir is not None:
223 0
            warnings.warn("Ignoring --arraydiff-reference-path since --arraydiff-generate-path is set")
224

225 3
        if reference_dir is not None:
226 0
            reference_dir = os.path.abspath(reference_dir)
227 3
        if generate_dir is not None:
228 3
            reference_dir = os.path.abspath(generate_dir)
229

230 3
        default_format = config.getoption("--arraydiff-default-format") or 'text'
231

232 3
        config.pluginmanager.register(ArrayComparison(config,
233
                                                      reference_dir=reference_dir,
234
                                                      generate_dir=generate_dir,
235
                                                      default_format=default_format))
236

237

238 3
class ArrayComparison(object):
239

240 3
    def __init__(self, config, reference_dir=None, generate_dir=None, default_format='text'):
241 3
        self.config = config
242 3
        self.reference_dir = reference_dir
243 3
        self.generate_dir = generate_dir
244 3
        self.default_format = default_format
245

246 3
    def pytest_runtest_setup(self, item):
247

248 3
        if StrictVersion(pytest.__version__) < StrictVersion("3.6"):
249 0
            compare = item.get_marker('array_compare')
250
        else:
251 3
            compare = item.get_closest_marker('array_compare')
252

253 3
        if compare is None:
254 3
            return
255

256 3
        file_format = compare.kwargs.get('file_format', self.default_format)
257

258 3
        if file_format not in FORMATS:
259 0
            raise ValueError("Unknown format: {0}".format(file_format))
260

261 3
        if 'extension' in compare.kwargs:
262 0
            extension = compare.kwargs['extension']
263
        else:
264 3
            extension = FORMATS[file_format].extension
265

266 3
        atol = compare.kwargs.get('atol', 0.)
267 3
        rtol = compare.kwargs.get('rtol', 1e-7)
268

269 3
        single_reference = compare.kwargs.get('single_reference', False)
270

271 3
        write_kwargs = compare.kwargs.get('write_kwargs', {})
272

273 3
        original = item.function
274

275 3
        @wraps(item.function)
276
        def item_function_wrapper(*args, **kwargs):
277

278 3
            reference_dir = compare.kwargs.get('reference_dir', None)
279 3
            if reference_dir is None:
280 3
                if self.reference_dir is None:
281 3
                    reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), 'reference')
282
                else:
283 3
                    reference_dir = self.reference_dir
284
            else:
285 3
                if not reference_dir.startswith(('http://', 'https://')):
286 3
                    reference_dir = os.path.join(os.path.dirname(item.fspath.strpath), reference_dir)
287

288 3
            baseline_remote = reference_dir.startswith('http')
289

290
            # Run test and get figure object
291 3
            import inspect
292 3
            if inspect.ismethod(original):  # method
293 0
                array = original(*args[1:], **kwargs)
294
            else:  # function
295 3
                array = original(*args, **kwargs)
296

297
            # Find test name to use as plot name
298 3
            filename = compare.kwargs.get('filename', None)
299 3
            if filename is None:
300 3
                if single_reference:
301 0
                    filename = original.__name__ + '.' + extension
302
                else:
303 3
                    filename = item.name + '.' + extension
304 3
                    filename = filename.replace('[', '_').replace(']', '_')
305 3
                    filename = filename.replace('_.' + extension, '.' + extension)
306

307
            # What we do now depends on whether we are generating the reference
308
            # files or simply running the test.
309 3
            if self.generate_dir is None:
310

311
                # Save the figure
312 3
                result_dir = tempfile.mkdtemp()
313 3
                test_array = os.path.abspath(os.path.join(result_dir, filename))
314

315 3
                FORMATS[file_format].write(test_array, array, **write_kwargs)
316

317
                # Find path to baseline array
318 3
                if baseline_remote:
319 0
                    baseline_file_ref = _download_file(reference_dir + filename)
320
                else:
321 3
                    baseline_file_ref = os.path.abspath(os.path.join(os.path.dirname(item.fspath.strpath), reference_dir, filename))
322

323 3
                if not os.path.exists(baseline_file_ref):
324 3
                    raise Exception("""File not found for comparison test
325
                                    Generated file:
326
                                    \t{test}
327
                                    This is expected for new tests.""".format(
328
                        test=test_array))
329

330
                # distutils may put the baseline arrays in non-accessible places,
331
                # copy to our tmpdir to be sure to keep them in case of failure
332 3
                baseline_file = os.path.abspath(os.path.join(result_dir, 'reference-' + filename))
333 3
                shutil.copyfile(baseline_file_ref, baseline_file)
334

335 3
                identical, msg = FORMATS[file_format].compare(baseline_file, test_array, atol=atol, rtol=rtol)
336

337 3
                if identical:
338 3
                    shutil.rmtree(result_dir)
339
                else:
340 0
                    raise Exception(msg)
341

342
            else:
343

344 3
                if not os.path.exists(self.generate_dir):
345 3
                    os.makedirs(self.generate_dir)
346

347 3
                FORMATS[file_format].write(os.path.abspath(os.path.join(self.generate_dir, filename)), array, **write_kwargs)
348

349 3
                pytest.skip("Skipping test, since generating data")
350

351 3
        if item.cls is not None:
352 3
            setattr(item.cls, item.function.__name__, item_function_wrapper)
353
        else:
354 3
            item.obj = item_function_wrapper

Read our documentation on viewing source code .

Loading