1 22
import threading
2 22
import queue as stdlib_queue
3 22
import time
4

5 22
import pytest
6

7 22
from .. import _core
8 22
from .. import Event, CapacityLimiter, sleep
9 22
from ..testing import wait_all_tasks_blocked
10 22
from .._threads import (
11
    to_thread_run_sync,
12
    current_default_thread_limiter,
13
    from_thread_run,
14
    from_thread_run_sync,
15
)
16

17 22
from .._core.tests.test_ki import ki_self
18

19

20 22
async def test_do_in_trio_thread():
21 22
    trio_thread = threading.current_thread()
22

23 22
    async def check_case(do_in_trio_thread, fn, expected, trio_token=None):
24 22
        record = []
25

26 22
        def threadfn():
27 22
            try:
28 22
                record.append(("start", threading.current_thread()))
29 22
                x = do_in_trio_thread(fn, record, trio_token=trio_token)
30 22
                record.append(("got", x))
31 22
            except BaseException as exc:
32 22
                print(exc)
33 22
                record.append(("error", type(exc)))
34

35 22
        child_thread = threading.Thread(target=threadfn, daemon=True)
36 22
        child_thread.start()
37 22
        while child_thread.is_alive():
38 22
            print("yawn")
39 22
            await sleep(0.01)
40 22
        assert record == [("start", child_thread), ("f", trio_thread), expected]
41

42 22
    token = _core.current_trio_token()
43

44 22
    def f(record):
45 22
        assert not _core.currently_ki_protected()
46 22
        record.append(("f", threading.current_thread()))
47 22
        return 2
48

49 22
    await check_case(from_thread_run_sync, f, ("got", 2), trio_token=token)
50

51 22
    def f(record):
52 22
        assert not _core.currently_ki_protected()
53 22
        record.append(("f", threading.current_thread()))
54 22
        raise ValueError
55

56 22
    await check_case(from_thread_run_sync, f, ("error", ValueError), trio_token=token)
57

58 22
    async def f(record):
59 22
        assert not _core.currently_ki_protected()
60 22
        await _core.checkpoint()
61 22
        record.append(("f", threading.current_thread()))
62 22
        return 3
63

64 22
    await check_case(from_thread_run, f, ("got", 3), trio_token=token)
65

66 22
    async def f(record):
67 22
        assert not _core.currently_ki_protected()
68 22
        await _core.checkpoint()
69 22
        record.append(("f", threading.current_thread()))
70 22
        raise KeyError
71

72 22
    await check_case(from_thread_run, f, ("error", KeyError), trio_token=token)
73

74

75 22
async def test_do_in_trio_thread_from_trio_thread():
76 22
    with pytest.raises(RuntimeError):
77 22
        from_thread_run_sync(lambda: None)  # pragma: no branch
78

79
    async def foo():  # pragma: no cover
80
        pass
81

82 22
    with pytest.raises(RuntimeError):
83 22
        from_thread_run(foo)
84

85

86 22
def test_run_in_trio_thread_ki():
87
    # if we get a control-C during a run_in_trio_thread, then it propagates
88
    # back to the caller (slick!)
89 22
    record = set()
90

91 22
    async def check_run_in_trio_thread():
92 22
        token = _core.current_trio_token()
93

94 22
        def trio_thread_fn():
95 22
            print("in Trio thread")
96 22
            assert not _core.currently_ki_protected()
97 22
            print("ki_self")
98 22
            try:
99 22
                ki_self()
100
            finally:
101 22
                import sys
102

103 22
                print("finally", sys.exc_info())
104

105 22
        async def trio_thread_afn():
106 22
            trio_thread_fn()
107

108 22
        def external_thread_fn():
109 22
            try:
110 22
                print("running")
111 22
                from_thread_run_sync(trio_thread_fn, trio_token=token)
112 22
            except KeyboardInterrupt:
113 22
                print("ok1")
114 22
                record.add("ok1")
115 22
            try:
116 22
                from_thread_run(trio_thread_afn, trio_token=token)
117 22
            except KeyboardInterrupt:
118 22
                print("ok2")
119 22
                record.add("ok2")
120

121 22
        thread = threading.Thread(target=external_thread_fn)
122 22
        thread.start()
123 22
        print("waiting")
