astropy / astroquery
1
# Licensed under a 3-clause BSD style license - see LICENSE.rst
2 1
from __future__ import (absolute_import, division, print_function,
3
                        unicode_literals)
4 1
import abc
5 1
import inspect
6 1
import pickle
7 1
import getpass
8 1
import hashlib
9 1
import keyring
10 1
import io
11 1
import os
12 1
import requests
13

14 1
import six
15 1
from astropy.config import paths
16 1
from astropy.logger import log
17 1
import astropy.units as u
18 1
from astropy.utils.console import ProgressBarOrSpinner
19 1
import astropy.utils.data
20

21 1
from . import version
22 1
from .utils import system_tools
23

24 1
__all__ = ['BaseQuery', 'QueryWithLogin']
25

26

27 1
def to_cache(response, cache_file):
28 1
    log.debug("Caching data to {0}".format(cache_file))
29 1
    with open(cache_file, "wb") as f:
30 1
        pickle.dump(response, f)
31

32

33 1
def _replace_none_iterable(iterable):
34 1
    return tuple('' if i is None else i for i in iterable)
35

36

37 1
class AstroQuery(object):
38

39 1
    def __init__(self, method, url,
40
                 params=None, data=None, headers=None,
41
                 files=None, timeout=None, json=None):
42 1
        self.method = method
43 1
        self.url = url
44 1
        self.params = params
45 1
        self.data = data
46 1
        self.json = json
47 1
        self.headers = headers
48 1
        self.files = files
49 1
        self._hash = None
50 1
        self.timeout = timeout
51

52 1
    @property
53
    def timeout(self):
54 1
        return self._timeout
55

56 1
    @timeout.setter
57
    def timeout(self, value):
58 1
        if hasattr(value, 'to'):
59 0
            self._timeout = value.to(u.s).value
60
        else:
61 1
            self._timeout = value
62

63 1
    def request(self, session, cache_location=None, stream=False,
64
                auth=None, verify=True, allow_redirects=True,
65
                json=None):
66 1
        return session.request(self.method, self.url, params=self.params,
67
                               data=self.data, headers=self.headers,
68
                               files=self.files, timeout=self.timeout,
69
                               stream=stream, auth=auth, verify=verify,
70
                               allow_redirects=allow_redirects,
71
                               json=json)
72

73 1
    def hash(self):
74 1
        if self._hash is None:
75 1
            request_key = (self.method, self.url)
76 1
            for k in (self.params, self.data, self.json,
77
                      self.headers, self.files):
78 1
                if isinstance(k, dict):
79 1
                    entry = (tuple(sorted(k.items(),
80
                                          key=_replace_none_iterable)))
81 1
                    entry = tuple((k_, v_.read()) if hasattr(v_, 'read')
82
                                  else (k_, v_) for k_, v_ in entry)
83 1
                    for k_, v_ in entry:
84 1
                        if hasattr(v_, 'read') and hasattr(v_, 'seek'):
85 0
                            v_.seek(0)
86

87 1
                    request_key += entry
88 1
                elif isinstance(k, tuple) or isinstance(k, list):
89 0
                    request_key += (tuple(sorted(k,
90
                                                 key=_replace_none_iterable)),)
91 1
                elif k is None:
92 1
                    request_key += (None,)
93 1
                elif isinstance(k, six.string_types):
94 1
                    request_key += (k,)
95
                else:
96 0
                    raise TypeError("{0} must be a dict, tuple, str, or "
97
                                    "list".format(k))
98 1
            self._hash = hashlib.sha224(pickle.dumps(request_key)).hexdigest()
99 1
        return self._hash
100

101 1
    def request_file(self, cache_location):
102 1
        fn = os.path.join(cache_location, self.hash() + ".pickle")
103 1
        return fn
104

105 1
    def from_cache(self, cache_location):
106 1
        request_file = self.request_file(cache_location)
107 1
        try:
108 1
            with open(request_file, "rb") as f:
109 1
                response = pickle.load(f)
110 1
            if not isinstance(response, requests.Response):
111 1
                response = None
112 1
        except IOError:  # TODO: change to FileNotFoundError once drop py2 support
113 1
            response = None
114 1
        if response:
115 0
            log.debug("Retrieving data from {0}".format(request_file))
