1 5
import socket
2 5
import struct
3 5
import warnings
4 5
from collections import namedtuple
5 5
from types import TracebackType
6 5
from typing import TYPE_CHECKING, Awaitable, Callable, Optional, Tuple, Type
7

8 5
import aiohttp
9 5
from yarl import URL
10

11 5
from .exceptions import DockerError
12

13

14 5
if TYPE_CHECKING:
15 0
    from .docker import Docker
16

17 5
Message = namedtuple("Message", "stream data")
18

19

20 5
class Stream:
21 5
    def __init__(
22
        self,
23
        docker: "Docker",
24
        setup: Callable[[], Awaitable[Tuple[URL, Optional[bytes], bool]]],
25
        timeout: Optional[aiohttp.ClientTimeout],
26
    ) -> None:
27 5
        self._setup = setup
28 5
        self.docker = docker
29 5
        self._resp = None
30 5
        self._closed = False
31 5
        self._timeout = timeout
32 5
        self._queue = None
33

34 5
    async def _init(self) -> None:
35 5
        if self._resp is not None:
36 5
            return
37 5
        url, body, tty = await self._setup()
38 5
        timeout = self._timeout
39 5
        if timeout is None:
40
            # total timeout doesn't make sense for streaming
41 5
            timeout = aiohttp.ClientTimeout()
42 5
        self._resp = resp = await self.docker._do_query(
43
            url,
44
            method="POST",
45
            data=body,
46
            params=None,
47
            headers={"Connection": "Upgrade", "Upgrade": "tcp"},
48
            timeout=timeout,
49
            chunked=None,
50
            read_until_eof=False,
51
            versioned_api=True,
52
        )
53
        # read body if present, it can contain an information
54
        # about disconnection
55 5
        assert self._resp is not None
56 5
        body = await self._resp.read()
57

58 5
        conn = resp.connection
59 5
        if conn is None:
60 0
            msg = (
61
                "Cannot upgrade connection to vendored tcp protocol, "
62
                "the docker server has closed underlying socket."
63
            )
64 0
            msg += f" Status code: {resp.status}."
65 0
            msg += f" Headers: {resp.headers}."
66 0
            if body:
67 0
                if len(body) > 100:
68 0
                    msg = msg + f" First 100 bytes of body: [{body[100]!r}]..."
69
                else:
70 0
                    msg = msg + f" Body: [{body!r}]"
71 0
            raise DockerError(
72
                500,
73
                {"message": msg},
74
            )
75 5
        protocol = conn.protocol
76 5
        loop = resp._loop
77 5
        sock = protocol.transport.get_extra_info("socket")
78 5
        if sock is not None:
79
            # set TCP keepalive for vendored socket
80
            # the socket can be closed in the case of error
81 4
            sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
82

83 5
        queue: aiohttp.FlowControlDataQueue[Message] = aiohttp.FlowControlDataQueue(
84
            protocol, limit=2 ** 16, loop=loop
85
        )
86 5
        protocol.set_parser(_ExecParser(queue, tty=tty), queue)
87 5
        protocol.force_close()
88 5
        self._queue = queue
89

90 5
    async def read_out(self) -> Optional[Message]:
91
        """Read from stdout or stderr."""
92 5
        await self._init()
93 5
        try:
94 5
            assert self._queue is not None
95 5
            return await self._queue.read()
96 0
        except aiohttp.EofStream:
97 0
            return None
98

99 5
    async def write_in(self, data: bytes) -> None:
100
        """Write into stdin."""
101 4
        if self._closed:
102 0
            raise RuntimeError("Cannot write to closed transport")
103 4
        await self._init()
104 4
        assert self._resp is not None
105 4
        transport = self._resp.connection.transport
106 4
        transport.write(data)
107 4
        protocol = self._resp.connection.protocol
108 4
        if protocol.transport is not None:
109 4
            await protocol._drain_helper()
110

111 5
    async def close(self) -> None:
112 5
        if self._resp is not None:
113 5
            return
114 0
        if self._closed:
115 0
            return
116 0
        assert self._resp is not None
117 0
        self._closed = True
118 0
        transport = self._resp.connection.transport
119 0
        transport.write_eof()
120 0
        await self._resp.close()
121

122 5
    async def __aenter__(self) -> "Stream":
123 5
        await self._init()
124 5
        return self
125

126 5
    async def __aexit__(
127
        self,
128
        exc_typ: Type[BaseException],
129
        exc_val: BaseException,
130
        exc_tb: TracebackType,
131
    ) -> None:
132 5
        await self.close()
133

134 5
    def __del__(self, _warnings=warnings) -> None:
135 5
        if self._resp is not None:
136 5
            return
137 0
        if not self._closed:
138 0
            warnings.warn("Unclosed ExecStream", ResourceWarning)
139

140

141 5
class _ExecParser:
142 5
    def __init__(self, queue, tty=False) -> None:
143 5
        self.queue = queue
144 5
        self.tty = tty
145 5
        self.header_fmt = struct.Struct(">BxxxL")
146 5
        self._buf = bytearray()
147

148 5
    def set_exception(self, exc: BaseException) -> None:
149 0
        self.queue.set_exception(exc)
150

151 5
    def feed_eof(self) -> None:
152 5
        self.queue.feed_eof()
153

154 5
    def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
155 5
        if self.tty:
156 4
            msg = Message(1, data)  # stdout
157 4
            self.queue.feed_data(msg, len(data))
158
        else:
159 5
            self._buf.extend(data)
160 5
            while self._buf:
161
                # Parse the header
162 5
                if len(self._buf) < self.header_fmt.size:
163 0
                    return False, b""
164 5
                fileno, msglen = self.header_fmt.unpack(
165
                    self._buf[: self.header_fmt.size]
166
                )
167 5
                msg_and_header = self.header_fmt.size + msglen
168 5
                if len(self._buf) < msg_and_header:
169 0
                    return False, b""
170 5
                msg = Message(
171
                    fileno, bytes(self._buf[self.header_fmt.size : msg_and_header])
172
                )
173 5
                self.queue.feed_data(msg, msglen)
174 5
                del self._buf[:msg_and_header]
175 5
        return False, b""

Read our documentation on viewing source code .

Loading