124 22
        while thread.is_alive():
125 22
            await sleep(0.01)
126 22
        print("waited, joining")
127 22
        thread.join()
128 22
        print("done")
129

130 22
    _core.run(check_run_in_trio_thread)
131 22
    assert record == {"ok1", "ok2"}
132

133

134 22
def test_await_in_trio_thread_while_main_exits():
135 22
    record = []
136 22
    ev = Event()
137

138 22
    async def trio_fn():
139 22
        record.append("sleeping")
140 22
        ev.set()
141 22
        await _core.wait_task_rescheduled(lambda _: _core.Abort.SUCCEEDED)
142

143 22
    def thread_fn(token):
144 22
        try:
145 22
            from_thread_run(trio_fn, trio_token=token)
146 22
        except _core.Cancelled:
147 22
            record.append("cancelled")
148

149 22
    async def main():
150 22
        token = _core.current_trio_token()
151 22
        thread = threading.Thread(target=thread_fn, args=(token,))
152 22
        thread.start()
153 22
        await ev.wait()
154 22
        assert record == ["sleeping"]
155 22
        return thread
156

157 22
    thread = _core.run(main)
158 22
    thread.join()
159 22
    assert record == ["sleeping", "cancelled"]
160

161

162 22
async def test_run_in_worker_thread():
163 22
    trio_thread = threading.current_thread()
164

165 22
    def f(x):
166 22
        return (x, threading.current_thread())
167

168 22
    x, child_thread = await to_thread_run_sync(f, 1)
169 22
    assert x == 1
170 22
    assert child_thread != trio_thread
171

172 22
    def g():
173 22
        raise ValueError(threading.current_thread())
174

175 22
    with pytest.raises(ValueError) as excinfo:
176 22
        await to_thread_run_sync(g)
177 22
    print(excinfo.value.args)
178 22
    assert excinfo.value.args[0] != trio_thread
179

180

181 22
async def test_run_in_worker_thread_cancellation():
182 22
    register = [None]
183

184 22
    def f(q):
185
        # Make the thread block for a controlled amount of time
186 22
        register[0] = "blocking"
187 22
        q.get()
188 22
        register[0] = "finished"
189

190 22
    async def child(q, cancellable):
191 22
        record.append("start")
192 22
        try:
193 22
            return await to_thread_run_sync(f, q, cancellable=cancellable)
194
        finally:
195 22
            record.append("exit")
196

197 22
    record = []
198 22
    q = stdlib_queue.Queue()
199 22
    async with _core.open_nursery() as nursery:
200 22
        nursery.start_soon(child, q, True)
201
        # Give it a chance to get started. (This is important because
202
        # to_thread_run_sync does a checkpoint_if_cancelled before
203
        # blocking on the thread, and we don't want to trigger this.)
204 22
        await wait_all_tasks_blocked()
205 22
        assert record == ["start"]
206
        # Then cancel it.
207 22
        nursery.cancel_scope.cancel()
208
    # The task exited, but the thread didn't:
209 22
    assert register[0] != "finished"
210
    # Put the thread out of its misery:
211 22
    q.put(None)
212 22
    while register[0] != "finished":
213 22
        time.sleep(0.01)
214

215
    # This one can't be cancelled
216 22
    record = []
217 22
    register[0] = None
218 22
    async with _core.open_nursery() as nursery:
219 22
        nursery.start_soon(child, q, False)
220 22
        await wait_all_tasks_blocked()
221 22
        nursery.cancel_scope.cancel()
222 22
        with _core.CancelScope(shield=True):
223 22
            for _ in range(10):
224 22
                await _core.checkpoint()
225
        # It's still running
226 22
        assert record == ["start"]
227 22
        q.put(None)
228
        # Now it exits
229

230
    # But if we cancel *before* it enters, the entry is itself a cancellation
231
    # point
232 22
    with _core.CancelScope() as scope:
233 22
        scope.cancel()
234 22
        await child(q, False)
235 22
    assert scope.cancelled_caught
236

237

238
# Make sure that if trio.run exits, and then the thread finishes, then that's
239
# handled gracefully. (Requires that the thread result machinery be prepared
240
# for call_soon to raise RunFinishedError.)
241 22
def test_run_in_worker_thread_abandoned(capfd, monkeypatch):
242 22
    monkeypatch.setattr(_core._thread_cache, "IDLE_TIMEOUT", 0.01)
