aio-libs / aiohttp
1 10
import asyncio
2 10
import dataclasses
3 10
import functools
4 10
import logging
5 10
import random
6 10
import sys
7 10
import traceback
8 10
import warnings
9 10
from collections import defaultdict, deque
10 10
from contextlib import suppress
11 10
from http.cookies import SimpleCookie
12 10
from itertools import cycle, islice
13 10
from time import monotonic
14 10
from types import TracebackType
15 10
from typing import (  # noqa
16
    TYPE_CHECKING,
17
    Any,
18
    Awaitable,
19
    Callable,
20
    DefaultDict,
21
    Dict,
22
    Iterator,
23
    List,
24
    Optional,
25
    Set,
26
    Tuple,
27
    Type,
28
    Union,
29
    cast,
30
)
31

32 10
from . import hdrs, helpers
33 10
from .abc import AbstractResolver
34 10
from .client_exceptions import (
35
    ClientConnectionError,
36
    ClientConnectorCertificateError,
37
    ClientConnectorError,
38
    ClientConnectorSSLError,
39
    ClientHttpProxyError,
40
    ClientProxyConnectionError,
41
    ServerFingerprintMismatch,
42
    UnixClientConnectorError,
43
    cert_errors,
44
    ssl_errors,
45
)
46 10
from .client_proto import ResponseHandler
47 10
from .client_reqrep import SSL_ALLOWED_TYPES, ClientRequest, Fingerprint
48 10
from .helpers import _SENTINEL, ceil_timeout, is_ip_address, sentinel
49 10
from .http import RESPONSES
50 10
from .locks import EventResultOrError
51 10
from .resolver import DefaultResolver
52

53 10
try:
54 10
    import ssl
55

56 10
    SSLContext = ssl.SSLContext
57
except ImportError:  # pragma: no cover
58
    ssl = None  # type: ignore[assignment]
59
    SSLContext = object  # type: ignore[misc,assignment]
60

61

62 10
__all__ = ("BaseConnector", "TCPConnector", "UnixConnector", "NamedPipeConnector")
63

64

65
if TYPE_CHECKING:  # pragma: no cover
66
    from .client import ClientTimeout
67
    from .client_reqrep import ConnectionKey
68
    from .tracing import Trace
69

70

71 10
class Connection:
72

73 10
    _source_traceback = None
74 10
    _transport = None
75

76 10
    def __init__(
77
        self,
78
        connector: "BaseConnector",
79
        key: "ConnectionKey",
80
        protocol: ResponseHandler,
81
        loop: asyncio.AbstractEventLoop,
82
    ) -> None:
83 10
        self._key = key
84 10
        self._connector = connector
85 10
        self._loop = loop
86 10
        self._protocol = protocol  # type: Optional[ResponseHandler]
87 10
        self._callbacks = []  # type: List[Callable[[], None]]
88

89 10
        if loop.get_debug():
90 10
            self._source_traceback = traceback.extract_stack(sys._getframe(1))
91

92 10
    def __repr__(self) -> str:
93 10
        return f"Connection<{self._key}>"
94

95 10
    def __del__(self, _warnings: Any = warnings) -> None:
96 10
        if self._protocol is not None:
97 10
            _warnings.warn(
98
                f"Unclosed connection {self!r}", ResourceWarning, source=self
99
            )
100 10
            if self._loop.is_closed():
101 10
                return
102

103 10
            self._connector._release(self._key, self._protocol, should_close=True)
104

105 10
            context = {"client_connection": self, "message": "Unclosed connection"}
106 10
            if self._source_traceback is not None:
107 10
                context["source_traceback"] = self._source_traceback
108 10
            self._loop.call_exception_handler(context)
109

110 10
    @property
111 10
    def transport(self) -> Optional[asyncio.Transport]:
112 10
        if self._protocol is None:
113 0
            return None
114 10
        return self._protocol.transport
115

116 10
    @property
117 10
    def protocol(self) -> Optional[ResponseHandler]:
118 10
        return self._protocol
119

120 10
    def add_callback(self, callback: Callable[[], None]) -> None:
121 10
        if callback is not None:
122 10
            self._callbacks.append(callback)
123

124 10
    def _notify_release(self) -> None:
125 10
        callbacks, self._callbacks = self._callbacks[:], []
126

127 10
        for cb in callbacks:
128 10
            with suppress(Exception):
129 10
                cb()
130

131 10
    def close(self) -> None:
132 10
        self._notify_release()
133

134 10
        if self._protocol is not None:
135 10
            self._connector._release(self._key, self._protocol, should_close=True)
136 10
            self._protocol = None
137

138 10
    def release(self) -> None:
139 10
        self._notify_release()
140

141 10
        if self._protocol is not None:
142 10
            self._connector._release(
143
                self._key, self._protocol, should_close=self._protocol.should_close
144
            )
145 10
            self._protocol = None
146

147 10
    @property
148 10
    def closed(self) -> bool:
149 10
        return self._protocol is None or not self._protocol.is_connected()
150

151

152 10
class _TransportPlaceholder:
153
    """ placeholder for BaseConnector.connect function """
154

155 10
    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
156 10
        fut = loop.create_future()
157 10
        fut.set_result(None)
158 10
        self.closed = fut  # type: asyncio.Future[Optional[Exception]]
159

160 10
    def close(self) -> None:
161 10
        pass
162

163

164 10
class BaseConnector:
165
    """Base connector class.
166

167
    keepalive_timeout - (optional) Keep-alive timeout.
168
    force_close - Set to True to force close and do reconnect
169
        after each request (and between redirects).
170
    limit - The total number of simultaneous connections.
171
    limit_per_host - Number of simultaneous connections to one host.
172
    enable_cleanup_closed - Enables clean-up closed ssl transports.
173
                            Disabled by default.
174
    loop - Optional event loop.
175
    """
