python-trio / trio
1
# coding: utf-8
2

3
# Little utilities we use internally
4

5 22
from abc import ABCMeta
6 22
import os
7 22
import signal
8 22
import sys
9 22
import pathlib
10 22
from functools import wraps, update_wrapper
11 22
import typing as t
12 22
import threading
13 22
import collections
14

15 22
from async_generator import isasyncgen
16

17 22
from ._deprecate import warn_deprecated
18

19 22
import trio
20

21
# Equivalent to the C function raise(), which Python doesn't wrap
22 22
if os.name == "nt":
23
    # On windows, os.kill exists but is really weird.
24
    #
25
    # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
26
    # those using GenerateConsoleCtrlEvent. But I found that when I tried
27
    # to run my test normally, it would freeze waiting... unless I added
28
    # print statements, in which case the test suddenly worked. So I guess
29
    # these signals are only delivered if/when you access the console? I
30
    # don't really know what was going on there. From reading the
31
    # GenerateConsoleCtrlEvent docs I don't know how it worked at all.
32
    #
33
    # I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
34
    # work for creating synthetic control-C events, and... failed
35
    # utterly. There are lots of details in the code and comments
36
    # removed/added at this commit:
37
    #     https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
38
    #
39
    # OTOH, if you pass os.kill any *other* signal number... then CPython
40
    # just calls TerminateProcess (wtf).
41
    #
42
    # So, anyway, os.kill is not so useful for testing purposes. Instead
43
    # we use raise():
44
    #
45
    #   https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
46
    #
47
    # Have to import cffi inside the 'if os.name' block because we don't
48
    # depend on cffi on non-Windows platforms. (It would be easy to switch
49
    # this to ctypes though if we ever remove the cffi dependency.)
50
    #
51
    # Some more information:
52
    #   https://bugs.python.org/issue26350
53
    #
54
    # Anyway, we use this for two things:
55
    # - redelivering unhandled signals
56
    # - generating synthetic signals for tests
57
    # and for both of those purposes, 'raise' works fine.
58 9
    import cffi
59

60 9
    _ffi = cffi.FFI()
61 9
    _ffi.cdef("int raise(int);")
62 9
    _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
63 9
    signal_raise = getattr(_lib, "raise")
64
else:
65

66 13
    def signal_raise(signum):
67 13
        signal.pthread_kill(threading.get_ident(), signum)
68

69

70
# See: #461 as to why this is needed.
71
# The gist is that threading.main_thread() has the capability to lie to us
72
# if somebody else edits the threading ident cache to replace the main
73
# thread; causing threading.current_thread() to return a _DummyThread,
74
# causing the C-c check to fail, and so on.
75
# Trying to use signal out of the main thread will fail, so we can then
76
# reliably check if this is the main thread without relying on a
77
# potentially modified threading.
78 22
def is_main_thread():
79
    """Attempt to reliably check if we are in the main thread."""
80 22
    try:
81 22
        signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
82 22
        return True
83 22
    except ValueError:
84 22
        return False
85

86

87
######
88
# Call the function and get the coroutine object, while giving helpful
89
# errors for common mistakes. Returns coroutine object.
90
######
91 22
def coroutine_or_error(async_fn, *args):
92 22
    def _return_value_looks_like_wrong_library(value):
93
        # Returned by legacy @asyncio.coroutine functions, which includes
94
        # a surprising proportion of asyncio builtins.
95 22
        if isinstance(value, collections.abc.Generator):
96 22
            return True
97
        # The protocol for detecting an asyncio Future-like object
98 22
        if getattr(value, "_asyncio_future_blocking", None) is not None:
99 22
            return True
100
        # This janky check catches tornado Futures and twisted Deferreds.
101
        # By the time we're calling this function, we already know
102
        # something has gone wrong, so a heuristic is pretty safe.
103 22
        if value.__class__.__name__ in ("Future", "Deferred"):
104 22
            return True
105 22
        return False
106

107 22
    try:
108 22
        coro = async_fn(*args)
109

110 22
    except TypeError:
111
        # Give good error for: nursery.start_soon(trio.sleep(1))
112 22
        if isinstance(async_fn, collections.abc.Coroutine):
113
            # explicitly close coroutine to avoid RuntimeWarning
114 22
            async_fn.close()
115