243

244 22
    q1 = stdlib_queue.Queue()
245 22
    q2 = stdlib_queue.Queue()
246

247 22
    def thread_fn():
248 22
        q1.get()
249 22
        q2.put(threading.current_thread())
250

251 22
    async def main():
252 22
        async def child():
253 22
            await to_thread_run_sync(thread_fn, cancellable=True)
254

255 22
        async with _core.open_nursery() as nursery:
256 22
            nursery.start_soon(child)
257 22
            await wait_all_tasks_blocked()
258 22
            nursery.cancel_scope.cancel()
259

260 22
    _core.run(main)
261

262 22
    q1.put(None)
263
    # This makes sure:
264
    # - the thread actually ran
265
    # - that thread has finished before we check for its output
266 22
    thread = q2.get()
267 22
    while thread.is_alive():
268
        time.sleep(0.01)  # pragma: no cover
269

270
    # Make sure we don't have a "Exception in thread ..." dump to the console:
271 22
    out, err = capfd.readouterr()
272 22
    assert "Exception in thread" not in out
273 22
    assert "Exception in thread" not in err
274

275

276 22
@pytest.mark.parametrize("MAX", [3, 5, 10])
277 22
@pytest.mark.parametrize("cancel", [False, True])
278 22
@pytest.mark.parametrize("use_default_limiter", [False, True])
279 9
async def test_run_in_worker_thread_limiter(MAX, cancel, use_default_limiter):
280
    # This test is a bit tricky. The goal is to make sure that if we set
281
    # limiter=CapacityLimiter(MAX), then in fact only MAX threads are ever
282
    # running at a time, even if there are more concurrent calls to
283
    # to_thread_run_sync, and even if some of those are cancelled. And
284
    # also to make sure that the default limiter actually limits.
285 22
    COUNT = 2 * MAX
286 22
    gate = threading.Event()
287 22
    lock = threading.Lock()
288 22
    if use_default_limiter:
289 22
        c = current_default_thread_limiter()
290 22
        orig_total_tokens = c.total_tokens
291 22
        c.total_tokens = MAX
292 22
        limiter_arg = None
293
    else:
294 22
        c = CapacityLimiter(MAX)
295 22
        orig_total_tokens = MAX
296 22
        limiter_arg = c
297 22
    try:
298
        # We used to use regular variables and 'nonlocal' here, but it turns
299
        # out that it's not safe to assign to closed-over variables that are
300
        # visible in multiple threads, at least as of CPython 3.6 and PyPy
301
        # 5.8:
302
        #
303
        #   https://bugs.python.org/issue30744
304
        #   https://bitbucket.org/pypy/pypy/issues/2591/
305
        #
306
        # Mutating them in-place is OK though (as long as you use proper
307
        # locking etc.).
308 22
        class state:
309 22
            pass
310

311 22
        state.ran = 0
312 22
        state.high_water = 0
313 22
        state.running = 0
314 22
        state.parked = 0
315

316 22
        token = _core.current_trio_token()
317

318 22
        def thread_fn(cancel_scope):
319 22
            print("thread_fn start")
320 22
            from_thread_run_sync(cancel_scope.cancel, trio_token=token)
321 22
            with lock:
322 22
                state.ran += 1
323 22
                state.running += 1
324 22
                state.high_water = max(state.high_water, state.running)
325
                # The Trio thread below watches this value and uses it as a
326
                # signal that all the stats calculations have finished.
327 22
                state.parked += 1
328 22
            gate.wait()
329 22
            with lock:
330 22
                state.parked -= 1
331 22
                state.running -= 1
332 22
            print("thread_fn exiting")
333

334 22
        async def run_thread(event):
335 22
            with _core.CancelScope() as cancel_scope:
336 22
                await to_thread_run_sync(
337
                    thread_fn, cancel_scope, limiter=limiter_arg, cancellable=cancel
338
                )
339 22
            print("run_thread finished, cancelled:", cancel_scope.cancelled_caught)
340 22
            event.set()
341

342 22
        async with _core.open_nursery() as nursery:
343 22
            print("spawning")
344 22
            events = []
345 22
            for i in range(COUNT):
