aio-libs / aiohttp
1 10
import asyncio
2 10
import collections
3 10
import warnings
4 10
from typing import Awaitable, Callable, Generic, List, Optional, Tuple, TypeVar
5

6 10
from typing_extensions import Final
7

8 10
from .base_protocol import BaseProtocol
9 10
from .helpers import BaseTimerContext, set_exception, set_result
10 10
from .log import internal_logger
11

12
try:  # pragma: no cover
13
    from typing import Deque
14 0
except ImportError:
15 0
    from typing_extensions import Deque
16

17 10
__all__ = (
18
    "EMPTY_PAYLOAD",
19
    "EofStream",
20
    "StreamReader",
21
    "DataQueue",
22
    "FlowControlDataQueue",
23
)
24

25 10
_T = TypeVar("_T")
26

27

28 10
class EofStream(Exception):
29
    """eof stream indication."""
30

31

32 10
class AsyncStreamIterator(Generic[_T]):
33 10
    def __init__(self, read_func: Callable[[], Awaitable[_T]]) -> None:
34 10
        self.read_func = read_func
35

36 10
    def __aiter__(self) -> "AsyncStreamIterator[_T]":
37 10
        return self
38

39 10
    async def __anext__(self) -> _T:
40 10
        try:
41 10
            rv = await self.read_func()
42 10
        except EofStream:
43 10
            raise StopAsyncIteration
44 10
        if rv == b"":
45 10
            raise StopAsyncIteration
46 10
        return rv
47

48

49 10
class ChunkTupleAsyncStreamIterator:
50 10
    def __init__(self, stream: "StreamReader") -> None:
51 10
        self._stream = stream
52

53 10
    def __aiter__(self) -> "ChunkTupleAsyncStreamIterator":
54 10
        return self
55

56 10
    async def __anext__(self) -> Tuple[bytes, bool]:
57 10
        rv = await self._stream.readchunk()
58 10
        if rv == (b"", False):
59 10
            raise StopAsyncIteration
60 10
        return rv
61

62

63 10
class AsyncStreamReaderMixin:
64 10
    def __aiter__(self) -> AsyncStreamIterator[bytes]:
65 10
        return AsyncStreamIterator(self.readline)  # type: ignore[attr-defined]
66

67 10
    def iter_chunked(self, n: int) -> AsyncStreamIterator[bytes]:
68
        """Returns an asynchronous iterator that yields chunks of size n.
69

70
        Python-3.5 available for Python 3.5+ only
71
        """
72 10
        return AsyncStreamIterator(
73
            lambda: self.read(n)  # type: ignore[attr-defined,no-any-return]
74
        )
75

76 10
    def iter_any(self) -> AsyncStreamIterator[bytes]:
77
        """Returns an asynchronous iterator that yields all the available
78
        data as soon as it is received
79

80
        Python-3.5 available for Python 3.5+ only
81
        """
82 10
        return AsyncStreamIterator(self.readany)  # type: ignore[attr-defined]
83

84 10
    def iter_chunks(self) -> ChunkTupleAsyncStreamIterator:
85
        """Returns an asynchronous iterator that yields chunks of data
86
        as they are received by the server. The yielded objects are tuples
87
        of (bytes, bool) as returned by the StreamReader.readchunk method.
88

89
        Python-3.5 available for Python 3.5+ only
90
        """
91 10
        return ChunkTupleAsyncStreamIterator(self)  # type: ignore[arg-type]
92

93

94 10
class StreamReader(AsyncStreamReaderMixin):
95
    """An enhancement of asyncio.StreamReader.
96

97
    Supports asynchronous iteration by line, chunk or as available::
98

99
        async for line in reader:
100
            ...
101
        async for chunk in reader.iter_chunked(1024):
102
            ...
103
        async for slice in reader.iter_any():
104
            ...
105

106
    """
107

108 10
    total_bytes = 0
109

110 10
    def __init__(
111
        self,
112
        protocol: BaseProtocol,
113
        limit: int,
114
        *,
115
        timer: Optional[BaseTimerContext] = None,
116
        loop: asyncio.AbstractEventLoop,
117
    ) -> None:
118 10
        self._protocol = protocol
119 10
        self._low_water = limit
120 10
        self._high_water = limit * 2
121 10
        if loop is None:
122 10
            loop = asyncio.get_event_loop()
123 10
        self._loop = loop
124 10
        self._size = 0
125 10
        self._cursor = 0
126 10
        self._http_chunk_splits = None  # type: Optional[List[int]]
127 10
        self._buffer = collections.deque()  # type: Deque[bytes]
128 10
        self._buffer_offset = 0
