aio-libs / aiohttp
1
"""Various helper functions"""
2

3 10
import asyncio
4 10
import base64
5 10
import binascii
6 10
import cgi
7 10
import dataclasses
8 10
import datetime
9 10
import functools
10 10
import netrc
11 10
import os
12 10
import platform
13 10
import re
14 10
import sys
15 10
import time
16 10
import weakref
17 10
from collections import namedtuple
18 10
from contextlib import suppress
19 10
from http.cookies import SimpleCookie
20 10
from math import ceil
21 10
from pathlib import Path
22 10
from types import TracebackType
23 10
from typing import (
24
    Any,
25
    Callable,
26
    Dict,
27
    Generator,
28
    Generic,
29
    Iterable,
30
    Iterator,
31
    List,
32
    Mapping,
33
    NewType,
34
    Optional,
35
    Pattern,
36
    Tuple,
37
    Type,
38
    TypeVar,
39
    Union,
40
    cast,
41
)
42 10
from urllib.parse import quote
43 10
from urllib.request import getproxies, proxy_bypass
44

45 10
import async_timeout
46 10
from multidict import CIMultiDict, MultiDict, MultiDictProxy
47 10
from typing_extensions import Protocol, final
48 10
from yarl import URL
49

50 10
from . import hdrs
51 10
from .log import client_logger
52 10
from .typedefs import PathLike  # noqa
53

54 10
__all__ = ("BasicAuth", "ChainMapProxy", "ETag")
55

56 10
PY_38 = sys.version_info >= (3, 8)
57

58

59 10
try:
60 10
    from typing import ContextManager
61 0
except ImportError:
62 0
    from typing_extensions import ContextManager
63

64

65 10
_T = TypeVar("_T")
66 10
_S = TypeVar("_S")
67

68 10
_SENTINEL = NewType("_SENTINEL", object)
69

70 10
sentinel: _SENTINEL = _SENTINEL(object())
71 10
NO_EXTENSIONS = bool(os.environ.get("AIOHTTP_NO_EXTENSIONS"))  # type: bool
72

73
# N.B. sys.flags.dev_mode is available on Python 3.7+, use getattr
74
# for compatibility with older versions
75 10
DEBUG = getattr(sys.flags, "dev_mode", False) or (
76
    not sys.flags.ignore_environment and bool(os.environ.get("PYTHONASYNCIODEBUG"))
77
)  # type: bool
78

79

80 10
CHAR = {chr(i) for i in range(0, 128)}
81 10
CTL = {chr(i) for i in range(0, 32)} | {
82
    chr(127),
83
}
84 10
SEPARATORS = {
85
    "(",
86
    ")",
87
    "<",
88
    ">",
89
    "@",
90
    ",",
91
    ";",
92
    ":",
93
    "\\",
94
    '"',
95
    "/",
96
    "[",
97
    "]",
98
    "?",
99
    "=",
100
    "{",
101
    "}",
102
    " ",
103
    chr(9),
104
}
105 10
TOKEN = CHAR ^ CTL ^ SEPARATORS
106

107

108 10
class noop:
109 10
    def __await__(self) -> Generator[None, None, None]:
110 10
        yield
111

112

113 10
if PY_38:
114 7
    iscoroutinefunction = asyncio.iscoroutinefunction
115
else:
116

117 3
    def iscoroutinefunction(func: Callable[..., Any]) -> bool:
118 10
        while isinstance(func, functools.partial):
119 3
            func = func.func
120 3
        return asyncio.iscoroutinefunction(func)
121

122

123 10
json_re = re.compile(r"^application/(?:[\w.+-]+?\+)?json")
124

125

126 10
class BasicAuth(namedtuple("BasicAuth", ["login", "password", "encoding"])):
127
    """Http basic authentication helper."""
128

129 10
    def __new__(
130
        cls, login: str, password: str = "", encoding: str = "latin1"
131
    ) -> "BasicAuth":
132 10
        if login is None:
133 10
            raise ValueError("None is not allowed as login value")
134

135 10
        if password is None:
136 10
            raise ValueError("None is not allowed as password value")
137

138 10
        if ":" in login:
139 10
            raise ValueError('A ":" is not allowed in login (RFC 1945#section-11.1)')
140

141 10
        return super().__new__(cls, login, password, encoding)
142

143 10
    @classmethod