346 22
                events.append(Event())
347 22
                nursery.start_soon(run_thread, events[-1])
348 22
                await wait_all_tasks_blocked()
349
            # In the cancel case, we in particular want to make sure that the
350
            # cancelled tasks don't release the semaphore. So let's wait until
351
            # at least one of them has exited, and that everything has had a
352
            # chance to settle down from this, before we check that everyone
353
            # who's supposed to be waiting is waiting:
354 22
            if cancel:
355 22
                print("waiting for first cancellation to clear")
356 22
                await events[0].wait()
357 22
                await wait_all_tasks_blocked()
358
            # Then wait until the first MAX threads are parked in gate.wait(),
359
            # and the next MAX threads are parked on the semaphore, to make
360
            # sure no-one is sneaking past, and to make sure the high_water
361
            # check below won't fail due to scheduling issues. (It could still
362
            # fail if too many threads are let through here.)
363 22
            while state.parked != MAX or c.statistics().tasks_waiting != MAX:
364
                await sleep(0.01)  # pragma: no cover
365
            # Then release the threads
366 22
            gate.set()
367

368 22
        assert state.high_water == MAX
369

370 22
        if cancel:
371
            # Some threads might still be running; need to wait to them to
372
            # finish before checking that all threads ran. We can do this
373
            # using the CapacityLimiter.
374 22
            while c.borrowed_tokens > 0:
375
                await sleep(0.01)  # pragma: no cover
376

377 22
        assert state.ran == COUNT
378 22
        assert state.running == 0
379
    finally:
380 22
        c.total_tokens = orig_total_tokens
381

382

383 22
async def test_run_in_worker_thread_custom_limiter():
384
    # Basically just checking that we only call acquire_on_behalf_of and
385
    # release_on_behalf_of, since that's part of our documented API.
386 22
    record = []
387

388 22
    class CustomLimiter:
389 22
        async def acquire_on_behalf_of(self, borrower):
390 22
            record.append("acquire")
391 22
            self._borrower = borrower
392

393 22
        def release_on_behalf_of(self, borrower):
394 22
            record.append("release")
395 22
            assert borrower == self._borrower
396

397 22
    await to_thread_run_sync(lambda: None, limiter=CustomLimiter())
398 22
    assert record == ["acquire", "release"]
399

400

401 22
async def test_run_in_worker_thread_limiter_error():
402 22
    record = []
403

404 22
    class BadCapacityLimiter:
405 22
        async def acquire_on_behalf_of(self, borrower):
406 22
            record.append("acquire")
407

408 22
        def release_on_behalf_of(self, borrower):
409 22
            record.append("release")
410 22
            raise ValueError
411

412 22
    bs = BadCapacityLimiter()
413

414 22
    with pytest.raises(ValueError) as excinfo:
415 22
        await to_thread_run_sync(lambda: None, limiter=bs)
416 22
    assert excinfo.value.__context__ is None
417 22
    assert record == ["acquire", "release"]
418 22
    record = []
419

420
    # If the original function raised an error, then the semaphore error
421
    # chains with it
422 22
    d = {}
423 22
    with pytest.raises(ValueError) as excinfo:
424 22
        await to_thread_run_sync(lambda: d["x"], limiter=bs)
425 22
    assert isinstance(excinfo.value.__context__, KeyError)
426 22
    assert record == ["acquire", "release"]
427

428

429 22
async def test_run_in_worker_thread_fail_to_spawn(monkeypatch):
430
    # Test the unlikely but possible case where trying to spawn a thread fails
431 22
    def bad_start(self, *args):
432 22
        raise RuntimeError("the engines canna take it captain")
433

434 22
    monkeypatch.setattr(_core._thread_cache.ThreadCache, "start_thread_soon", bad_start)
435

436 22
    limiter = current_default_thread_limiter()
437 22
    assert limiter.borrowed_tokens == 0
438

439
    # We get an appropriate error, and the limiter is cleanly released
440 22
    with pytest.raises(RuntimeError) as excinfo:
441
        await to_thread_run_sync(lambda: None)  # pragma: no cover
442 22
    assert "engines" in str(excinfo.value)
443

444 22
    assert limiter.borrowed_tokens == 0
445

446