176

177 10
    _closed = True  # prevent AttributeError in __del__ if ctor was failed
178 10
    _source_traceback = None
179

180
    # abort transport after 2 seconds (cleanup broken connections)
181 10
    _cleanup_closed_period = 2.0
182

183 10
    def __init__(
184
        self,
185
        *,
186
        keepalive_timeout: Union[_SENTINEL, None, float] = sentinel,
187
        force_close: bool = False,
188
        limit: int = 100,
189
        limit_per_host: int = 0,
190
        enable_cleanup_closed: bool = False,
191
    ) -> None:
192

193 10
        if force_close:
194 10
            if keepalive_timeout is not None and keepalive_timeout is not sentinel:
195 10
                raise ValueError(
196
                    "keepalive_timeout cannot " "be set if force_close is True"
197
                )
198
        else:
199 10
            if keepalive_timeout is sentinel:
200 10
                keepalive_timeout = 15.0
201

202 10
        loop = asyncio.get_running_loop()
203

204 10
        self._closed = False
205 10
        if loop.get_debug():
206 10
            self._source_traceback = traceback.extract_stack(sys._getframe(1))
207

208 10
        self._conns = (
209
            {}
210
        )  # type: Dict[ConnectionKey, List[Tuple[ResponseHandler, float]]]
211 10
        self._limit = limit
212 10
        self._limit_per_host = limit_per_host
213 10
        self._acquired = set()  # type: Set[ResponseHandler]
214 10
        self._acquired_per_host = defaultdict(
215
            set
216
        )  # type: DefaultDict[ConnectionKey, Set[ResponseHandler]]
217 10
        self._keepalive_timeout = cast(float, keepalive_timeout)
218 10
        self._force_close = force_close
219

220
        # {host_key: FIFO list of waiters}
221 10
        self._waiters = defaultdict(deque)  # type: ignore[var-annotated]
222

223 10
        self._loop = loop
224 10
        self._factory = functools.partial(ResponseHandler, loop=loop)
225

226 10
        self.cookies = SimpleCookie()  # type: SimpleCookie[str]
227

228
        # start keep-alive connection cleanup task
229 10
        self._cleanup_handle: Optional[asyncio.TimerHandle] = None
230

231
        # start cleanup closed transports task
232 10
        self._cleanup_closed_handle: Optional[asyncio.TimerHandle] = None
233 10
        self._cleanup_closed_disabled = not enable_cleanup_closed
234 10
        self._cleanup_closed_transports = []  # type: List[Optional[asyncio.Transport]]
235 10
        self._cleanup_closed()
236

237 10
    def __del__(self, _warnings: Any = warnings) -> None:
238 10
        if self._closed:
239 10
            return
240 10
        if not self._conns:
241 10
            return
242

243 10
        conns = [repr(c) for c in self._conns.values()]
244

245 10
        self._close_immediately()
246

247 10
        _warnings.warn(f"Unclosed connector {self!r}", ResourceWarning, source=self)
248 10
        context = {
249
            "connector": self,
250
            "connections": conns,
251
            "message": "Unclosed connector",
252
        }
253 10
        if self._source_traceback is not None:
254 10
            context["source_traceback"] = self._source_traceback
255 10
        self._loop.call_exception_handler(context)
256

257 10
    async def __aenter__(self) -> "BaseConnector":
258 10
        return self
259

260 10
    async def __aexit__(
261
        self,
262
        exc_type: Optional[Type[BaseException]] = None,
263
        exc_value: Optional[BaseException] = None,
264
        exc_traceback: Optional[TracebackType] = None,
265
    ) -> None:
266 10
        await self.close()
267

268 10
    @property
269 10
    def force_close(self) -> bool:
270
        """Ultimately close connection on releasing if True."""
271 10
        return self._force_close
272

273 10
    @property
274 10
    def limit(self) -> int:
275
        """The total number for simultaneous connections.
276

277
        If limit is 0 the connector has no limit.
278
        The default limit size is 100.
279
        """
280 10
        return self._limit
281

282 10
    @property
283 10
    def limit_per_host(self) -> int:
284
        """The limit_per_host for simultaneous connections
285
        to the same endpoint.
286

287
        Endpoints are the same if they are have equal
288
        (host, port, is_ssl) triple.
289

290
        """
291 10
        return self._limit_per_host
292

293 10
    def _cleanup(self) -> None:
294
        """Cleanup unused transports."""
295 10
        if self._cleanup_handle:
296 10
            self._cleanup_handle.cancel()
297
            # _cleanup_handle should be unset, otherwise _release() will not
298
            # recreate it ever!
299 10
            self._cleanup_handle = None
300

301 10
        now = self._loop.time()
302 10
        timeout = self._keepalive_timeout
303

304 10
        if self._conns:
305 10
            connections = {}
306 10
            deadline = now - timeout
307 10
            for key, conns in self._conns.items():
308 10
                alive = []
309 10
                for proto, use_time in conns:
310 10
                    if proto.is_connected():
311 10
                        if use_time - deadline < 0:
312 10
                            transport = proto.transport
313 10
                            proto.close()
314 10
                            if key.is_ssl and not self._cleanup_closed_disabled:
315 10
                                self._cleanup_closed_transports.append(transport)
316
                        else:
317 10
                            alive.append((proto, use_time))
318
                    else:
319 10
                        transport = proto.transport
320 10
                        proto.close()
321 10
                        if key.is_ssl and not self._cleanup_closed_disabled:
322 0
                            self._cleanup_closed_transports.append(transport)
323

324 10
                if alive:
325 10
                    connections[key] = alive