144 10
    def decode(cls, auth_header: str, encoding: str = "latin1") -> "BasicAuth":
145
        """Create a BasicAuth object from an Authorization HTTP header."""
146 10
        try:
147 10
            auth_type, encoded_credentials = auth_header.split(" ", 1)
148 10
        except ValueError:
149 10
            raise ValueError("Could not parse authorization header.")
150

151 10
        if auth_type.lower() != "basic":
152 10
            raise ValueError("Unknown authorization method %s" % auth_type)
153

154 10
        try:
155 10
            decoded = base64.b64decode(
156
                encoded_credentials.encode("ascii"), validate=True
157
            ).decode(encoding)
158 10
        except binascii.Error:
159 10
            raise ValueError("Invalid base64 encoding.")
160

161 10
        try:
162
            # RFC 2617 HTTP Authentication
163
            # https://www.ietf.org/rfc/rfc2617.txt
164
            # the colon must be present, but the username and password may be
165
            # otherwise blank.
166 10
            username, password = decoded.split(":", 1)
167 10
        except ValueError:
168 10
            raise ValueError("Invalid credentials.")
169

170 10
        return cls(username, password, encoding=encoding)
171

172 10
    @classmethod
173 10
    def from_url(cls, url: URL, *, encoding: str = "latin1") -> Optional["BasicAuth"]:
174
        """Create BasicAuth from url."""
175 10
        if not isinstance(url, URL):
176 10
            raise TypeError("url should be yarl.URL instance")
177 10
        if url.user is None:
178 10
            return None
179 10
        return cls(url.user, url.password or "", encoding=encoding)
180

181 10
    def encode(self) -> str:
182
        """Encode credentials."""
183 10
        creds = (f"{self.login}:{self.password}").encode(self.encoding)
184 10
        return "Basic %s" % base64.b64encode(creds).decode(self.encoding)
185

186

187 10
def strip_auth_from_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
188 10
    auth = BasicAuth.from_url(url)
189 10
    if auth is None:
190 10
        return url, None
191
    else:
192 10
        return url.with_user(None), auth
193

194

195 10
def netrc_from_env() -> Optional[netrc.netrc]:
196
    """Attempt to load the netrc file from the path specified by the env-var
197
    NETRC or in the default location in the user's home directory.
198

199
    Returns None if it couldn't be found or fails to parse.
200
    """
201 10
    netrc_env = os.environ.get("NETRC")
202

203 10
    if netrc_env is not None:
204 10
        netrc_path = Path(netrc_env)
205
    else:
206 10
        try:
207 10
            home_dir = Path.home()
208
        except RuntimeError as e:  # pragma: no cover
209
            # if pathlib can't resolve home, it may raise a RuntimeError
210
            client_logger.debug(
211
                "Could not resolve home directory when "
212
                "trying to look for .netrc file: %s",
213
                e,
214
            )
215
            return None
216

217 10
        netrc_path = home_dir / (
218
            "_netrc" if platform.system() == "Windows" else ".netrc"
219
        )
220

221 10
    try:
222 10
        return netrc.netrc(str(netrc_path))
223 10
    except netrc.NetrcParseError as e:
224 10
        client_logger.warning("Could not parse .netrc file: %s", e)
225 10
    except OSError as e:
226
        # we couldn't read the file (doesn't exist, permissions, etc.)
227 10
        if netrc_env or netrc_path.is_file():
228
            # only warn if the environment wanted us to load it,
229
            # or it appears like the default file does actually exist
230 0
            client_logger.warning("Could not read .netrc file: %s", e)
231

232 10
    return None
233

234

235 10
@dataclasses.dataclass(frozen=True)
236 7
class ProxyInfo:
237 10
    proxy: URL
238 10
    proxy_auth: Optional[BasicAuth]
239

240

241 10
def proxies_from_env() -> Dict[str, ProxyInfo]:
242 10
    proxy_urls = {
243
        k: URL(v)
244
        for k, v in getproxies().items()
245
        if k in ("http", "https", "ws", "wss")
246
    }
247 10
    netrc_obj = netrc_from_env()
248 10
    stripped = {k: strip_auth_from_url(v) for k, v in proxy_urls.items()}
249 10
    ret = {}
250 10
    for proto, val in stripped.items():
251 10
        proxy, auth = val