447 22
async def test_trio_to_thread_run_sync_token():
448
    # Test that to_thread_run_sync automatically injects the current trio token
449
    # into a spawned thread
450 22
    def thread_fn():
451 22
        callee_token = from_thread_run_sync(_core.current_trio_token)
452 22
        return callee_token
453

454 22
    caller_token = _core.current_trio_token()
455 22
    callee_token = await to_thread_run_sync(thread_fn)
456 22
    assert callee_token == caller_token
457

458

459 22
async def test_trio_to_thread_run_sync_expected_error():
460
    # Test correct error when passed async function
461
    async def async_fn():  # pragma: no cover
462
        pass
463

464 22
    with pytest.raises(TypeError, match="expected a sync function"):
465 22
        await to_thread_run_sync(async_fn)
466

467

468 22
async def test_trio_from_thread_run_sync():
469
    # Test that to_thread_run_sync correctly "hands off" the trio token to
470
    # trio.from_thread.run_sync()
471 22
    def thread_fn():
472 22
        trio_time = from_thread_run_sync(_core.current_time)
473 22
        return trio_time
474

475 22
    trio_time = await to_thread_run_sync(thread_fn)
476 22
    assert isinstance(trio_time, float)
477

478
    # Test correct error when passed async function
479
    async def async_fn():  # pragma: no cover
480
        pass
481

482 22
    def thread_fn():
483 22
        from_thread_run_sync(async_fn)
484

485 22
    with pytest.raises(TypeError, match="expected a sync function"):
486 22
        await to_thread_run_sync(thread_fn)
487

488

489 22
async def test_trio_from_thread_run():
490
    # Test that to_thread_run_sync correctly "hands off" the trio token to
491
    # trio.from_thread.run()
492 22
    record = []
493

494 22
    async def back_in_trio_fn():
495 22
        _core.current_time()  # implicitly checks that we're in trio
496 22
        record.append("back in trio")
497

498 22
    def thread_fn():
499 22
        record.append("in thread")
500 22
        from_thread_run(back_in_trio_fn)
501

502 22
    await to_thread_run_sync(thread_fn)
503 22
    assert record == ["in thread", "back in trio"]
504

505
    # Test correct error when passed sync function
506
    def sync_fn():  # pragma: no cover
507
        pass
508

509 22
    with pytest.raises(TypeError, match="appears to be synchronous"):
510 22
        await to_thread_run_sync(from_thread_run, sync_fn)
511

512

513 22
async def test_trio_from_thread_token():
514
    # Test that to_thread_run_sync and spawned trio.from_thread.run_sync()
515
    # share the same Trio token
516 22
    def thread_fn():
517 22
        callee_token = from_thread_run_sync(_core.current_trio_token)
518 22
        return callee_token
519

520 22
    caller_token = _core.current_trio_token()
521 22
    callee_token = await to_thread_run_sync(thread_fn)
522 22
    assert callee_token == caller_token
523

524

525 22
async def test_trio_from_thread_token_kwarg():
526
    # Test that to_thread_run_sync and spawned trio.from_thread.run_sync() can
527
    # use an explicitly defined token
528 22
    def thread_fn(token):
529 22
        callee_token = from_thread_run_sync(_core.current_trio_token, trio_token=token)
530 22
        return callee_token
531

532 22
    caller_token = _core.current_trio_token()
533 22
    callee_token = await to_thread_run_sync(thread_fn, caller_token)
534 22
    assert callee_token == caller_token
535

536

537 22
async def test_from_thread_no_token():
538
    # Test that a "raw call" to trio.from_thread.run() fails because no token
539
    # has been provided
540

541 22
    with pytest.raises(RuntimeError):
542 22
        from_thread_run_sync(_core.current_time)
543

544

545 22
def test_run_fn_as_system_task_catched_badly_typed_token():
546 22
    with pytest.raises(RuntimeError):
547 22
        from_thread_run_sync(_core.current_time, trio_token="Not TrioTokentype")
548

549

550 22
async def test_from_thread_inside_trio_thread():
551
    def not_called():  # pragma: no cover
552
        assert False
553

554 22
    trio_token = _core.current_trio_token()
555 22
    with pytest.raises(RuntimeError):
556 22
        from_thread_run_sync(not_called, trio_token=trio_token)

Read our documentation on viewing source code .

Loading