326

327 10
            self._conns = connections
328

329 10
        if self._conns:
330 10
            self._cleanup_handle = helpers.weakref_handle(
331
                self, "_cleanup", timeout, self._loop
332
            )
333

334 10
    def _drop_acquired_per_host(
335
        self, key: "ConnectionKey", val: ResponseHandler
336
    ) -> None:
337 10
        acquired_per_host = self._acquired_per_host
338 10
        if key not in acquired_per_host:
339 10
            return
340 10
        conns = acquired_per_host[key]
341 10
        conns.remove(val)
342 10
        if not conns:
343 10
            del self._acquired_per_host[key]
344

345 10
    def _cleanup_closed(self) -> None:
346
        """Double confirmation for transport close.
347
        Some broken ssl servers may leave socket open without proper close.
348
        """
349 10
        if self._cleanup_closed_handle:
350 10
            self._cleanup_closed_handle.cancel()
351

352 10
        for transport in self._cleanup_closed_transports:
353 10
            if transport is not None:
354 10
                transport.abort()
355

356 10
        self._cleanup_closed_transports = []
357

358 10
        if not self._cleanup_closed_disabled:
359 10
            self._cleanup_closed_handle = helpers.weakref_handle(
360
                self, "_cleanup_closed", self._cleanup_closed_period, self._loop
361
            )
362

363 10
    async def close(self) -> None:
364
        """Close all opened transports."""
365 10
        waiters = self._close_immediately()
366 10
        if waiters:
367 10
            results = await asyncio.gather(*waiters, return_exceptions=True)
368 10
            for res in results:
369 10
                if isinstance(res, Exception):
370 0
                    err_msg = "Error while closing connector: " + repr(res)
371 0
                    logging.error(err_msg)
372

373 10
    def _close_immediately(self) -> List["asyncio.Future[None]"]:
374 10
        waiters = []  # type: List['asyncio.Future[None]']
375

376 10
        if self._closed:
377 10
            return waiters
378

379 10
        self._closed = True
380

381 10
        try:
382 10
            if self._loop.is_closed():
383 10
                return waiters
384

385
            # cancel cleanup task
386 10
            if self._cleanup_handle:
387 10
                self._cleanup_handle.cancel()
388

389
            # cancel cleanup close task
390 10
            if self._cleanup_closed_handle:
391 10
                self._cleanup_closed_handle.cancel()
392

393 10
            for data in self._conns.values():
394 10
                for proto, t0 in data:
395 10
                    proto.close()
396 10
                    waiters.append(proto.closed)
397

398 10
            for proto in self._acquired:
399 10
                proto.close()
400 10
                waiters.append(proto.closed)
401

402
            # TODO (A.Yushovskiy, 24-May-2019) collect transp. closing futures
403 10
            for transport in self._cleanup_closed_transports:
404 10
                if transport is not None:
405 10
                    transport.abort()
406

407 10
            return waiters
408

409
        finally:
410 10
            self._conns.clear()
411 10
            self._acquired.clear()
412 10
            self._waiters.clear()
413 10
            self._cleanup_handle = None
414 10
            self._cleanup_closed_transports.clear()
415 10
            self._cleanup_closed_handle = None
416

417 10
    @property
418 10
    def closed(self) -> bool:
419
        """Is connector closed.
420

421
        A readonly property.
422
        """
423 10
        return self._closed
424

425 10
    def _available_connections(self, key: "ConnectionKey") -> int:
426
        """
427
        Return number of available connections taking into account
428
        the limit, limit_per_host and the connection key.
429

430
        If it returns less than 1 means that there is no connections
431
        availables.
432
        """
433

434 10
        if self._limit:
435
            # total calc available connections
436 10
            available = self._limit - len(self._acquired)
437

438
            # check limit per host
439 10
            if (
440
                self._limit_per_host
441
                and available > 0
442
                and key in self._acquired_per_host
443
            ):
444 10
                acquired = self._acquired_per_host.get(key)
445 10
                assert acquired is not None
446 10
                available = self._limit_per_host - len(acquired)
447

448 10
        elif self._limit_per_host and key in self._acquired_per_host:
449
            # check limit per host
450 10
            acquired = self._acquired_per_host.get(key)
451 10
            assert acquired is not None
452 10
            available = self._limit_per_host - len(acquired)
453
        else:
454 10
            available = 1
455

456 10
        return available
457

458 10
    async def connect(
459
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
460
    ) -> Connection:
461
        """Get from pool or create new connection."""
462 10
        key = req.connection_key
463 10
        available = self._available_connections(key)
464

465
        # Wait if there are no available connections or if there are/were
466
        # waiters (i.e. don't steal connection from a waiter about to wake up)
467 10
        if available <= 0 or key in self._waiters:
468 10
            fut = self._loop.create_future()
469

470
            # This connection will now count towards the limit.
471 10
            self._waiters[key].append(fut)
472

473 10
            if traces:
474 10
                for trace in traces:
475 10
                    await trace.send_connection_queued_start()
476

477 10
            try:
478 10
                await fut
479 10
            except BaseException as e:
480 10
                if key in self._waiters:
481
                    # remove a waiter even if it was cancelled, normally it's
482
                    #  removed when it's notified
483 10
                    try:
484 10
                        self._waiters[key].remove(fut)
485 0
                    except ValueError:  # fut may no longer be in list
486 0
                        pass
487

488 10
                raise e
489
            finally:
490 10
                if key in self._waiters and not self._waiters[key]:
491 10
                    del self._waiters[key]
492

493 10
            if traces:
494 10
                for trace in traces:
495 10
                    await trace.send_connection_queued_end()
496

497 10
        proto = self._get(key)
498 10
        if proto is None:
499 10
            placeholder = cast(ResponseHandler, _TransportPlaceholder(self._loop))
500 10
            self._acquired.add(placeholder)
501 10
            self._acquired_per_host[key].add(placeholder)
502

503 10
            if traces:
504 10
                for trace in traces:
505 10
                    await trace.send_connection_create_start()
506

507 10
            try:
508 10
                proto = await self._create_connection(req, traces, timeout)
509 10
                if self._closed:
510 10
                    proto.close()
511 10
                    raise ClientConnectionError("Connector is closed.")
512 10
            except BaseException:
513 10
                if not self._closed:
514 10
                    self._acquired.remove(placeholder)
515 10
                    self._drop_acquired_per_host(key, placeholder)
516 10
                    self._release_waiter()
517 10
                raise
518
            else:
519 10
                if not self._closed:
520 10
                    self._acquired.remove(placeholder)
521 10
                    self._drop_acquired_per_host(key, placeholder)
522

523 10
            if traces:
524 10
                for trace in traces:
525 10
                    await trace.send_connection_create_end()
526
        else:
527 10
            if traces:
528
                # Acquire the connection to prevent race conditions with limits
529 10
                placeholder = cast(ResponseHandler, _TransportPlaceholder(self._loop))
530 10
                self._acquired.add(placeholder)
531 10
                self._acquired_per_host[key].add(placeholder)
532 10
                for trace in traces:
533 10
                    await trace.send_connection_reuseconn()
534 10
                self._acquired.remove(placeholder)
535 10
                self._drop_acquired_per_host(key, placeholder)
536

537 10
        self._acquired.add(proto)
538 10
        self._acquired_per_host[key].add(proto)
539 10
        return Connection(self, key, proto, self._loop)
540

541 10
    def _get(self, key: "ConnectionKey") -> Optional[ResponseHandler]:
542 10
        try:
543 10
            conns = self._conns[key]
544 10
        except KeyError:
545 10
            return None
546

547 10
        t1 = self._loop.time()
548 10
        while conns:
549 10
            proto, t0 = conns.pop()
550 10
            if proto.is_connected():
551 10
                if t1 - t0 > self._keepalive_timeout:
552 10
                    transport = proto.transport
553 10
                    proto.close()
554
                    # only for SSL transports
555 10
                    if key.is_ssl and not self._cleanup_closed_disabled:
556 10
                        self._cleanup_closed_transports.append(transport)
557
                else:
558 10
                    if not conns:
559
                        # The very last connection was reclaimed: drop the key
560 10
                        del self._conns[key]
561 10
                    return proto
562
            else:
563 10
                transport = proto.transport
564 10
                proto.close()
565 10
                if key.is_ssl and not self._cleanup_closed_disabled:
566 0
                    self._cleanup_closed_transports.append(transport)
567

568
        # No more connections: drop the key
569 10
        del self._conns[key]
570 10
        return None
571

572 10
    def _release_waiter(self) -> None:
573
        """
574
        Iterates over all waiters till found one that is not finsihed and
575
        belongs to a host that has available connections.
576
        """
577 10
        if not self._waiters:
578 10
            return
579

580
        # Having the dict keys ordered this avoids to iterate
581
        # at the same order at each call.
582 10
        queues = list(self._waiters.keys())
583 10
        random.shuffle(queues)
584

585 10
        for key in queues:
586 10
            if self._available_connections(key) < 1:
587 10
                continue
588

589 10
            waiters = self._waiters[key]
590 10
            while waiters:
591 10
                waiter = waiters.popleft()
592 10
                if not waiter.done():
593 10
                    waiter.set_result(None)
594 10
                    return
595

596 10
    def _release_acquired(self, key: "ConnectionKey", proto: ResponseHandler) -> None:
597 10
        if self._closed:
598
            # acquired connection is already released on connector closing
599 10
            return
600

601 10
        try:
602 10
            self._acquired.remove(proto)
603 10
            self._drop_acquired_per_host(key, proto)
604
        except KeyError:  # pragma: no cover
605
            # this may be result of undetermenistic order of objects
606
            # finalization due garbage collection.
607
            pass
608
        else:
609 10
            self._release_waiter()
610

611 10
    def _release(
612
        self,
613
        key: "ConnectionKey",
614
        protocol: ResponseHandler,
615
        *,
616
        should_close: bool = False,
617
    ) -> None:
618 10
        if self._closed:
619
            # acquired connection is already released on connector closing
620 10
            return
621

622 10
        self._release_acquired(key, protocol)
623

624 10
        if self._force_close:
625 10
            should_close = True
626

627 10
        if should_close or protocol.should_close:
628 10
            transport = protocol.transport
629 10
            protocol.close()
630

631 10
            if key.is_ssl and not self._cleanup_closed_disabled:
632 10
                self._cleanup_closed_transports.append(transport)
633
        else:
634 10
            conns = self._conns.get(key)
635 10
            if conns is None:
636 10
                conns = self._conns[key] = []
637 10
            conns.append((protocol, self._loop.time()))
638

639 10
            if self._cleanup_handle is None:
640 10
                self._cleanup_handle = helpers.weakref_handle(
641
                    self, "_cleanup", self._keepalive_timeout, self._loop
642
                )
643

644 10
    async def _create_connection(
645
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
646
    ) -> ResponseHandler:
647 10
        raise NotImplementedError()
648

649

650 10
class _DNSCacheTable:
651 10
    def __init__(self, ttl: Optional[float] = None) -> None:
652 10
        self._addrs_rr = (
653
            {}
654
        )  # type: Dict[Tuple[str, int], Tuple[Iterator[Dict[str, Any]], int]]
655 10
        self._timestamps = {}  # type: Dict[Tuple[str, int], float]