252 10
        if proxy.scheme in ("https", "wss"):
253 10
            client_logger.warning(
254
                "%s proxies %s are not supported, ignoring", proxy.scheme.upper(), proxy
255
            )
256 10
            continue
257 10
        if netrc_obj and auth is None:
258 10
            auth_from_netrc = None
259 10
            if proxy.host is not None:
260 10
                auth_from_netrc = netrc_obj.authenticators(proxy.host)
261 10
            if auth_from_netrc is not None:
262
                # auth_from_netrc is a (`user`, `account`, `password`) tuple,
263
                # `user` and `account` both can be username,
264
                # if `user` is None, use `account`
265 10
                *logins, password = auth_from_netrc
266 10
                login = logins[0] if logins[0] else logins[-1]
267 10
                auth = BasicAuth(cast(str, login), cast(str, password))
268 10
        ret[proto] = ProxyInfo(proxy, auth)
269 10
    return ret
270

271

272 10
def get_env_proxy_for_url(url: URL) -> Tuple[URL, Optional[BasicAuth]]:
273
    """Get a permitted proxy for the given URL from the env."""
274 10
    if url.host is not None and proxy_bypass(url.host):
275 10
        raise LookupError(f"Proxying is disallowed for `{url.host!r}`")
276

277 10
    proxies_in_env = proxies_from_env()
278 10
    try:
279 10
        proxy_info = proxies_in_env[url.scheme]
280 10
    except KeyError:
281 10
        raise LookupError(f"No proxies found for `{url!s}` in the env")
282
    else:
283 10
        return proxy_info.proxy, proxy_info.proxy_auth
284

285

286 10
@dataclasses.dataclass(frozen=True)
287 7
class MimeType:
288 10
    type: str
289 10
    subtype: str
290 10
    suffix: str
291 10
    parameters: "MultiDictProxy[str]"
292

293

294 10
@functools.lru_cache(maxsize=56)
295 10
def parse_mimetype(mimetype: str) -> MimeType:
296
    """Parses a MIME type into its components.
297

298
    mimetype is a MIME type string.
299

300
    Returns a MimeType object.
301

302
    Example:
303

304
    >>> parse_mimetype('text/html; charset=utf-8')
305
    MimeType(type='text', subtype='html', suffix='',
306
             parameters={'charset': 'utf-8'})
307

308
    """
309 10
    if not mimetype:
310 10
        return MimeType(
311
            type="", subtype="", suffix="", parameters=MultiDictProxy(MultiDict())
312
        )
313

314 10
    parts = mimetype.split(";")
315 10
    params = MultiDict()  # type: MultiDict[str]
316 10
    for item in parts[1:]:
317 10
        if not item:
318 10
            continue
319 10
        key, value = cast(
320
            Tuple[str, str], item.split("=", 1) if "=" in item else (item, "")
321
        )
322 10
        params.add(key.lower().strip(), value.strip(' "'))
323

324 10
    fulltype = parts[0].strip().lower()
325 10
    if fulltype == "*":
326 10
        fulltype = "*/*"
327

328 10
    mtype, stype = (
329
        cast(Tuple[str, str], fulltype.split("/", 1))
330
        if "/" in fulltype
331
        else (fulltype, "")
332
    )
333 10
    stype, suffix = (
334
        cast(Tuple[str, str], stype.split("+", 1)) if "+" in stype else (stype, "")
335
    )
336

337 10
    return MimeType(
338
        type=mtype, subtype=stype, suffix=suffix, parameters=MultiDictProxy(params)
339
    )
340

341

342 10
def guess_filename(obj: Any, default: Optional[str] = None) -> Optional[str]:
343 10
    name = getattr(obj, "name", None)
344 10
    if name and isinstance(name, str) and name[0] != "<" and name[-1] != ">":
345 10
        return Path(name).name
346 10
    return default
347

348

349 10
not_qtext_re = re.compile(r"[^\041\043-\133\135-\176]")
350 10
QCONTENT = {chr(i) for i in range(0x20, 0x7F)} | {"\t"}
351

352

353 10
def quoted_string(content: str) -> str:
354
    """Return 7-bit content as quoted-string.
355

356
    Format content into a quoted-string as defined in RFC5322 for
357
    Internet Message Format. Notice that this is not the 8-bit HTTP
358
    format, but the 7-bit email format. Content must be in usascii or
359
    a ValueError is raised.
360
    """
