scrapy / scrapy
1 7
import re
2 7
import time
3 7
from http.cookiejar import CookieJar as _CookieJar, DefaultCookiePolicy
4

5 7
from scrapy.utils.httpobj import urlparse_cached
6 7
from scrapy.utils.python import to_unicode
7

8

9
# Defined in the http.cookiejar module, but undocumented:
10
# https://github.com/python/cpython/blob/v3.9.0/Lib/http/cookiejar.py#L527
11 7
IPV4_RE = re.compile(r"\.\d+$", re.ASCII)
12

13

14 7
class CookieJar:
15 7
    def __init__(self, policy=None, check_expired_frequency=10000):
16 7
        self.policy = policy or DefaultCookiePolicy()
17 7
        self.jar = _CookieJar(self.policy)
18 7
        self.jar._cookies_lock = _DummyLock()
19 7
        self.check_expired_frequency = check_expired_frequency
20 7
        self.processed = 0
21

22 7
    def extract_cookies(self, response, request):
23 7
        wreq = WrappedRequest(request)
24 7
        wrsp = WrappedResponse(response)
25 7
        return self.jar.extract_cookies(wrsp, wreq)
26

27 7
    def add_cookie_header(self, request):
28 7
        wreq = WrappedRequest(request)
29 7
        self.policy._now = self.jar._now = int(time.time())
30

31
        # the cookiejar implementation iterates through all domains
32
        # instead we restrict to potential matches on the domain
33 7
        req_host = urlparse_cached(request).hostname
34 7
        if not req_host:
35 7
            return
36

37 7
        if not IPV4_RE.search(req_host):
38 7
            hosts = potential_domain_matches(req_host)
39 7
            if '.' not in req_host:
40 7
                hosts += [req_host + ".local"]
41
        else:
42 7
            hosts = [req_host]
43

44 7
        cookies = []
45 7
        for host in hosts:
46 7
            if host in self.jar._cookies:
47 7
                cookies += self.jar._cookies_for_domain(host, wreq)
48

49 7
        attrs = self.jar._cookie_attrs(cookies)
50 7
        if attrs:
51 7
            if not wreq.has_header("Cookie"):
52 7
                wreq.add_unredirected_header("Cookie", "; ".join(attrs))
53

54 7
        self.processed += 1
55 7
        if self.processed % self.check_expired_frequency == 0:
56
            # This is still quite inefficient for large number of cookies
57 0
            self.jar.clear_expired_cookies()
58

59 7
    @property
60 4
    def _cookies(self):
61 0
        return self.jar._cookies
62

63 7
    def clear_session_cookies(self, *args, **kwargs):
64 0
        return self.jar.clear_session_cookies(*args, **kwargs)
65

66 7
    def clear(self, domain=None, path=None, name=None):
67 0
        return self.jar.clear(domain, path, name)
68

69 7
    def __iter__(self):
70 0
        return iter(self.jar)
71

72 7
    def __len__(self):
73 0
        return len(self.jar)
74

75 7
    def set_policy(self, pol):
76 0
        return self.jar.set_policy(pol)
77

78 7
    def make_cookies(self, response, request):
79 7
        wreq = WrappedRequest(request)
80 7
        wrsp = WrappedResponse(response)
81 7
        return self.jar.make_cookies(wrsp, wreq)
82

83 7
    def set_cookie(self, cookie):
84 0
        self.jar.set_cookie(cookie)
85

86 7
    def set_cookie_if_ok(self, cookie, request):
87 7
        self.jar.set_cookie_if_ok(cookie, WrappedRequest(request))
88

89

90 7
def potential_domain_matches(domain):
91
    """Potential domain matches for a cookie
92

93
    >>> potential_domain_matches('www.example.com')
94
    ['www.example.com', 'example.com', '.www.example.com', '.example.com']
95

96
    """
97 7
    matches = [domain]
98 7
    try:
99 7
        start = domain.index('.') + 1
100 7
        end = domain.rindex('.')
101 7
        while start < end:
102 7
            matches.append(domain[start:])
103 7
            start = domain.index('.', start) + 1
104 7
    except ValueError:
105 7
        pass
106 7
    return matches + ['.' + d for d in matches]
107

108

109 7
class _DummyLock:
110 7
    def acquire(self):
111 7
        pass
112

113 7
    def release(self):
114 7
        pass
115

116

117 7
class WrappedRequest:
118
    """Wraps a scrapy Request class with methods defined by urllib2.Request class to interact with CookieJar class
119

120
    see http://docs.python.org/library/urllib2.html#urllib2.Request
121
    """
122

123 7
    def __init__(self, request):
124 7
        self.request = request
125

126 7
    def get_full_url(self):
127 7
        return self.request.url
128

129 7
    def get_host(self):
130 7
        return urlparse_cached(self.request).netloc
131

132 7
    def get_type(self):
133 7
        return urlparse_cached(self.request).scheme
134

135 7
    def is_unverifiable(self):
136
        """Unverifiable should indicate whether the request is unverifiable, as defined by RFC 2965.
137

138
        It defaults to False. An unverifiable request is one whose URL the user did not have the
139
        option to approve. For example, if the request is for an image in an
140
        HTML document, and the user had no option to approve the automatic
141
        fetching of the image, this should be true.
142
        """
143 7
        return self.request.meta.get('is_unverifiable', False)
144

145 7
    def get_origin_req_host(self):
146 7
        return urlparse_cached(self.request).hostname
147

148
    # python3 uses attributes instead of methods
149 7
    @property
150 4
    def full_url(self):
151 7
        return self.get_full_url()
152

153 7
    @property
154 4
    def host(self):
155 7
        return self.get_host()
156

157 7
    @property
158 4
    def type(self):
159 7
        return self.get_type()
160

161 7
    @property
162 4
    def unverifiable(self):
163 7
        return self.is_unverifiable()
164

165 7
    @property
166 4
    def origin_req_host(self):
167 7
        return self.get_origin_req_host()
168

169 7
    def has_header(self, name):
170 7
        return name in self.request.headers
171

172 7
    def get_header(self, name, default=None):
173 7
        return to_unicode(self.request.headers.get(name, default),
174
                          errors='replace')
175

176 7
    def header_items(self):
177 7
        return [
178
            (to_unicode(k, errors='replace'),
179
             [to_unicode(x, errors='replace') for x in v])
180
            for k, v in self.request.headers.items()
181
        ]
182

183 7
    def add_unredirected_header(self, name, value):
184 7
        self.request.headers.appendlist(name, value)
185

186

187 7
class WrappedResponse:
188

189 7
    def __init__(self, response):
190 7
        self.response = response
191

192 7
    def info(self):
193 7
        return self
194

195 7
    def get_all(self, name, default=None):
196 7
        return [to_unicode(v, errors='replace')
197
                for v in self.response.headers.getlist(name)]

Read our documentation on viewing source code .

Loading