656 10
        self._ttl = ttl
657

658 10
    def __contains__(self, host: object) -> bool:
659 10
        return host in self._addrs_rr
660

661 10
    def add(self, key: Tuple[str, int], addrs: List[Dict[str, Any]]) -> None:
662 10
        self._addrs_rr[key] = (cycle(addrs), len(addrs))
663

664 10
        if self._ttl:
665 10
            self._timestamps[key] = monotonic()
666

667 10
    def remove(self, key: Tuple[str, int]) -> None:
668 10
        self._addrs_rr.pop(key, None)
669

670 10
        if self._ttl:
671 10
            self._timestamps.pop(key, None)
672

673 10
    def clear(self) -> None:
674 10
        self._addrs_rr.clear()
675 10
        self._timestamps.clear()
676

677 10
    def next_addrs(self, key: Tuple[str, int]) -> List[Dict[str, Any]]:
678 10
        loop, length = self._addrs_rr[key]
679 10
        addrs = list(islice(loop, length))
680
        # Consume one more element to shift internal state of `cycle`
681 10
        next(loop)
682 10
        return addrs
683

684 10
    def expired(self, key: Tuple[str, int]) -> bool:
685 10
        if self._ttl is None:
686 10
            return False
687

688 10
        return self._timestamps[key] + self._ttl < monotonic()
689

690

691 10
class TCPConnector(BaseConnector):
692
    """TCP connector.
693

694
    verify_ssl - Set to True to check ssl certifications.
695
    fingerprint - Pass the binary sha256
696
        digest of the expected certificate in DER format to verify
697
        that the certificate the server presents matches. See also
698
        https://en.wikipedia.org/wiki/Transport_Layer_Security#Certificate_pinning
699
    resolver - Enable DNS lookups and use this
700
        resolver
701
    use_dns_cache - Use memory cache for DNS lookups.
702
    ttl_dns_cache - Max seconds having cached a DNS entry, None forever.
703
    family - socket address family
704
    local_addr - local tuple of (host, port) to bind socket to
705

706
    keepalive_timeout - (optional) Keep-alive timeout.
707
    force_close - Set to True to force close and do reconnect
708
        after each request (and between redirects).
709
    limit - The total number of simultaneous connections.
710
    limit_per_host - Number of simultaneous connections to one host.
711
    enable_cleanup_closed - Enables clean-up closed ssl transports.
712
                            Disabled by default.
713
    loop - Optional event loop.
714
    """
715

716 10
    def __init__(
717
        self,
718
        *,
719
        use_dns_cache: bool = True,
720
        ttl_dns_cache: Optional[int] = 10,
721
        family: int = 0,
722
        ssl: Union[None, bool, Fingerprint, SSLContext] = None,
723
        local_addr: Optional[Tuple[str, int]] = None,
724
        resolver: Optional[AbstractResolver] = None,
725
        keepalive_timeout: Union[None, float, _SENTINEL] = sentinel,
726
        force_close: bool = False,
727
        limit: int = 100,
728
        limit_per_host: int = 0,
729
        enable_cleanup_closed: bool = False,
730
    ) -> None:
731 10
        super().__init__(
732
            keepalive_timeout=keepalive_timeout,
733
            force_close=force_close,
734
            limit=limit,
735
            limit_per_host=limit_per_host,
736
            enable_cleanup_closed=enable_cleanup_closed,
737
        )
738

739 10
        if not isinstance(ssl, SSL_ALLOWED_TYPES):
740 10
            raise TypeError(
741
                "ssl should be SSLContext, bool, Fingerprint, "
742
                "or None, got {!r} instead.".format(ssl)
743
            )
744 10
        self._ssl = ssl
745 10
        if resolver is None:
746 10
            resolver = DefaultResolver()
747 10
        self._resolver: AbstractResolver = resolver
748

749 10
        self._use_dns_cache = use_dns_cache
750 10
        self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
751 10
        self._throttle_dns_events = (
752
            {}
753
        )  # type: Dict[Tuple[str, int], EventResultOrError]
754 10
        self._family = family
755 10
        self._local_addr = local_addr
756

757 10
    def _close_immediately(self) -> List["asyncio.Future[None]"]:
758 10
        for ev in self._throttle_dns_events.values():
759 10
            ev.cancel()
760 10
        return super()._close_immediately()
761

762 10
    @property
763 10
    def family(self) -> int:
764
        """Socket family like AF_INET."""
765 10
        return self._family
766

767 10
    @property
768 10
    def use_dns_cache(self) -> bool:
769
        """True if local DNS caching is enabled."""
770 10
        return self._use_dns_cache
771

772 10
    def clear_dns_cache(
773
        self, host: Optional[str] = None, port: Optional[int] = None
774
    ) -> None:
775
        """Remove specified host/port or clear all dns local cache."""
776 10
        if host is not None and port is not None:
777 10
            self._cached_hosts.remove((host, port))
778 10
        elif host is not None or port is not None:
779 10
            raise ValueError("either both host and port " "or none of them are allowed")
780
        else:
781 10
            self._cached_hosts.clear()
782

783 10
    async def _resolve_host(
784
        self, host: str, port: int, traces: Optional[List["Trace"]] = None
785
    ) -> List[Dict[str, Any]]:
786 10
        if is_ip_address(host):
787 10
            return [
788
                {
789
                    "hostname": host,
790
                    "host": host,
791
                    "port": port,
792
                    "family": self._family,
793
                    "proto": 0,
794
                    "flags": 0,
795
                }
796
            ]
797

798 10
        if not self._use_dns_cache:
799

800 10
            if traces:
801 10
                for trace in traces:
802 10
                    await trace.send_dns_resolvehost_start(host)
803