361 10
    if not (QCONTENT > set(content)):
362 10
        raise ValueError(f"bad content for quoted-string {content!r}")
363 10
    return not_qtext_re.sub(lambda x: "\\" + x.group(0), content)
364

365

366 10
def content_disposition_header(
367
    disptype: str, quote_fields: bool = True, _charset: str = "utf-8", **params: str
368
) -> str:
369
    """Sets ``Content-Disposition`` header for MIME.
370

371
    This is the MIME payload Content-Disposition header from RFC 2183
372
    and RFC 7579 section 4.2, not the HTTP Content-Disposition from
373
    RFC 6266.
374

375
    disptype is a disposition type: inline, attachment, form-data.
376
    Should be valid extension token (see RFC 2183)
377

378
    quote_fields performs value quoting to 7-bit MIME headers
379
    according to RFC 7578. Set to quote_fields to False if recipient
380
    can take 8-bit file names and field values.
381

382
    _charset specifies the charset to use when quote_fields is True.
383

384
    params is a dict with disposition params.
385
    """
386 10
    if not disptype or not (TOKEN > set(disptype)):
387 10
        raise ValueError("bad content disposition type {!r}" "".format(disptype))
388

389 10
    value = disptype
390 10
    if params:
391 10
        lparams = []
392 10
        for key, val in params.items():
393 10
            if not key or not (TOKEN > set(key)):
394 10
                raise ValueError(
395
                    "bad content disposition parameter" " {!r}={!r}".format(key, val)
396
                )
397 10
            if quote_fields:
398 10
                if key.lower() == "filename":
399 10
                    qval = quote(val, "", encoding=_charset)
400 10
                    lparams.append((key, '"%s"' % qval))
401
                else:
402 10
                    try:
403 10
                        qval = quoted_string(val)
404 10
                    except ValueError:
405 10
                        qval = "".join(
406
                            (_charset, "''", quote(val, "", encoding=_charset))
407
                        )
408 10
                        lparams.append((key + "*", qval))
409
                    else:
410 10
                        lparams.append((key, '"%s"' % qval))
411
            else:
412 10
                qval = val.replace("\\", "\\\\").replace('"', '\\"')
413 10
                lparams.append((key, '"%s"' % qval))
414 10
        sparams = "; ".join("=".join(pair) for pair in lparams)
415 10
        value = "; ".join((value, sparams))
416 10
    return value
417

418

419 10
def is_expected_content_type(
420
    response_content_type: str, expected_content_type: str
421
) -> bool:
422 10
    if expected_content_type == "application/json":
423 10
        return json_re.match(response_content_type) is not None
424 10
    return expected_content_type in response_content_type
425

426

427 10
class _TSelf(Protocol, Generic[_T]):
428 10
    _cache: Dict[str, _T]
429

430

431 10
class reify(Generic[_T]):
432
    """Use as a class method decorator.  It operates almost exactly like
433
    the Python `@property` decorator, but it puts the result of the
434
    method it decorates into the instance dict after the first call,
435
    effectively replacing the function it decorates with an instance
436
    variable.  It is, in Python parlance, a data descriptor.
437

438
    """
439

440 10
    def __init__(self, wrapped: Callable[..., _T]) -> None:
441 10
        self.wrapped = wrapped
442 10
        self.__doc__ = wrapped.__doc__
443 10
        self.name = wrapped.__name__
444

445 10
    def __get__(self, inst: _TSelf[_T], owner: Optional[Type[Any]] = None) -> _T:
446 10
        try:
447 10
            try:
448 10
                return inst._cache[self.name]
449 10
            except KeyError:
450 10
                val = self.wrapped(inst)
451 10
                inst._cache[self.name] = val
452 10
                return val
453 10
        except AttributeError:
454 10
            if inst is None:
455 10
                return self
456 0
            raise
457

458 10
    def __set__(self, inst: _TSelf[_T], value: _T) -> None:
459 10
        raise AttributeError("reified property is read-only")
460

461

462 10
reify_py = reify
463

464 10
try:
465 10
    from ._helpers import reify as reify_c
466

467 10
    if not NO_EXTENSIONS:
468 7
        reify = reify_c  # type: ignore[misc,assignment]
469 3
except ImportError:
470 3
    pass
471