116 22
            raise TypeError(
117
                "Trio was expecting an async function, but instead it got "
118
                "a coroutine object {async_fn!r}\n"
119
                "\n"
120
                "Probably you did something like:\n"
121
                "\n"
122
                "  trio.run({async_fn.__name__}(...))            # incorrect!\n"
123
                "  nursery.start_soon({async_fn.__name__}(...))  # incorrect!\n"
124
                "\n"
125
                "Instead, you want (notice the parentheses!):\n"
126
                "\n"
127
                "  trio.run({async_fn.__name__}, ...)            # correct!\n"
128
                "  nursery.start_soon({async_fn.__name__}, ...)  # correct!".format(
129
                    async_fn=async_fn
130
                )
131
            ) from None
132

133
        # Give good error for: nursery.start_soon(future)
134 22
        if _return_value_looks_like_wrong_library(async_fn):
135 22
            raise TypeError(
136
                "Trio was expecting an async function, but instead it got "
137
                "{!r} – are you trying to use a library written for "
138
                "asyncio/twisted/tornado or similar? That won't work "
139
                "without some sort of compatibility shim.".format(async_fn)
140
            ) from None
141

142 22
        raise
143

144
    # We can't check iscoroutinefunction(async_fn), because that will fail
145
    # for things like functools.partial objects wrapping an async
146
    # function. So we have to just call it and then check whether the
147
    # return value is a coroutine object.
148 22
    if not isinstance(coro, collections.abc.Coroutine):
149
        # Give good error for: nursery.start_soon(func_returning_future)
150 22
        if _return_value_looks_like_wrong_library(coro):
151 22
            raise TypeError(
152
                "Trio got unexpected {!r} – are you trying to use a "
153
                "library written for asyncio/twisted/tornado or similar? "
154
                "That won't work without some sort of compatibility shim.".format(coro)
155
            )
156

157 22
        if isasyncgen(coro):
158 22
            raise TypeError(
159
                "start_soon expected an async function but got an async "
160
                "generator {!r}".format(coro)
161
            )
162

163
        # Give good error for: nursery.start_soon(some_sync_fn)
164 22
        raise TypeError(
165
            "Trio expected an async function, but {!r} appears to be "
166
            "synchronous".format(getattr(async_fn, "__qualname__", async_fn))
167
        )
168

169 22
    return coro
170

171

172 22
class ConflictDetector:
173
    """Detect when two tasks are about to perform operations that would
174
    conflict.
175

176
    Use as a synchronous context manager; if two tasks enter it at the same
177
    time then the second one raises an error. You can use it when there are
178
    two pieces of code that *would* collide and need a lock if they ever were
179
    called at the same time, but that should never happen.
180

181
    We use this in particular for things like, making sure that two different
182
    tasks don't call sendall simultaneously on the same stream.
183

184
    """
185

186 22
    def __init__(self, msg):
187 22
        self._msg = msg
188 22
        self._held = False
189

190 22
    def __enter__(self):
191 22
        if self._held:
192 22
            raise trio.BusyResourceError(self._msg)
193
        else:
194 22
            self._held = True
195

196 22
    def __exit__(self, *args):
197 22
        self._held = False
198

199

200 22
def async_wraps(cls, wrapped_cls, attr_name):
201
    """Similar to wraps, but for async wrappers of non-async functions."""
202

203 22
    def decorator(func):
204 22
        func.__name__ = attr_name
205 22
        func.__qualname__ = ".".join((cls.__qualname__, attr_name))
206

207 22
        func.__doc__ = """Like :meth:`~{}.{}.{}`, but async.
208

209
        """.format(
210
            wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name
211
        )
212

213 22
        return func
214

215 22
    return decorator
216

217

218 22
def fixup_module_metadata(module_name, namespace):
219 22
    seen_ids = set()
220

221 22
    def fix_one(qualname, name, obj):
222
        # avoid infinite recursion (relevant when using
223
        # typing.Generic, for example)
224 22
        if id(obj) in seen_ids:
225 22
            return
226 22
        seen_ids.add(id(obj))
227

228 22
        mod = getattr(obj, "__module__", None)
229 22
        if mod is not None and mod.startswith("trio."):
230 22
            obj.__module__ = module_name
231
            # Modules, unlike everything else in Python, put fully-qualitied
232
            # names into their __name__ attribute. We check for "." to avoid
233
            # rewriting these.
234 22
            if hasattr(obj, "__name__") and "." not in obj.__name__:
235 22
                obj.__name__ = name
236 22
                obj.__qualname__ = qualname
237 22
            if isinstance(obj, type):
238 22
                for attr_name, attr_value in obj.__dict__.items():
239 22
                    fix_one(objname + "." + attr_name, attr_name, attr_value)
240

241 22
    for objname, obj in namespace.items():
242 22
        if not objname.startswith("_"):  # ignore private attributes
243 22
            fix_one(objname, objname, obj)
244

245