804 10
            res = await self._resolver.resolve(host, port, family=self._family)
805

806 10
            if traces:
807 10
                for trace in traces:
808 10
                    await trace.send_dns_resolvehost_end(host)
809

810 10
            return res
811

812 10
        key = (host, port)
813

814 10
        if (key in self._cached_hosts) and (not self._cached_hosts.expired(key)):
815
            # get result early, before any await (#4014)
816 10
            result = self._cached_hosts.next_addrs(key)
817

818 10
            if traces:
819 10
                for trace in traces:
820 10
                    await trace.send_dns_cache_hit(host)
821 10
            return result
822

823 10
        if key in self._throttle_dns_events:
824
            # get event early, before any await (#4014)
825 10
            event = self._throttle_dns_events[key]
826 10
            if traces:
827 10
                for trace in traces:
828 10
                    await trace.send_dns_cache_hit(host)
829 10
            await event.wait()
830
        else:
831
            # update dict early, before any await (#4014)
832 10
            self._throttle_dns_events[key] = EventResultOrError(self._loop)
833 10
            if traces:
834 10
                for trace in traces:
835 10
                    await trace.send_dns_cache_miss(host)
836 10
            try:
837

838 10
                if traces:
839 10
                    for trace in traces:
840 10
                        await trace.send_dns_resolvehost_start(host)
841

842 10
                addrs = await self._resolver.resolve(host, port, family=self._family)
843 10
                if traces:
844 10
                    for trace in traces:
845 10
                        await trace.send_dns_resolvehost_end(host)
846

847 10
                self._cached_hosts.add(key, addrs)
848 10
                self._throttle_dns_events[key].set()
849 10
            except BaseException as e:
850
                # any DNS exception, independently of the implementation
851
                # is set for the waiters to raise the same exception.
852 10
                self._throttle_dns_events[key].set(exc=e)
853 10
                raise
854
            finally:
855 10
                self._throttle_dns_events.pop(key)
856

857 10
        return self._cached_hosts.next_addrs(key)
858

859 10
    async def _create_connection(
860
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
861
    ) -> ResponseHandler:
862
        """Create connection.
863

864
        Has same keyword arguments as BaseEventLoop.create_connection.
865
        """
866 10
        if req.proxy:
867 10
            _, proto = await self._create_proxy_connection(req, traces, timeout)
868
        else:
869 10
            _, proto = await self._create_direct_connection(req, traces, timeout)
870

871 10
        return proto
872

873 10
    @staticmethod
874 10
    @functools.lru_cache(None)
875 10
    def _make_ssl_context(verified: bool) -> SSLContext:
876 10
        if verified:
877 10
            return ssl.create_default_context()
878
        else:
879 10
            sslcontext = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
880 10
            sslcontext.options |= ssl.OP_NO_SSLv2
881 10
            sslcontext.options |= ssl.OP_NO_SSLv3
882 10
            try:
883 10
                sslcontext.options |= ssl.OP_NO_COMPRESSION
884 0
            except AttributeError as attr_err:
885 0
                warnings.warn(
886
                    "{!s}: The Python interpreter is compiled "
887
                    "against OpenSSL < 1.0.0. Ref: "
888
                    "https://docs.python.org/3/library/ssl.html"
889
                    "#ssl.OP_NO_COMPRESSION".format(attr_err),
890
                )
891 10
            sslcontext.set_default_verify_paths()
892 10
            return sslcontext
893

894 10
    def _get_ssl_context(self, req: "ClientRequest") -> Optional[SSLContext]:
895
        """Logic to get the correct SSL context
896

897
        0. if req.ssl is false, return None
898

899
        1. if ssl_context is specified in req, use it
900
        2. if _ssl_context is specified in self, use it
901
        3. otherwise:
902
            1. if verify_ssl is not specified in req, use self.ssl_context
903
               (will generate a default context according to self.verify_ssl)
904
            2. if verify_ssl is True in req, generate a default SSL context
905
            3. if verify_ssl is False in req, generate a SSL context that
906
               won't verify
907
        """
908 10
        if req.is_ssl():
909
            if ssl is None:  # pragma: no cover
910
                raise RuntimeError("SSL is not supported.")
911 10
            sslcontext = req.ssl
912 10
            if isinstance(sslcontext, ssl.SSLContext):
913 10
                return sslcontext
914 10
            if sslcontext is not None:
915
                # not verified or fingerprinted
916 10
                return self._make_ssl_context(False)
917 10
            sslcontext = self._ssl
918 10
            if isinstance(sslcontext, ssl.SSLContext):
919 10
                return sslcontext
920 10
            if sslcontext is not None:
921
                # not verified or fingerprinted
922 10
                return self._make_ssl_context(False)
923 10
            return self._make_ssl_context(True)
924
        else:
925 10
            return None
926

927 10
    def _get_fingerprint(self, req: "ClientRequest") -> Optional["Fingerprint"]:
928 10
        ret = req.ssl
929 10
        if isinstance(ret, Fingerprint):
930 10
            return ret
931 10
        ret = self._ssl
932 10
        if isinstance(ret, Fingerprint):
933 10
            return ret
934 10
        return None
935

936 10
    async def _wrap_create_connection(
937
        self,
938
        *args: Any,
939
        req: "ClientRequest",
940
        timeout: "ClientTimeout",
941
        client_error: Type[Exception] = ClientConnectorError,
942
        **kwargs: Any,
943
    ) -> Tuple[asyncio.Transport, ResponseHandler]:
944 10
        try:
945 10
            async with ceil_timeout(timeout.sock_connect):
946 10
                return await self._loop.create_connection(*args, **kwargs)  # type: ignore[return-value]  # noqa
947 10
        except cert_errors as exc:
948 10
            raise ClientConnectorCertificateError(req.connection_key, exc) from exc
949 10
        except ssl_errors as exc:
950 10
            raise ClientConnectorSSLError(req.connection_key, exc) from exc
951 10
        except OSError as exc:
952 10
            raise client_error(req.connection_key, exc) from exc
953

954 10
    async def _create_direct_connection(
955
        self,
956
        req: "ClientRequest",
957
        traces: List["Trace"],
958
        timeout: "ClientTimeout",
959
        *,
960
        client_error: Type[Exception] = ClientConnectorError,
961
    ) -> Tuple[asyncio.Transport, ResponseHandler]:
962 10
        sslcontext = self._get_ssl_context(req)
963 10
        fingerprint = self._get_fingerprint(req)
964

965 10
        host = req.url.raw_host
966 10
        assert host is not None
967 10
        port = req.port
968 10
        assert port is not None
969 10
        host_resolved = asyncio.ensure_future(
970
            self._resolve_host(host, port, traces=traces), loop=self._loop
971
        )
972 10
        try:
973
            # Cancelling this lookup should not cancel the underlying lookup
974
            #  or else the cancel event will get broadcast to all the waiters
975
            #  across all connections.
976 10
            hosts = await asyncio.shield(host_resolved)
977 10
        except asyncio.CancelledError:
978

979 10
            def drop_exception(fut: "asyncio.Future[List[Dict[str, Any]]]") -> None:
980 10
                with suppress(Exception, asyncio.CancelledError):
981 10
                    fut.result()
982

983 10
            host_resolved.add_done_callback(drop_exception)
984 10
            raise
985 10
        except OSError as exc:
986
            # in case of proxy it is not ClientProxyConnectionError
987
            # it is problem of resolving proxy ip itself
988 10
            raise ClientConnectorError(req.connection_key, exc) from exc
989

990 10
        last_exc = None  # type: Optional[Exception]
991

992 10
        for hinfo in hosts:
993 10
            host = hinfo["host"]
994 10
            port = hinfo["port"]
995

996 10
            try:
997 10
                transp, proto = await self._wrap_create_connection(
998
                    self._factory,
999
                    host,
1000
                    port,
1001
                    timeout=timeout,
1002
                    ssl=sslcontext,
1003
                    family=hinfo["family"],
1004
                    proto=hinfo["proto"],
1005
                    flags=hinfo["flags"],
1006
                    server_hostname=hinfo["hostname"] if sslcontext else None,
1007
                    local_addr=self._local_addr,
1008
                    req=req,
1009
                    client_error=client_error,
1010
                )
1011 10
            except ClientConnectorError as exc:
1012 10
                last_exc = exc
1013 10
                continue
1014

1015 10
            if req.is_ssl() and fingerprint:
1016 10
                try:
1017 10
                    fingerprint.check(transp)
1018 10
                except ServerFingerprintMismatch as exc:
1019 10
                    transp.close()
1020 10
                    if not self._cleanup_closed_disabled:
1021 0
                        self._cleanup_closed_transports.append(transp)
1022 10
                    last_exc = exc
1023 10
                    continue
1024

1025 10
            return transp, proto
1026 10
        assert last_exc is not None
1027 10
        raise last_exc
1028

1029 10
    async def _create_proxy_connection(
1030
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
1031
    ) -> Tuple[asyncio.Transport, ResponseHandler]:
1032 10
        headers = {}  # type: Dict[str, str]
1033 10
        if req.proxy_headers is not None:
1034 10
            headers = req.proxy_headers  # type: ignore[assignment]
1035 10
        headers[hdrs.HOST] = req.headers[hdrs.HOST]
1036

1037 10
        url = req.proxy
1038 10
        assert url is not None
1039 10
        proxy_req = ClientRequest(
1040
            hdrs.METH_GET,
1041
            url,
1042
            headers=headers,
1043
            auth=req.proxy_auth,
1044
            loop=self._loop,
1045
            ssl=req.ssl,
1046
        )
1047

1048
        # create connection to proxy server
1049 10
        transport, proto = await self._create_direct_connection(
1050
            proxy_req, [], timeout, client_error=ClientProxyConnectionError
1051
        )
1052

1053
        # Many HTTP proxies has buggy keepalive support.  Let's not
1054
        # reuse connection but close it after processing every
1055
        # response.
1056 10
        proto.force_close()
1057

1058 10
        auth = proxy_req.headers.pop(hdrs.AUTHORIZATION, None)
1059 10
        if auth is not None:
1060 10
            if not req.is_ssl():
1061 10
                req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1062
            else:
1063 10
                proxy_req.headers[hdrs.PROXY_AUTHORIZATION] = auth
1064

1065 10
        if req.is_ssl():
1066 10
            sslcontext = self._get_ssl_context(req)
1067
            # For HTTPS requests over HTTP proxy
1068
            # we must notify proxy to tunnel connection
1069
            # so we send CONNECT command:
1070
            #   CONNECT www.python.org:443 HTTP/1.1
1071
            #   Host: www.python.org
1072
            #
1073
            # next we must do TLS handshake and so on
1074
            # to do this we must wrap raw socket into secure one
1075
            # asyncio handles this perfectly
1076 10
            proxy_req.method = hdrs.METH_CONNECT
1077 10
            proxy_req.url = req.url
1078 10
            key = dataclasses.replace(
1079
                req.connection_key, proxy=None, proxy_auth=None, proxy_headers_hash=None
1080
            )
1081 10
            conn = Connection(self, key, proto, self._loop)
1082 10
            proxy_resp = await proxy_req.send(conn)
1083 10
            try:
1084 10
                protocol = conn._protocol
1085 10
                assert protocol is not None