472 10
_ipv4_pattern = (
473
    r"^(?:(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)\.){3}"
474
    r"(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$"
475
)
476 10
_ipv6_pattern = (
477
    r"^(?:(?:(?:[A-F0-9]{1,4}:){6}|(?=(?:[A-F0-9]{0,4}:){0,6}"
478
    r"(?:[0-9]{1,3}\.){3}[0-9]{1,3}$)(([0-9A-F]{1,4}:){0,5}|:)"
479
    r"((:[0-9A-F]{1,4}){1,5}:|:)|::(?:[A-F0-9]{1,4}:){5})"
480
    r"(?:(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])\.){3}"
481
    r"(?:25[0-5]|2[0-4][0-9]|1[0-9][0-9]|[1-9]?[0-9])|(?:[A-F0-9]{1,4}:){7}"
482
    r"[A-F0-9]{1,4}|(?=(?:[A-F0-9]{0,4}:){0,7}[A-F0-9]{0,4}$)"
483
    r"(([0-9A-F]{1,4}:){1,7}|:)((:[0-9A-F]{1,4}){1,7}|:)|(?:[A-F0-9]{1,4}:){7}"
484
    r":|:(:[A-F0-9]{1,4}){7})$"
485
)
486 10
_ipv4_regex = re.compile(_ipv4_pattern)
487 10
_ipv6_regex = re.compile(_ipv6_pattern, flags=re.IGNORECASE)
488 10
_ipv4_regexb = re.compile(_ipv4_pattern.encode("ascii"))
489 10
_ipv6_regexb = re.compile(_ipv6_pattern.encode("ascii"), flags=re.IGNORECASE)
490

491

492 10
def _is_ip_address(
493
    regex: Pattern[str], regexb: Pattern[bytes], host: Optional[Union[str, bytes]]
494
) -> bool:
495 10
    if host is None:
496 10
        return False
497 10
    if isinstance(host, str):
498 10
        return bool(regex.match(host))
499 10
    elif isinstance(host, (bytes, bytearray, memoryview)):
500 10
        return bool(regexb.match(host))
501
    else:
502 10
        raise TypeError("{} [{}] is not a str or bytes".format(host, type(host)))
503

504

505 10
is_ipv4_address = functools.partial(_is_ip_address, _ipv4_regex, _ipv4_regexb)
506 10
is_ipv6_address = functools.partial(_is_ip_address, _ipv6_regex, _ipv6_regexb)
507

508

509 10
def is_ip_address(host: Optional[Union[str, bytes, bytearray, memoryview]]) -> bool:
510 10
    return is_ipv4_address(host) or is_ipv6_address(host)
511

512

513 10
def next_whole_second() -> datetime.datetime:
514
    """Return current time rounded up to the next whole second."""
515 10
    return datetime.datetime.now(datetime.timezone.utc).replace(
516
        microsecond=0
517
    ) + datetime.timedelta(seconds=0)
518

519

520 10
_cached_current_datetime = None  # type: Optional[int]
521 10
_cached_formatted_datetime = ""
522

523

524 10
def rfc822_formatted_time() -> str:
525
    global _cached_current_datetime
526
    global _cached_formatted_datetime
527

528 10
    now = int(time.time())
529 10
    if now != _cached_current_datetime:
530
        # Weekday and month names for HTTP date/time formatting;
531
        # always English!
532
        # Tuples are constants stored in codeobject!
533 10
        _weekdayname = ("Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun")
534 10
        _monthname = (
535
            "",  # Dummy so we can use 1-based month numbers
536
            "Jan",
537
            "Feb",
538
            "Mar",
539
            "Apr",
540
            "May",
541
            "Jun",
542
            "Jul",
543
            "Aug",
544
            "Sep",
545
            "Oct",
546
            "Nov",
547
            "Dec",
548
        )
549

550 10
        year, month, day, hh, mm, ss, wd, *tail = time.gmtime(now)
551 10
        _cached_formatted_datetime = "%s, %02d %3s %4d %02d:%02d:%02d GMT" % (
552
            _weekdayname[wd],
553
            day,
554
            _monthname[month],
555
            year,
556
            hh,
557
            mm,
558
            ss,
559
        )
560 10
        _cached_current_datetime = now
561 10
    return _cached_formatted_datetime
562

563