246 22
class generic_function:
247
    """Decorator that makes a function indexable, to communicate
248
    non-inferrable generic type parameters to a static type checker.
249

250
    If you write::
251

252
        @generic_function
253
        def open_memory_channel(max_buffer_size: int) -> Tuple[
254
            SendChannel[T], ReceiveChannel[T]
255
        ]: ...
256

257
    it is valid at runtime to say ``open_memory_channel[bytes](5)``.
258
    This behaves identically to ``open_memory_channel(5)`` at runtime,
259
    and currently won't type-check without a mypy plugin or clever stubs,
260
    but at least it becomes possible to write those.
261
    """
262

263 22
    def __init__(self, fn):
264 22
        update_wrapper(self, fn)
265 22
        self._fn = fn
266

267 22
    def __call__(self, *args, **kwargs):
268 22
        return self._fn(*args, **kwargs)
269

270 22
    def __getitem__(self, _):
271 22
        return self
272

273

274
# If a new class inherits from any ABC, then the new class's metaclass has to
275
# inherit from ABCMeta. If a new class inherits from typing.Generic, and
276
# you're using Python 3.6, then the new class's metaclass has to
277
# inherit from typing.GenericMeta. Some of the classes that want to use Final
278
# or NoPublicConstructor inherit from ABCs and generics, so Final has to
279
# inherit from these metaclasses. Fortunately, GenericMeta inherits from
280
# ABCMeta, so inheriting from GenericMeta alone is sufficient (when it
281
# exists at all).
282 22
if not t.TYPE_CHECKING and hasattr(t, "GenericMeta"):
283 8
    BaseMeta = t.GenericMeta
284
else:
285 14
    BaseMeta = ABCMeta
286

287

288 22
class Final(BaseMeta):
289
    """Metaclass that enforces a class to be final (i.e., subclass not allowed).
290

291
    If a class uses this metaclass like this::
292

293
        class SomeClass(metaclass=Final):
294
            pass
295

296
    The metaclass will ensure that no sub class can be created.
297

298
    Raises
299
    ------
300
    - TypeError if a sub class is created
301
    """
302

303 22
    def __new__(cls, name, bases, cls_namespace):
304 22
        for base in bases:
305 22
            if isinstance(base, Final):
306 22
                raise TypeError(
307
                    f"{base.__module__}.{base.__qualname__} does not support subclassing"
308
                )
309 22
        return super().__new__(cls, name, bases, cls_namespace)
310

311

312 22
class SubclassingDeprecatedIn_v0_15_0(BaseMeta):
313 22
    def __new__(cls, name, bases, cls_namespace):
314 22
        for base in bases:
315 22
            if isinstance(base, SubclassingDeprecatedIn_v0_15_0):
316 22
                warn_deprecated(
317
                    f"subclassing {base.__module__}.{base.__qualname__}",
318
                    "0.15.0",
319
                    issue=1044,
320
                    instead="composition or delegation",
321
                )
322 22
                break
323 22
        return super().__new__(cls, name, bases, cls_namespace)
324

325

326 22
T = t.TypeVar("T")
327

328

329 22
class NoPublicConstructor(Final):
330
    """Metaclass that enforces a class to be final (i.e., subclass not allowed)
331
    and ensures a private constructor.
332

333
    If a class uses this metaclass like this::
334

335
        class SomeClass(metaclass=NoPublicConstructor):
336
            pass
337

338
    The metaclass will ensure that no sub class can be created, and that no instance
339
    can be initialized.
340

341
    If you try to instantiate your class (SomeClass()), a TypeError will be thrown.
342

343
    Raises
344
    ------
345
    - TypeError if a sub class or an instance is created.
346
    """
347

348 22
    def __call__(cls, *args, **kwargs):
349 22
        raise TypeError(
350
            f"{cls.__module__}.{cls.__qualname__} has no public constructor"
351
        )
352

353 22
    def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T:
354 22
        return super().__call__(*args, **kwargs)  # type: ignore
355

356

357 22
def name_asyncgen(agen):
358
    """Return the fully-qualified name of the async generator function
359
    that produced the async generator iterator *agen*.
360
    """
361
    if not hasattr(agen, "ag_code"):  # pragma: no cover
362
        return repr(agen)
363 22
    try:
364 22
        module = agen.ag_frame.f_globals["__name__"]
365 22
    except (AttributeError, KeyError):
366 22
        module = "<{}>".format(agen.ag_code.co_filename)
367 22
    try:
368 22
        qualname = agen.__qualname__
369 2
    except AttributeError:
370 2
        qualname = agen.ag_code.co_name
371 22
    return f"{module}.{qualname}"

Read our documentation on viewing source code .

Loading