129 10
        self._eof = False
130 10
        self._waiter = None  # type: Optional[asyncio.Future[None]]
131 10
        self._eof_waiter = None  # type: Optional[asyncio.Future[None]]
132 10
        self._exception = None  # type: Optional[BaseException]
133 10
        self._timer = timer
134 10
        self._eof_callbacks = []  # type: List[Callable[[], None]]
135

136 10
    def __repr__(self) -> str:
137 10
        info = [self.__class__.__name__]
138 10
        if self._size:
139 10
            info.append("%d bytes" % self._size)
140 10
        if self._eof:
141 10
            info.append("eof")
142 10
        if self._low_water != 2 ** 16:  # default limit
143 10
            info.append("low=%d high=%d" % (self._low_water, self._high_water))
144 10
        if self._waiter:
145 10
            info.append("w=%r" % self._waiter)
146 10
        if self._exception:
147 10
            info.append("e=%r" % self._exception)
148 10
        return "<%s>" % " ".join(info)
149

150 10
    def get_read_buffer_limits(self) -> Tuple[int, int]:
151 10
        return (self._low_water, self._high_water)
152

153 10
    def exception(self) -> Optional[BaseException]:
154 10
        return self._exception
155

156 10
    def set_exception(self, exc: BaseException) -> None:
157 10
        self._exception = exc
158 10
        self._eof_callbacks.clear()
159

160 10
        waiter = self._waiter
161 10
        if waiter is not None:
162 10
            self._waiter = None
163 10
            set_exception(waiter, exc)
164

165 10
        waiter = self._eof_waiter
166 10
        if waiter is not None:
167 10
            self._eof_waiter = None
168 10
            set_exception(waiter, exc)
169

170 10
    def on_eof(self, callback: Callable[[], None]) -> None:
171 10
        if self._eof:
172 10
            try:
173 10
                callback()
174 10
            except Exception:
175 10
                internal_logger.exception("Exception in eof callback")
176
        else:
177 10
            self._eof_callbacks.append(callback)
178

179 10
    def feed_eof(self) -> None:
180 10
        self._eof = True
181

182 10
        waiter = self._waiter
183 10
        if waiter is not None:
184 10
            self._waiter = None
185 10
            set_result(waiter, None)
186

187 10
        waiter = self._eof_waiter
188 10
        if waiter is not None:
189 10
            self._eof_waiter = None
190 10
            set_result(waiter, None)
191

192 10
        for cb in self._eof_callbacks:
193 10
            try:
194 10
                cb()
195 10
            except Exception:
196 10
                internal_logger.exception("Exception in eof callback")
197

198 10
        self._eof_callbacks.clear()
199

200 10
    def is_eof(self) -> bool:
201
        """Return True if  'feed_eof' was called."""
202 10
        return self._eof
203

204 10
    def at_eof(self) -> bool:
205
        """Return True if the buffer is empty and 'feed_eof' was called."""
206 10
        return self._eof and not self._buffer
207

208 10
    async def wait_eof(self) -> None:
209 10
        if self._eof:
210 10
            return
211

212 10
        assert self._eof_waiter is None
213 10
        self._eof_waiter = self._loop.create_future()
214 10
        try:
215 10
            await self._eof_waiter
216
        finally:
217 10
            self._eof_waiter = None
218

219 10
    def unread_data(self, data: bytes) -> None:
220
        """rollback reading some data from stream, inserting it to buffer head."""
221 10
        warnings.warn(
222
            "unread_data() is deprecated "
223
            "and will be removed in future releases (#3260)",
224
            DeprecationWarning,
225
            stacklevel=2,
226
        )
227 10
        if not data:
228 10
            return
229

230 10
        if self._buffer_offset:
231 10
            self._buffer[0] = self._buffer[0][self._buffer_offset :]
232 10
            self._buffer_offset = 0
233 10
        self._size += len(data)
234 10
        self._cursor -= len(data)
235 10
        self._buffer.appendleft(data)
236 10
        self._eof_counter = 0
237

238
    # TODO: size is ignored, remove the param later
239 10
    def feed_data(self, data: bytes, size: int = 0) -> None:
240 10
        assert not self._eof, "feed_data after feed_eof"
241

242 10
        if not data:
243 10
            return
244

245 10
        self._size += len(data)
246 10
        self._buffer.append(data)
247 10
        self.total_bytes += len(data)
248

249 10
        waiter = self._waiter
250 10
        if waiter is not None:
251 10
            self._waiter = None
252 10
            set_result(waiter, None)
253

254 10
        if self._size > self._high_water and not self._protocol._reading_paused:
255 10
            self._protocol.pause_reading()
256

257 10
    def begin_http_chunk_receiving(self) -> None:
258 10
        if self._http_chunk_splits is None:
259 10
            if self.total_bytes:
260 0
                raise RuntimeError(
261
                    "Called begin_http_chunk_receiving when" "some data was already fed"
262
                )
263 10
            self._http_chunk_splits = []
264

265 10
    def end_http_chunk_receiving(self) -> None:
266 10
        if self._http_chunk_splits is None:
267 10
            raise RuntimeError(
268
                "Called end_chunk_receiving without calling "
269
                "begin_chunk_receiving first"
270
            )
271

272
        # self._http_chunk_splits contains logical byte offsets from start of
273
        # the body transfer. Each offset is the offset of the end of a chunk.
274
        # "Logical" means bytes, accessible for a user.
275
        # If no chunks containing logical data were received, current position
276
        # is difinitely zero.
277 10
        pos = self._http_chunk_splits[-1] if self._http_chunk_splits else 0
278

279 10
        if self.total_bytes == pos:
280
            # We should not add empty chunks here. So we check for that.
281
            # Note, when chunked + gzip is used, we can receive a chunk
282
            # of compressed data, but that data may not be enough for gzip FSM
283
            # to yield any uncompressed data. That's why current position may
284
            # not change after receiving a chunk.
285 10
            return
286

287 10
        self._http_chunk_splits.append(self.total_bytes)
288

289
        # wake up readchunk when end of http chunk received
290 10
        waiter = self._waiter
291 10
        if waiter is not None:
292 10
            self._waiter = None
293 10
            set_result(waiter, None)
294

295 10
    async def _wait(self, func_name: str) -> None:
296
        # StreamReader uses a future to link the protocol feed_data() method
297
        # to a read coroutine. Running two read coroutines at the same time
298
        # would have an unexpected behaviour. It would not possible to know
299
        # which coroutine would get the next data.
300 10
        if self._waiter is not None:
301 10
            raise RuntimeError(
302
                "%s() called while another coroutine is "
303
                "already waiting for incoming data" % func_name
304
            )
305

306 10
        waiter = self._waiter = self._loop.create_future()
307 10
        try:
308 10
            if self._timer:
309 10
                with self._timer:
310 10
                    await waiter
311
            else:
312 10
                await waiter
313
        finally:
314 10
            self._waiter = None
315

316 10
    async def readline(self) -> bytes:
317 10
        return await self.readuntil()
318

319 10
    async def readuntil(self, separator: bytes = b"\n") -> bytes:
320 10
        seplen = len(separator)
321 10
        if seplen == 0:
322 0
            raise ValueError("Separator should be at least one-byte string")
323

324 10
        if self._exception is not None:
325 10
            raise self._exception
326

327 10
        chunk = b""
328 10
        chunk_size = 0
329 10
        not_enough = True
330

331 10
        while not_enough:
332 10
            while self._buffer and not_enough:
333 10
                offset = self._buffer_offset
334 10
                ichar = self._buffer[0].find(separator, offset) + 1
335
                # Read from current offset to found separator or to the end.
336 10
                data = self._read_nowait_chunk(ichar - offset if ichar else -1)
337 10
                chunk += data
338 10
                chunk_size += len(data)
339 10
                if ichar:
340 10
                    not_enough = False
341

342 10
                if chunk_size > self._high_water:
343 10
                    raise ValueError("Chunk too big")
344

345 10
            if self._eof:
346 10
                break
347

348 10
            if not_enough:
349 10
                await self._wait("readuntil")
350

351 10
        return chunk
352

353 10
    async def read(self, n: int = -1) -> bytes:
354 10
        if self._exception is not None:
355 10
            raise self._exception
356

357 10
        if not n:
358 10
            return b""
359

360 10
        if n < 0:
361
            # This used to just loop creating a new waiter hoping to
362
            # collect everything in self._buffer, but that would
363
            # deadlock if the subprocess sends more than self.limit
364
            # bytes.  So just call self.readany() until EOF.
365 10
            blocks = []
366 3
            while True:
367 10
                block = await self.readany()
368 10
                if not block:
369 10
                    break
370 10
                blocks.append(block)
371 10
            return b"".join(blocks)
372

373
        # TODO: should be `if` instead of `while`
374
        # because waiter maybe triggered on chunk end,
375
        # without feeding any data
376 10
        while not self._buffer and not self._eof:
377 10
            await self._wait("read")
378

379 10
        return self._read_nowait(n)
380

381 10
    async def readany(self) -> bytes:
382 10
        if self._exception is not None:
383 10
            raise self._exception