564 10
def _weakref_handle(info: "Tuple[weakref.ref[object], str]") -> None:
565 10
    ref, name = info
566 10
    ob = ref()
567 10
    if ob is not None:
568 10
        with suppress(Exception):
569 10
            getattr(ob, name)()
570

571

572 10
def weakref_handle(
573
    ob: object, name: str, timeout: float, loop: asyncio.AbstractEventLoop
574
) -> Optional[asyncio.TimerHandle]:
575 10
    if timeout is not None and timeout > 0:
576 10
        when = loop.time() + timeout
577 10
        if timeout >= 5:
578 10
            when = ceil(when)
579

580 10
        return loop.call_at(when, _weakref_handle, (weakref.ref(ob), name))
581 0
    return None
582

583

584 10
def call_later(
585
    cb: Callable[[], Any], timeout: float, loop: asyncio.AbstractEventLoop
586
) -> Optional[asyncio.TimerHandle]:
587 10
    if timeout is not None and timeout > 0:
588 10
        when = loop.time() + timeout
589 10
        if timeout > 5:
590 10
            when = ceil(when)
591 10
        return loop.call_at(when, cb)
592 10
    return None
593

594

595 10
class TimeoutHandle:
596
    """ Timeout handle """
597

598 10
    def __init__(
599
        self, loop: asyncio.AbstractEventLoop, timeout: Optional[float]
600
    ) -> None:
601 10
        self._timeout = timeout
602 10
        self._loop = loop
603 10
        self._callbacks = (
604
            []
605
        )  # type: List[Tuple[Callable[..., None], Tuple[Any, ...], Dict[str, Any]]]
606

607 10
    def register(
608
        self, callback: Callable[..., None], *args: Any, **kwargs: Any
609
    ) -> None:
610 10
        self._callbacks.append((callback, args, kwargs))
611

612 10
    def close(self) -> None:
613 10
        self._callbacks.clear()
614

615 10
    def start(self) -> Optional[asyncio.Handle]:
616 10
        timeout = self._timeout
617 10
        if timeout is not None and timeout > 0:
618 10
            when = self._loop.time() + timeout
619 10
            if timeout >= 5:
620 10
                when = ceil(when)
621 10
            return self._loop.call_at(when, self.__call__)
622
        else:
623 10
            return None
624

625 10
    def timer(self) -> "BaseTimerContext":
626 10
        if self._timeout is not None and self._timeout > 0:
627 10
            timer = TimerContext(self._loop)
628 10
            self.register(timer.timeout)
629 10
            return timer
630
        else:
631 10
            return TimerNoop()
632

633 10
    def __call__(self) -> None:
634 10
        for cb, args, kwargs in self._callbacks:
635 10
            with suppress(Exception):
636 10
                cb(*args, **kwargs)
637

638 10
        self._callbacks.clear()
639

640

641 10
class BaseTimerContext(ContextManager["BaseTimerContext"]):
642 10
    pass
643

644

645 10
class TimerNoop(BaseTimerContext):
646 10
    def __enter__(self) -> BaseTimerContext:
647 10
        return self
648

649 10
    def __exit__(
650
        self,
651
        exc_type: Optional[Type[BaseException]],
652
        exc_val: Optional[BaseException],
653
        exc_tb: Optional[TracebackType],
654
    ) -> None:
655 10
        return
656

657

658 10
class TimerContext(BaseTimerContext):
659
    """ Low resolution timeout context manager """
660

661 10
    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
662 10
        self._loop = loop
663 10
        self._tasks = []  # type: List[asyncio.Task[Any]]
664 10
        self._cancelled = False
665

666 10
    def __enter__(self) -> BaseTimerContext:
667 10
        task = asyncio.current_task(loop=self._loop)
668

669 10
        if task is None:
670 10
            raise RuntimeError(
671
                "Timeout context manager should be used " "inside a task"
672
            )
673

674 10
        if self._cancelled:
675 10
            task.cancel()
676 10
            raise asyncio.TimeoutError from None
677

678 10
        self._tasks.append(task)
679 10
        return self
680

681 10
    def __exit__(
682
        self,
683
        exc_type: Optional[Type[BaseException]],
684
        exc_val: Optional[BaseException],
685
        exc_tb: Optional[TracebackType],
686
    ) -> Optional[bool]:
687 10
        if self._tasks:
688 10
            self._tasks.pop()
