1
# coding: utf-8
2

3
# Little utilities we use internally
4

5 27
from abc import ABCMeta
6 27
import os
7 27
import signal
8 27
import sys
9 27
import pathlib
10 27
from functools import wraps, update_wrapper
11 27
import typing as t
12 27
import threading
13 27
import collections
14

15 27
from async_generator import isasyncgen
16

17 27
from ._deprecate import warn_deprecated
18

19 27
import trio
20

21
# Equivalent to the C function raise(), which Python doesn't wrap
22 27
if os.name == "nt":
23
    # On windows, os.kill exists but is really weird.
24
    #
25
    # If you give it CTRL_C_EVENT or CTRL_BREAK_EVENT, it tries to deliver
26
    # those using GenerateConsoleCtrlEvent. But I found that when I tried
27
    # to run my test normally, it would freeze waiting... unless I added
28
    # print statements, in which case the test suddenly worked. So I guess
29
    # these signals are only delivered if/when you access the console? I
30
    # don't really know what was going on there. From reading the
31
    # GenerateConsoleCtrlEvent docs I don't know how it worked at all.
32
    #
33
    # I later spent a bunch of time trying to make GenerateConsoleCtrlEvent
34
    # work for creating synthetic control-C events, and... failed
35
    # utterly. There are lots of details in the code and comments
36
    # removed/added at this commit:
37
    #     https://github.com/python-trio/trio/commit/95843654173e3e826c34d70a90b369ba6edf2c23
38
    #
39
    # OTOH, if you pass os.kill any *other* signal number... then CPython
40
    # just calls TerminateProcess (wtf).
41
    #
42
    # So, anyway, os.kill is not so useful for testing purposes. Instead
43
    # we use raise():
44
    #
45
    #   https://msdn.microsoft.com/en-us/library/dwwzkt4c.aspx
46
    #
47
    # Have to import cffi inside the 'if os.name' block because we don't
48
    # depend on cffi on non-Windows platforms. (It would be easy to switch
49
    # this to ctypes though if we ever remove the cffi dependency.)
50
    #
51
    # Some more information:
52
    #   https://bugs.python.org/issue26350
53
    #
54
    # Anyway, we use this for two things:
55
    # - redelivering unhandled signals
56
    # - generating synthetic signals for tests
57
    # and for both of those purposes, 'raise' works fine.
58 11
    import cffi
59

60 11
    _ffi = cffi.FFI()
61 11
    _ffi.cdef("int raise(int);")
62 11
    _lib = _ffi.dlopen("api-ms-win-crt-runtime-l1-1-0.dll")
63 11
    signal_raise = getattr(_lib, "raise")
64
else:
65

66 16
    def signal_raise(signum):
67 16
        signal.pthread_kill(threading.get_ident(), signum)
68

69

70
# See: #461 as to why this is needed.
71
# The gist is that threading.main_thread() has the capability to lie to us
72
# if somebody else edits the threading ident cache to replace the main
73
# thread; causing threading.current_thread() to return a _DummyThread,
74
# causing the C-c check to fail, and so on.
75
# Trying to use signal out of the main thread will fail, so we can then
76
# reliably check if this is the main thread without relying on a
77
# potentially modified threading.
78 27
def is_main_thread():
79
    """Attempt to reliably check if we are in the main thread."""
80 27
    try:
81 27
        signal.signal(signal.SIGINT, signal.getsignal(signal.SIGINT))
82 27
        return True
83 27
    except ValueError:
84 27
        return False
85

86

87
######
88
# Call the function and get the coroutine object, while giving helpful
89
# errors for common mistakes. Returns coroutine object.
90
######
91 27
def coroutine_or_error(async_fn, *args):
92 27
    def _return_value_looks_like_wrong_library(value):
93
        # Returned by legacy @asyncio.coroutine functions, which includes
94
        # a surprising proportion of asyncio builtins.
95 27
        if isinstance(value, collections.abc.Generator):
96 27
            return True
97
        # The protocol for detecting an asyncio Future-like object
98 27
        if getattr(value, "_asyncio_future_blocking", None) is not None:
99 27
            return True
100
        # This janky check catches tornado Futures and twisted Deferreds.
101
        # By the time we're calling this function, we already know
102
        # something has gone wrong, so a heuristic is pretty safe.
103 27
        if value.__class__.__name__ in ("Future", "Deferred"):
104 27
            return True
105 27
        return False
106

107 27
    try:
108 27
        coro = async_fn(*args)
109

110 27
    except TypeError:
111
        # Give good error for: nursery.start_soon(trio.sleep(1))
112 27
        if isinstance(async_fn, collections.abc.Coroutine):