1086 10
                protocol.set_response_params()
1087 10
                resp = await proxy_resp.start(conn)
1088 10
            except BaseException:
1089 10
                proxy_resp.close()
1090 10
                conn.close()
1091 10
                raise
1092
            else:
1093 10
                conn._protocol = None
1094 10
                conn._transport = None
1095 10
                try:
1096 10
                    if resp.status != 200:
1097 10
                        message = resp.reason
1098 10
                        if message is None:
1099 0
                            message = RESPONSES[resp.status][0]
1100 10
                        raise ClientHttpProxyError(
1101
                            proxy_resp.request_info,
1102
                            resp.history,
1103
                            status=resp.status,
1104
                            message=message,
1105
                            headers=resp.headers,
1106
                        )
1107 10
                    rawsock = transport.get_extra_info("socket", default=None)
1108 10
                    if rawsock is None:
1109 10
                        raise RuntimeError("Transport does not expose socket instance")
1110
                    # Duplicate the socket, so now we can close proxy transport
1111 10
                    rawsock = rawsock.dup()
1112
                finally:
1113 10
                    transport.close()
1114

1115 10
                transport, proto = await self._wrap_create_connection(
1116
                    self._factory,
1117
                    timeout=timeout,
1118
                    ssl=sslcontext,
1119
                    sock=rawsock,
1120
                    server_hostname=req.host,
1121
                    req=req,
1122
                )
1123
            finally:
1124 10
                proxy_resp.close()
1125

1126 10
        return transport, proto
1127

1128

1129 10
class UnixConnector(BaseConnector):
1130
    """Unix socket connector.
1131

1132
    path - Unix socket path.
1133
    keepalive_timeout - (optional) Keep-alive timeout.
1134
    force_close - Set to True to force close and do reconnect
1135
        after each request (and between redirects).
1136
    limit - The total number of simultaneous connections.
1137
    limit_per_host - Number of simultaneous connections to one host.
1138
    loop - Optional event loop.
1139
    """
1140

1141 10
    def __init__(
1142
        self,
1143
        path: str,
1144
        force_close: bool = False,
1145
        keepalive_timeout: Union[_SENTINEL, float, None] = sentinel,
1146
        limit: int = 100,
1147
        limit_per_host: int = 0,
1148
    ) -> None:
1149 7
        super().__init__(
1150
            force_close=force_close,
1151
            keepalive_timeout=keepalive_timeout,
1152
            limit=limit,
1153
            limit_per_host=limit_per_host,
1154
        )
1155 7
        self._path = path
1156

1157 10
    @property
1158 10
    def path(self) -> str:
1159
        """Path to unix socket."""
1160 7
        return self._path
1161

1162 10
    async def _create_connection(
1163
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
1164
    ) -> ResponseHandler:
1165 7
        try:
1166 7
            async with ceil_timeout(timeout.sock_connect):
1167 7
                _, proto = await self._loop.create_unix_connection(
1168
                    self._factory, self._path
1169
                )
1170 7
        except OSError as exc:
1171 7
            raise UnixClientConnectorError(self.path, req.connection_key, exc) from exc
1172

1173 7
        return cast(ResponseHandler, proto)
1174

1175

1176 10
class NamedPipeConnector(BaseConnector):
1177
    """Named pipe connector.
1178

1179
    Only supported by the proactor event loop.
1180
    See also: https://docs.python.org/3.7/library/asyncio-eventloop.html
1181

1182
    path - Windows named pipe path.
1183
    keepalive_timeout - (optional) Keep-alive timeout.
1184
    force_close - Set to True to force close and do reconnect
1185
        after each request (and between redirects).
1186
    limit - The total number of simultaneous connections.
1187
    limit_per_host - Number of simultaneous connections to one host.
1188
    loop - Optional event loop.
1189
    """
1190

1191 10
    def __init__(
1192
        self,
1193
        path: str,
1194
        force_close: bool = False,
1195
        keepalive_timeout: Union[_SENTINEL, float, None] = sentinel,
1196
        limit: int = 100,
1197
        limit_per_host: int = 0,
1198
    ) -> None:
1199 3
        super().__init__(
1200
            force_close=force_close,
1201
            keepalive_timeout=keepalive_timeout,
1202
            limit=limit,
1203
            limit_per_host=limit_per_host,
1204
        )
1205 10
        if not isinstance(
1206
            self._loop, asyncio.ProactorEventLoop  # type: ignore[attr-defined]
1207
        ):
1208 3
            raise RuntimeError(
1209
                "Named Pipes only available in proactor " "loop under windows"
1210
            )
1211 3
        self._path = path
1212

1213 10
    @property
1214 10
    def path(self) -> str:
1215
        """Path to the named pipe."""
1216 3
        return self._path
1217

1218 10
    async def _create_connection(
1219
        self, req: "ClientRequest", traces: List["Trace"], timeout: "ClientTimeout"
1220
    ) -> ResponseHandler:
1221 3
        try:
1222 3
            async with ceil_timeout(timeout.sock_connect):
1223 3
                _, proto = await self._loop.create_pipe_connection(  # type: ignore[attr-defined] # noqa: E501
1224
                    self._factory, self._path
1225
                )
1226
                # the drain is required so that the connection_made is called
1227
                # and transport is set otherwise it is not set before the
1228
                # `assert conn.transport is not None`
1229
                # in client.py's _request method
1230 3
                await asyncio.sleep(0)
1231
                # other option is to manually set transport like
1232
                # `proto.transport = trans`
1233 3
        except OSError as exc:
1234 3
            raise ClientConnectorError(req.connection_key, exc) from exc
1235

1236 3
        return cast(ResponseHandler, proto)

Read our documentation on viewing source code .

Loading