689

690 10
        if exc_type is asyncio.CancelledError and self._cancelled:
691 10
            raise asyncio.TimeoutError from None
692 10
        return None
693

694 10
    def timeout(self) -> None:
695 10
        if not self._cancelled:
696 10
            for task in set(self._tasks):
697 10
                task.cancel()
698

699 10
            self._cancelled = True
700

701

702 10
def ceil_timeout(delay: Optional[float]) -> async_timeout.Timeout:
703 10
    if delay is None or delay <= 0:
704 10
        return async_timeout.timeout(None)
705

706 10
    loop = asyncio.get_running_loop()
707 10
    now = loop.time()
708 10
    when = now + delay
709 10
    if delay > 5:
710 10
        when = ceil(when)
711 10
    return async_timeout.timeout_at(when)
712

713

714 10
class HeadersMixin:
715

716 10
    __slots__ = ("_content_type", "_content_dict", "_stored_content_type")
717

718 10
    def __init__(self) -> None:
719 10
        super().__init__()
720 10
        self._content_type = None  # type: Optional[str]
721 10
        self._content_dict = None  # type: Optional[Dict[str, str]]
722 10
        self._stored_content_type: Union[str, _SENTINEL] = sentinel
723

724 10
    def _parse_content_type(self, raw: str) -> None:
725 10
        self._stored_content_type = raw
726 10
        if raw is None:
727
            # default value according to RFC 2616
728 10
            self._content_type = "application/octet-stream"
729 10
            self._content_dict = {}
730
        else:
731 10
            self._content_type, self._content_dict = cgi.parse_header(raw)
732

733 10
    @property
734 10
    def content_type(self) -> str:
735
        """The value of content part for Content-Type HTTP header."""
736 10
        raw = self._headers.get(hdrs.CONTENT_TYPE)  # type: ignore[attr-defined]
737 10
        if self._stored_content_type != raw:
738 10
            self._parse_content_type(raw)
739 10
        return self._content_type  # type: ignore[return-value]
740

741 10
    @property
742 10
    def charset(self) -> Optional[str]:
743
        """The value of charset part for Content-Type HTTP header."""
744 10
        raw = self._headers.get(hdrs.CONTENT_TYPE)  # type: ignore[attr-defined]
745 10
        if self._stored_content_type != raw:
746 10
            self._parse_content_type(raw)
747 10
        return self._content_dict.get("charset")  # type: ignore[union-attr]
748

749 10
    @property
750 10
    def content_length(self) -> Optional[int]:
751
        """The value of Content-Length HTTP header."""
752 10
        content_length = self._headers.get(  # type: ignore[attr-defined]
753
            hdrs.CONTENT_LENGTH
754
        )
755

756 10
        if content_length is not None:
757 10
            return int(content_length)
758
        else:
759 10
            return None
760

761

762 10
def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
763 10
    if not fut.done():
764 10
        fut.set_result(result)
765

766

767 10
def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
768 10
    if not fut.done():
769 10
        fut.set_exception(exc)
770

771

772 10
@final
773 10
class ChainMapProxy(Mapping[str, Any]):
774 10
    __slots__ = ("_maps",)
775

776 10
    def __init__(self, maps: Iterable[Mapping[str, Any]]) -> None:
777 10
        self._maps = tuple(maps)
778

779 10
    def __init_subclass__(cls) -> None:
780 10
        raise TypeError(
781
            "Inheritance class {} from ChainMapProxy "
782
            "is forbidden".format(cls.__name__)
783
        )
784

785 10
    def __getitem__(self, key: str) -> Any:
786 10
        for mapping in self._maps:
787 10
            try:
788 10
                return mapping[key]
789 10
            except KeyError:
790 10
                pass
791 10
        raise KeyError(key)
792

793 10
    def get(self, key: str, default: Any = None) -> Any:
794 10
        return self[key] if key in self else default
795

796 10
    def __len__(self) -> int:
797
        # reuses stored hash values if possible
798 10
        return len(set().union(*self._maps))  # type: ignore[arg-type]
799

800 10
    def __iter__(self) -> Iterator[str]:
801 10
        d = {}  # type: Dict[str, Any]
802 10
        for mapping in reversed(self._maps):
803
            # reuses stored hash values if possible
804 10
            d.update(mapping)
805 10
        return iter(d)