384

385
        # TODO: should be `if` instead of `while`
386
        # because waiter maybe triggered on chunk end,
387
        # without feeding any data
388 10
        while not self._buffer and not self._eof:
389 10
            await self._wait("readany")
390

391 10
        return self._read_nowait(-1)
392

393 10
    async def readchunk(self) -> Tuple[bytes, bool]:
394
        """Returns a tuple of (data, end_of_http_chunk). When chunked transfer
395
        encoding is used, end_of_http_chunk is a boolean indicating if the end
396
        of the data corresponds to the end of a HTTP chunk , otherwise it is
397
        always False.
398
        """
399 3
        while True:
400 10
            if self._exception is not None:
401 0
                raise self._exception
402

403 10
            while self._http_chunk_splits:
404 10
                pos = self._http_chunk_splits.pop(0)
405 10
                if pos == self._cursor:
406 10
                    return (b"", True)
407 10
                if pos > self._cursor:
408 10
                    return (self._read_nowait(pos - self._cursor), True)
409 0
                internal_logger.warning(
410
                    "Skipping HTTP chunk end due to data "
411
                    "consumption beyond chunk boundary"
412
                )
413

414 10
            if self._buffer:
415 10
                return (self._read_nowait_chunk(-1), False)
416
                # return (self._read_nowait(-1), False)
417

418 10
            if self._eof:
419
                # Special case for signifying EOF.
420
                # (b'', True) is not a final return value actually.
421 10
                return (b"", False)
422

423 10
            await self._wait("readchunk")
424

425 10
    async def readexactly(self, n: int) -> bytes:
426 10
        if self._exception is not None:
427 10
            raise self._exception
428

429 10
        blocks = []  # type: List[bytes]
430 10
        while n > 0:
431 10
            block = await self.read(n)
432 10
            if not block:
433 10
                partial = b"".join(blocks)
434 10
                raise asyncio.IncompleteReadError(partial, len(partial) + n)
435 10
            blocks.append(block)
436 10
            n -= len(block)
437

438 10
        return b"".join(blocks)
439

440 10
    def read_nowait(self, n: int = -1) -> bytes:
441
        # default was changed to be consistent with .read(-1)
442
        #
443
        # I believe the most users don't know about the method and
444
        # they are not affected.
445 10
        if self._exception is not None:
446 10
            raise self._exception
447

448 10
        if self._waiter and not self._waiter.done():
449 10
            raise RuntimeError(
450
                "Called while some coroutine is waiting for incoming data."
451
            )
452

453 10
        return self._read_nowait(n)
454

455 10
    def _read_nowait_chunk(self, n: int) -> bytes:
456 10
        first_buffer = self._buffer[0]
457 10
        offset = self._buffer_offset
458 10
        if n != -1 and len(first_buffer) - offset > n:
459 10
            data = first_buffer[offset : offset + n]
460 10
            self._buffer_offset += n
461

462 10
        elif offset:
463 10
            self._buffer.popleft()
464 10
            data = first_buffer[offset:]
465 10
            self._buffer_offset = 0
466

467
        else:
468 10
            data = self._buffer.popleft()
469

470 10
        self._size -= len(data)
471 10
        self._cursor += len(data)
472

473 10
        chunk_splits = self._http_chunk_splits
474
        # Prevent memory leak: drop useless chunk splits
475 10
        while chunk_splits and chunk_splits[0] < self._cursor:
476 10
            chunk_splits.pop(0)
477

478 10
        if self._size < self._low_water and self._protocol._reading_paused:
479 10
            self._protocol.resume_reading()
480 10
        return data
481

482 10
    def _read_nowait(self, n: int) -> bytes:
483
        """ Read not more than n bytes, or whole buffer if n == -1 """
484 10
        chunks = []
485

486 10
        while self._buffer:
487 10
            chunk = self._read_nowait_chunk(n)
488 10
            chunks.append(chunk)
489 10
            if n != -1:
490 10
                n -= len(chunk)
491 10
                if n == 0:
492 10
                    break
493

494 10
        return b"".join(chunks) if chunks else b""
495

496

497 10
class EmptyStreamReader(StreamReader):  # lgtm [py/missing-call-to-init]
498 10
    def __init__(self) -> None:
499 10
        pass
500

501 10
    def exception(self) -> Optional[BaseException]:
502 10
        return None
503

504 10
    def set_exception(self, exc: BaseException) -> None:
505 10
        pass
506

507 10
    def on_eof(self, callback: Callable[[], None]) -> None:
508 10
        try:
509 10
            callback()
510 10
        except Exception:
511 10
            internal_logger.exception("Exception in eof callback")