113
            # explicitly close coroutine to avoid RuntimeWarning
114 27
            async_fn.close()
115

116 27
            raise TypeError(
117
                "Trio was expecting an async function, but instead it got "
118
                "a coroutine object {async_fn!r}\n"
119
                "\n"
120
                "Probably you did something like:\n"
121
                "\n"
122
                "  trio.run({async_fn.__name__}(...))            # incorrect!\n"
123
                "  nursery.start_soon({async_fn.__name__}(...))  # incorrect!\n"
124
                "\n"
125
                "Instead, you want (notice the parentheses!):\n"
126
                "\n"
127
                "  trio.run({async_fn.__name__}, ...)            # correct!\n"
128
                "  nursery.start_soon({async_fn.__name__}, ...)  # correct!".format(
129
                    async_fn=async_fn
130
                )
131
            ) from None
132

133
        # Give good error for: nursery.start_soon(future)
134 27
        if _return_value_looks_like_wrong_library(async_fn):
135 27
            raise TypeError(
136
                "Trio was expecting an async function, but instead it got "
137
                "{!r} – are you trying to use a library written for "
138
                "asyncio/twisted/tornado or similar? That won't work "
139
                "without some sort of compatibility shim.".format(async_fn)
140
            ) from None
141

142 27
        raise
143

144
    # We can't check iscoroutinefunction(async_fn), because that will fail
145
    # for things like functools.partial objects wrapping an async
146
    # function. So we have to just call it and then check whether the
147
    # return value is a coroutine object.
148 27
    if not isinstance(coro, collections.abc.Coroutine):
149
        # Give good error for: nursery.start_soon(func_returning_future)
150 27
        if _return_value_looks_like_wrong_library(coro):
151 27
            raise TypeError(
152
                "Trio got unexpected {!r} – are you trying to use a "
153
                "library written for asyncio/twisted/tornado or similar? "
154
                "That won't work without some sort of compatibility shim.".format(coro)
155
            )
156

157 27
        if isasyncgen(coro):
158 27
            raise TypeError(
159
                "start_soon expected an async function but got an async "
160
                "generator {!r}".format(coro)
161
            )
162

163
        # Give good error for: nursery.start_soon(some_sync_fn)
164 27
        raise TypeError(
165
            "Trio expected an async function, but {!r} appears to be "
166
            "synchronous".format(getattr(async_fn, "__qualname__", async_fn))
167
        )
168

169 27
    return coro
170

171

172 27
class ConflictDetector:
173
    """Detect when two tasks are about to perform operations that would
174
    conflict.
175

176
    Use as a synchronous context manager; if two tasks enter it at the same
177
    time then the second one raises an error. You can use it when there are
178
    two pieces of code that *would* collide and need a lock if they ever were
179
    called at the same time, but that should never happen.
180

181
    We use this in particular for things like, making sure that two different
182
    tasks don't call sendall simultaneously on the same stream.
183

184
    """
185

186 27
    def __init__(self, msg):
187 27
        self._msg = msg
188 27
        self._held = False
189

190 27
    def __enter__(self):
191 27
        if self._held:
192 27
            raise trio.BusyResourceError(self._msg)
193
        else:
194 27
            self._held = True
195

196 27
    def __exit__(self, *args):
197 27
        self._held = False
198

199

200 27
def async_wraps(cls, wrapped_cls, attr_name):
201
    """Similar to wraps, but for async wrappers of non-async functions."""
202

203 27
    def decorator(func):
204 27
        func.__name__ = attr_name
205 27
        func.__qualname__ = ".".join((cls.__qualname__, attr_name))
206

207 27
        func.__doc__ = """Like :meth:`~{}.{}.{}`, but async.
208

209
        """.format(
210
            wrapped_cls.__module__, wrapped_cls.__qualname__, attr_name
211
        )
212

213 27
        return func
214

215 27
    return decorator
216

217

218 27
def fixup_module_metadata(module_name, namespace):
219 27
    seen_ids = set()
220

221 27
    def fix_one(qualname, name, obj):
222
        # avoid infinite recursion (relevant when using
223
        # typing.Generic, for example)
224 27
        if id(obj) in seen_ids:
225 27
            return
226 27
        seen_ids.add(id(obj))
227

228 27
        mod = getattr(obj, "__module__", None)
229 27
        if mod is not None and mod.startswith("trio."):
230 27
            obj.__module__ = module_name
231
            # Modules, unlike everything else in Python, put fully-qualitied
232
            # names into their __name__ attribute. We check for "." to avoid
233
            # rewriting these.
234 27
            if hasattr(obj, "__name__") and "." not in obj.__name__:
235 27
                obj.__name__ = name
236 27
                obj.__qualname__ = qualname
237 27
            if isinstance(obj, type):
238 27
                for attr_name, attr_value in obj.__dict__.items():
239 27
                    fix_one(objname + "." + attr_name, attr_name, attr_value)
240

241 27
    for objname, obj in namespace.items():
242 27
        if not objname.startswith("_"):  # ignore private attributes
243 27
            fix_one(objname, objname, obj)
244

245

246 27
class generic_function:
247
    """Decorator that makes a function indexable, to communicate
248
    non-inferrable generic type parameters to a static type checker.
249

250
    If you write::
251

252
        @generic_function
253
        def open_memory_channel(max_buffer_size: int) -> Tuple[
254
            SendChannel[T], ReceiveChannel[T]
255
        ]: ...
256

257
    it is valid at runtime to say ``open_memory_channel[bytes](5)``.
258
    This behaves identically to ``open_memory_channel(5)`` at runtime,
259
    and currently won't type-check without a mypy plugin or clever stubs,
260
    but at least it becomes possible to write those.
261
    """
262

263 27
    def __init__(self, fn):
264 27
        update_wrapper(self, fn)
265 27
        self._fn = fn
266

267 27
    def __call__(self, *args, **kwargs):
268 27
        return self._fn(*args, **kwargs)
269

270 27
    def __getitem__(self, _):
271 27
        return self
272

273

274
# If a new class inherits from any ABC, then the new class's metaclass has to
275
# inherit from ABCMeta. If a new class inherits from typing.Generic, and
276
# you're using Python 3.6, then the new class's metaclass has to
277
# inherit from typing.GenericMeta. Some of the classes that want to use Final
278
# or NoPublicConstructor inherit from ABCs and generics, so Final has to
279
# inherit from these metaclasses. Fortunately, GenericMeta inherits from
280
# ABCMeta, so inheriting from GenericMeta alone is sufficient (when it
281
# exists at all).
282 27
if not t.TYPE_CHECKING and hasattr(t, "GenericMeta"):
283 8
    BaseMeta = t.GenericMeta
284
else:
285 19
    BaseMeta = ABCMeta
286

287

288 27
class Final(BaseMeta):
289
    """Metaclass that enforces a class to be final (i.e., subclass not allowed).
290

291
    If a class uses this metaclass like this::
292

293
        class SomeClass(metaclass=Final):
294
            pass
295

296
    The metaclass will ensure that no sub class can be created.
297

298
    Raises
299
    ------
300
    - TypeError if a sub class is created
301
    """
302

303 27
    def __new__(cls, name, bases, cls_namespace):
304 27
        for base in bases:
305 27
            if isinstance(base, Final):
306 27
                raise TypeError(
307
                    f"{base.__module__}.{base.__qualname__} does not support subclassing"
308
                )
309 27
        return super().__new__(cls, name, bases, cls_namespace)
310

311

312 27
T = t.TypeVar("T")
313

314

315 27
class NoPublicConstructor(Final):
316
    """Metaclass that enforces a class to be final (i.e., subclass not allowed)
317
    and ensures a private constructor.
318

319
    If a class uses this metaclass like this::
320

321
        class SomeClass(metaclass=NoPublicConstructor):
322
            pass
323

324
    The metaclass will ensure that no sub class can be created, and that no instance
325
    can be initialized.
326

327
    If you try to instantiate your class (SomeClass()), a TypeError will be thrown.
328

329
    Raises
330
    ------
331
    - TypeError if a sub class or an instance is created.
332
    """
333

334 27
    def __call__(cls, *args, **kwargs):
335 27
        raise TypeError(
336
            f"{cls.__module__}.{cls.__qualname__} has no public constructor"
337
        )
338

339 27
    def _create(cls: t.Type[T], *args: t.Any, **kwargs: t.Any) -> T:
340 27
        return super().__call__(*args, **kwargs)  # type: ignore
341

342

343 27
def name_asyncgen(agen):
344
    """Return the fully-qualified name of the async generator function
345
    that produced the async generator iterator *agen*.
346
    """
347
    if not hasattr(agen, "ag_code"):  # pragma: no cover
348
        return repr(agen)
349 27
    try:
350 27
        module = agen.ag_frame.f_globals["__name__"]
351 27
    except (AttributeError, KeyError):
352 27
        module = "<{}>".format(agen.ag_code.co_filename)
353 27
    try:
354 27
        qualname = agen.__qualname__
355 3
    except AttributeError:
356 3
        qualname = agen.ag_code.co_name
357 27
    return f"{module}.{qualname}"

Read our documentation on viewing source code .

Loading