806

807 10
    def __contains__(self, key: object) -> bool:
808 10
        return any(key in m for m in self._maps)
809

810 10
    def __bool__(self) -> bool:
811 10
        return any(self._maps)
812

813 10
    def __repr__(self) -> str:
814 10
        content = ", ".join(map(repr, self._maps))
815 10
        return f"ChainMapProxy({content})"
816

817

818 10
class CookieMixin:
819 10
    def __init__(self) -> None:
820 10
        super().__init__()
821 10
        self._cookies = SimpleCookie()  # type: SimpleCookie[str]
822

823 10
    @property
824 10
    def cookies(self) -> "SimpleCookie[str]":
825 10
        return self._cookies
826

827 10
    def set_cookie(
828
        self,
829
        name: str,
830
        value: str,
831
        *,
832
        expires: Optional[str] = None,
833
        domain: Optional[str] = None,
834
        max_age: Optional[Union[int, str]] = None,
835
        path: str = "/",
836
        secure: Optional[bool] = None,
837
        httponly: Optional[bool] = None,
838
        version: Optional[str] = None,
839
        samesite: Optional[str] = None,
840
    ) -> None:
841
        """Set or update response cookie.
842

843
        Sets new cookie or updates existent with new value.
844
        Also updates only those params which are not None.
845
        """
846

847 10
        old = self._cookies.get(name)
848 10
        if old is not None and old.coded_value == "":
849
            # deleted cookie
850 0
            self._cookies.pop(name, None)
851

852 10
        self._cookies[name] = value
853 10
        c = self._cookies[name]
854

855 10
        if expires is not None:
856 10
            c["expires"] = expires
857 10
        elif c.get("expires") == "Thu, 01 Jan 1970 00:00:00 GMT":
858 10
            del c["expires"]
859

860 10
        if domain is not None:
861 10
            c["domain"] = domain
862

863 10
        if max_age is not None:
864 10
            c["max-age"] = str(max_age)
865 10
        elif "max-age" in c:
866 10
            del c["max-age"]
867

868 10
        c["path"] = path
869

870 10
        if secure is not None:
871 10
            c["secure"] = secure
872 10
        if httponly is not None:
873 10
            c["httponly"] = httponly
874 10
        if version is not None:
875 10
            c["version"] = version
876 10
        if samesite is not None:
877 10
            c["samesite"] = samesite
878

879 10
    def del_cookie(
880
        self, name: str, *, domain: Optional[str] = None, path: str = "/"
881
    ) -> None:
882
        """Delete cookie.
883

884
        Creates new empty expired cookie.
885
        """
886
        # TODO: do we need domain/path here?
887 10
        self._cookies.pop(name, None)
888 10
        self.set_cookie(
889
            name,
890
            "",
891
            max_age=0,
892
            expires="Thu, 01 Jan 1970 00:00:00 GMT",
893
            domain=domain,
894
            path=path,
895
        )
896

897

898 10
def populate_with_cookies(
899
    headers: "CIMultiDict[str]", cookies: "SimpleCookie[str]"
900
) -> None:
901 10
    for cookie in cookies.values():
902 10
        value = cookie.output(header="")[1:]
903 10
        headers.add(hdrs.SET_COOKIE, value)
904

905

906
# https://tools.ietf.org/html/rfc7232#section-2.3
907 10
_ETAGC = r"[!#-}\x80-\xff]+"
908 10
_ETAGC_RE = re.compile(_ETAGC)
909 10
_QUOTED_ETAG = fr'(W/)?"({_ETAGC})"'
910 10
QUOTED_ETAG_RE = re.compile(_QUOTED_ETAG)
911 10
LIST_QUOTED_ETAG_RE = re.compile(fr"({_QUOTED_ETAG})(?:\s*,\s*|$)|(.)")
912

913 10
ETAG_ANY = "*"
914

915

916 10
@dataclasses.dataclass(frozen=True)
917 7
class ETag:
918 10
    value: str
919 10
    is_weak: bool = False
920

921

922 10
def validate_etag_value(value: str) -> None:
923 10
    if value != ETAG_ANY and not _ETAGC_RE.fullmatch(value):
924 10
        raise ValueError(
925
            f"Value {value!r} is not a valid etag. Maybe it contains '\"'?"
926
        )

Read our documentation on viewing source code .

Loading