512

513 10
    def feed_eof(self) -> None:
514 10
        pass
515

516 10
    def is_eof(self) -> bool:
517 10
        return True
518

519 10
    def at_eof(self) -> bool:
520 10
        return True
521

522 10
    async def wait_eof(self) -> None:
523 10
        return
524

525 10
    def feed_data(self, data: bytes, n: int = 0) -> None:
526 10
        pass
527

528 10
    async def readline(self) -> bytes:
529 10
        return b""
530

531 10
    async def read(self, n: int = -1) -> bytes:
532 10
        return b""
533

534
    # TODO add async def readuntil
535

536 10
    async def readany(self) -> bytes:
537 10
        return b""
538

539 10
    async def readchunk(self) -> Tuple[bytes, bool]:
540 10
        return (b"", True)
541

542 10
    async def readexactly(self, n: int) -> bytes:
543 10
        raise asyncio.IncompleteReadError(b"", n)
544

545 10
    def read_nowait(self, n: int = -1) -> bytes:
546 10
        return b""
547

548

549 10
EMPTY_PAYLOAD: Final[StreamReader] = EmptyStreamReader()
550

551

552 10
class DataQueue(Generic[_T]):
553
    """DataQueue is a general-purpose blocking queue with one reader."""
554

555 10
    def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
556 10
        self._loop = loop
557 10
        self._eof = False
558 10
        self._waiter = None  # type: Optional[asyncio.Future[None]]
559 10
        self._exception = None  # type: Optional[BaseException]
560 10
        self._size = 0
561 10
        self._buffer = collections.deque()  # type: Deque[Tuple[_T, int]]
562

563 10
    def __len__(self) -> int:
564 10
        return len(self._buffer)
565

566 10
    def is_eof(self) -> bool:
567 10
        return self._eof
568

569 10
    def at_eof(self) -> bool:
570 10
        return self._eof and not self._buffer
571

572 10
    def exception(self) -> Optional[BaseException]:
573 10
        return self._exception
574

575 10
    def set_exception(self, exc: BaseException) -> None:
576 10
        self._eof = True
577 10
        self._exception = exc
578

579 10
        waiter = self._waiter
580 10
        if waiter is not None:
581 10
            self._waiter = None
582 10
            set_exception(waiter, exc)
583

584 10
    def feed_data(self, data: _T, size: int = 0) -> None:
585 10
        self._size += size
586 10
        self._buffer.append((data, size))
587

588 10
        waiter = self._waiter
589 10
        if waiter is not None:
590 10
            self._waiter = None
591 10
            set_result(waiter, None)
592

593 10
    def feed_eof(self) -> None:
594 10
        self._eof = True
595

596 10
        waiter = self._waiter
597 10
        if waiter is not None:
598 10
            self._waiter = None
599 10
            set_result(waiter, None)
600

601 10
    async def read(self) -> _T:
602 10
        if not self._buffer and not self._eof:
603 10
            assert not self._waiter
604 10
            self._waiter = self._loop.create_future()
605 10
            try:
606 10
                await self._waiter
607 10
            except (asyncio.CancelledError, asyncio.TimeoutError):
608 10
                self._waiter = None
609 10
                raise
610

611 10
        if self._buffer:
612 10
            data, size = self._buffer.popleft()
613 10
            self._size -= size
614 10
            return data
615
        else:
616 10
            if self._exception is not None:
617 10
                raise self._exception
618
            else:
619 10
                raise EofStream
620

621 10
    def __aiter__(self) -> AsyncStreamIterator[_T]:
622 10
        return AsyncStreamIterator(self.read)
623

624

625 10
class FlowControlDataQueue(DataQueue[_T]):
626
    """FlowControlDataQueue resumes and pauses an underlying stream.
627

628
    It is a destination for parsed data."""
629

630 10
    def __init__(
631
        self, protocol: BaseProtocol, limit: int, *, loop: asyncio.AbstractEventLoop
632
    ) -> None:
633 10
        super().__init__(loop=loop)
634

635 10
        self._protocol = protocol
636 10
        self._limit = limit * 2
637

638 10
    def feed_data(self, data: _T, size: int = 0) -> None:
639 10
        super().feed_data(data, size)
640

641 10
        if self._size > self._limit and not self._protocol._reading_paused:
642 10
            self._protocol.pause_reading()
643

644 10
    async def read(self) -> _T:
645 10
        try:
646 10
            return await super().read()
647
        finally:
648 10
            if self._size < self._limit and self._protocol._reading_paused:
649 10
                self._protocol.resume_reading()

Read our documentation on viewing source code .

Loading