116 1
        return response
117

118

119 1
class LoginABCMeta(abc.ABCMeta):
120
    """
121
    The goal of this metaclass is to copy the docstring and signature from
122
    ._login methods, implemented in subclasses, to a .login method that is
123
    visible by the users.
124

125
    It also inherits from the ABCMeta metaclass as _login is an abstract
126
    method.
127

128
    """
129

130 1
    def __new__(cls, name, bases, attrs):
131 1
        newcls = super(LoginABCMeta, cls).__new__(cls, name, bases, attrs)
132

133 1
        if '_login' in attrs and name not in ('BaseQuery', 'QueryWithLogin'):
134
            # skip theses two classes, BaseQuery and QueryWithLogin, so
135
            # below bases[0] should always be QueryWithLogin.
136 1
            def login(*args, **kwargs):
137 1
                bases[0].login(*args, **kwargs)
138

139 1
            login.__doc__ = attrs['_login'].__doc__
140 1
            if not six.PY2:
141 1
                login.__signature__ = inspect.signature(attrs['_login'])
142 1
            setattr(newcls, login.__name__, login)
143

144 1
        return newcls
145

146

147 1
@six.add_metaclass(LoginABCMeta)
148 1
class BaseQuery(object):
149
    """
150
    This is the base class for all the query classes in astroquery. It
151
    is implemented as an abstract class and must not be directly instantiated.
152
    """
153

154 1
    def __init__(self):
155 1
        S = self._session = requests.session()
156 1
        S.headers['User-Agent'] = (
157
            'astroquery/{vers} {olduseragent}'
158
            .format(vers=version.version,
159
                    olduseragent=S.headers['User-Agent']))
160

161 1
        self.cache_location = os.path.join(
162
            paths.get_cache_dir(), 'astroquery',
163
            self.__class__.__name__.split("Class")[0])
164 1
        if not os.path.exists(self.cache_location):
165 1
            os.makedirs(self.cache_location)
166 1
        self._cache_active = True
167

168 1
    def __call__(self, *args, **kwargs):
169
        """ init a fresh copy of self """
170 1
        return self.__class__(*args, **kwargs)
171

172 1
    def _request(self, method, url,
173
                 params=None, data=None, headers=None,
174
                 files=None, save=False, savedir='', timeout=None, cache=True,
175
                 stream=False, auth=None, continuation=True, verify=True,
176
                 allow_redirects=True,
177
                 json=None):
178
        """
179
        A generic HTTP request method, similar to `requests.Session.request`
180
        but with added caching-related tools
181

182
        This is a low-level method not generally intended for use by astroquery
183
        end-users.  However, it should _always_ be used by astroquery
184
        developers; direct uses of `urllib` or `requests` are almost never
185
        correct.
186

187
        Parameters
188
        ----------
189
        method : str
190
            'GET' or 'POST'
191
        url : str
192
        params : None or dict
193
        data : None or dict
194
        json : None or dict
195
        headers : None or dict
196
        auth : None or dict
197
        files : None or dict
198
            See `requests.request`
199
        save : bool
200
            Whether to save the file to a local directory.  Caching will happen
201
            independent of this parameter if `BaseQuery.cache_location` is set,
202
            but the save location can be overridden if ``save==True``
203
        savedir : str
204
            The location to save the local file if you want to save it
205
            somewhere other than `BaseQuery.cache_location`
206
        timeout : int
207
        cache : bool
208
        verify : bool
209
            Verify the server's TLS certificate?
210
            (see http://docs.python-requests.org/en/master/_modules/requests/sessions/?highlight=verify)
211
        continuation : bool
212
            If the file is partly downloaded to the target location, this
213
            parameter will try to continue the download where it left off.
214
            See `_download_file`.
215
        stream : bool
216

217
        Returns
218
        -------
219
        response : `requests.Response`
220
            The response from the server if ``save`` is False
221
        local_filepath : list
222
            a list of strings containing the downloaded local paths if ``save``
223
            is True
224
        """
225 1
        req_kwargs = dict(
226
            params=params,
227
            data=data,
228
            headers=headers,
229
            files=files,
230
            timeout=timeout,
231
            json=json
232
        )
233 1
        if save:
234 0
            local_filename = url.split('/')[-1]
235 0
            if os.name == 'nt':
236
                # Windows doesn't allow special characters in filenames like
237
                # ":" so replace them with an underscore
238 0
                local_filename = local_filename.replace(':', '_')
239 0
            local_filepath = os.path.join(savedir or self.cache_location or '.', local_filename)
240

241 0
            self._download_file(url, local_filepath, cache=cache,
242
                                continuation=continuation, method=method,
243
                                allow_redirects=allow_redirects,
244
                                auth=auth, **req_kwargs)
245 0
            return local_filepath
246
        else:
247 1
            query = AstroQuery(method, url, **req_kwargs)
248 1
            if ((self.cache_location is None) or (not self._cache_active) or (not cache)):
249 1
                with suspend_cache(self):
250 1
                    response = query.request(self._session, stream=stream,
251
                                             auth=auth, verify=verify,
252
                                             allow_redirects=allow_redirects,
253
                                             json=json)
254
            else:
255 1
                response = query.from_cache(self.cache_location)
256 1
                if not response:
257 1
                    response = query.request(self._session,
258
                                             self.cache_location,
259
                                             stream=stream,
260
                                             auth=auth,
261
                                             allow_redirects=allow_redirects,
262
                                             verify=verify,
263
                                             json=json)
264 1
                    to_cache(response, query.request_file(self.cache_location))
265 1
            self._last_query = query
266 1
            return response
267

268 1
    def _download_file(self, url, local_filepath, timeout=None, auth=None,
269
                       continuation=True, cache=False, method="GET",
270
                       head_safe=False, **kwargs):
271
        """
272
        Download a file.  Resembles `astropy.utils.data.download_file` but uses
273
        the local ``_session``
274

275
        Parameters
276
        ----------
277
        url : string
278
        local_filepath : string
279
        timeout : int
280
        auth : dict or None
281
        continuation : bool
282
            If the file has already been partially downloaded *and* the server
283
            supports HTTP "range" requests, the download will be continued
284
            where it left off.
285
        cache : bool
286
        method : "GET" or "POST"
287
        head_safe : bool
288
        """
289

290 0
        if head_safe:
291 0
            response = self._session.request("HEAD", url,
292
                                             timeout=timeout, stream=True,
293
                                             auth=auth, **kwargs)
294
        else:
295 0
            response = self._session.request(method, url,
296
                                             timeout=timeout, stream=True,
297
                                             auth=auth, **kwargs)
298

299 0
        response.raise_for_status()
300 0
        if 'content-length' in response.headers:
301 0
            length = int(response.headers['content-length'])
302 0
            if length == 0:
303 0
                log.warn('URL {0} has length=0'.format(url))
304
        else:
305 0
            length = None
306

307 0
        if ((os.path.exists(local_filepath)
308
             and ('Accept-Ranges' in response.headers)
309
             and continuation)):
310 0
            open_mode = 'ab'
311

312 0
            existing_file_length = os.stat(local_filepath).st_size
313 0
            if length is not None and existing_file_length >= length:
314
                # all done!
315 0
                log.info("Found cached file {0} with expected size {1}."
316
                         .format(local_filepath, existing_file_length))
317 0
                return
318 0
            elif existing_file_length == 0:
319 0
                open_mode = 'wb'
320
            else:
321 0
                log.info("Continuing download of file {0}, with {1} bytes to "
322
                         "go ({2}%)".format(local_filepath,
323
                                            length - existing_file_length,
324
                                            (length-existing_file_length)/length*100))
325

326
                # bytes are indexed from 0:
327
                # https://en.wikipedia.org/wiki/List_of_HTTP_header_fields#range-request-header
328 0
                end = "{0}".format(length-1) if length is not None else ""
329 0
                self._session.headers['Range'] = "bytes={0}-{1}".format(existing_file_length,
330
                                                                        end)
331

332 0
                response = self._session.request(method, url,
333
                                                 timeout=timeout, stream=True,
334
                                                 auth=auth, **kwargs)
335 0
                response.raise_for_status()
336

337 0
        elif cache and os.path.exists(local_filepath):
338 0
            if length is not None:
339 0
                statinfo = os.stat(local_filepath)
340 0
                if statinfo.st_size != length:
341 0
                    log.warning("Found cached file {0} with size {1} that is "
342
                                "different from expected size {2}"
343
                                .format(local_filepath,
344
                                        statinfo.st_size,
345
                                        length))
346 0
                    open_mode = 'wb'
347
                else:
348 0
                    log.info("Found cached file {0} with expected size {1}."
349
                             .format(local_filepath, statinfo.st_size))
350 0
                    response.close()
351 0
                    return
352
            else:
353 0
                log.info("Found cached file {0}.".format(local_filepath))
354 0
                response.close()
355 0
                return
356
        else:
357 0
            open_mode = 'wb'
358 0
            if head_safe:
359 0
                response = self._session.request(method, url,
360
                                                 timeout=timeout, stream=True,
361
                                                 auth=auth, **kwargs)
362 0
                response.raise_for_status()
363

364 0
        blocksize = astropy.utils.data.conf.download_block_size
365

366 0
        bytes_read = 0
367

368
        # Only show progress bar if logging level is INFO or lower.
369 0
        if log.getEffectiveLevel() <= 20:
370 0
            progress_stream = None  # Astropy default
371
        else:
372 0
            progress_stream = io.StringIO()
373

374 0
        with ProgressBarOrSpinner(
375
                length, ('Downloading URL {0} to {1} ...'
376
                         .format(url, local_filepath)),
377
                file=progress_stream) as pb:
378 0
            with open(local_filepath, open_mode) as f:
379 0
                for block in response.iter_content(blocksize):
380 0
                    f.write(block)
381 0
                    bytes_read += blocksize
382 0
                    if length is not None:
383 0
                        pb.update(bytes_read if bytes_read <= length else
384
                                  length)
385
                    else:
386 0
                        pb.update(bytes_read)
387

388 0
        response.close()
389 0
        return response
390

391

392 1
class suspend_cache:
393
    """
394
    A context manager that suspends caching.
395
    """
396

397 1
    def __init__(self, obj):
398 1
        self.obj = obj
399

400 1
    def __enter__(self):
401 1
        self.obj._cache_active = False
402

403 1
    def __exit__(self, exc_type, exc_value, traceback):
404 1
        self.obj._cache_active = True
405 1
        return False
406

407

408 1
class QueryWithLogin(BaseQuery):
409
    """
410
    This is the base class for all the query classes which are required to
411
    have a login to access the data.
412

413
    The abstract method _login() must be implemented. It is wrapped by the
414
    login() method, which turns off the cache. This way, login credentials
415
    are not stored in the cache.
416
    """
417

418 1
    def __init__(self):
419 1
        super(QueryWithLogin, self).__init__()
420 1
        self._authenticated = False
421

422 1
    def _get_password(self, service_name, username, reenter=False):
423
        """Get password from keyring or prompt."""
424

425 0
        password_from_keyring = None
426 0
        if reenter is False:
427 0
            try:
428 0
                password_from_keyring = keyring.get_password(
429
                    service_name, username)
430 0
            except keyring.errors.KeyringError as exc:
431 0
                log.warning("Failed to get a valid keyring for password "
432
                            "storage: {}".format(exc))
433

434 0
        if password_from_keyring is None:
435 0
            log.warning("No password was found in the keychain for the "
436
                        "provided username.")
437 0
            if system_tools.in_ipynb():
438 0
                log.warning("You may be using an ipython notebook:"
439
                            " the password form will appear in your terminal.")
440 0
            password = getpass.getpass("{0}, enter your password:\n"
441
                                       .format(username))
442
        else:
443 0
            password = password_from_keyring
444

445 0
        return password, password_from_keyring
446

447 1
    @abc.abstractmethod
448
    def _login(self, *args, **kwargs):
449
        """
450
        login to non-public data as a known user
451

452
        Parameters
453
        ----------
454
        Keyword arguments that can be used to create
455
        the data payload(dict) sent via `requests.post`
456
        """
457 0
        pass
458

459 1
    def login(self, *args, **kwargs):
460 1
        with suspend_cache(self):
461 1
            self._authenticated = self._login(*args, **kwargs)
462 1
        return self._authenticated
463

464 1
    def authenticated(self):
465 0
        return self._authenticated

Read our documentation on viewing source code .

Loading