1
"""
2
Queue adapter for Dask
3
"""
4

5 4
import traceback
6 4
from typing import Any, Dict, Hashable, Tuple
7

8 4
from qcelemental.models import FailedOperation
9

10 4
from .base_adapter import BaseAdapter
11

12

13 4
def _get_future(future):
14 2
    try:
15 2
        return future.result()
16 0
    except Exception as e:
17 0
        msg = "Caught Executor Error:\n" + traceback.format_exc()
18 0
        ret = FailedOperation(**{"success": False, "error": {"error_type": e.__class__.__name__, "error_message": msg}})
19 0
        return ret
20

21

22 4
class ExecutorAdapter(BaseAdapter):
23
    """A Queue Adapter for Python Executors"""
24

25 4
    def __repr__(self):
26

27 4
        return "<ExecutorAdapter client=<{} max_workers={}>>".format(
28
            self.client.__class__.__name__, self.client._max_workers
29
        )
30

31 4
    def _submit_task(self, task_spec: Dict[str, Any]) -> Tuple[Hashable, Any]:
32 2
        func = self.get_function(task_spec["spec"]["function"])
33 2
        task = self.client.submit(func, *task_spec["spec"]["args"], **task_spec["spec"]["kwargs"])
34 2
        return task_spec["id"], task
35

36 4
    def count_active_task_slots(self) -> int:
37 4
        return self.client._max_workers
38

39 4
    def acquire_complete(self) -> Dict[str, Any]:
40 4
        ret = {}
41 4
        del_keys = []
42 4
        for key, future in self.queue.items():
43 2
            if future.done():
44 2
                ret[key] = _get_future(future)
45 2
                del_keys.append(key)
46

47 4
        for key in del_keys:
48 2
            del self.queue[key]
49

50 4
        return ret
51

52 4
    def await_results(self) -> bool:
53 4
        from concurrent.futures import wait
54

55 4
        wait(list(self.queue.values()))
56

57 4
        return True
58

59 4
    def close(self) -> bool:
60 4
        for future in self.queue.values():
61 0
            future.cancel()
62

63 4
        self.client.shutdown()
64 4
        return True
65

66

67 4
class DaskAdapter(ExecutorAdapter):
68
    """A Queue Adapter for Dask"""
69

70 4
    def __repr__(self):
71

72 1
        return "<DaskAdapter client={}>".format(self.client)
73

74 4
    def _submit_task(self, task_spec: Dict[str, Any]) -> Tuple[Hashable, Any]:
75 1
        func = self.get_function(task_spec["spec"]["function"])
76

77
        # Watch out out for thread unsafe tasks and our own constraints
78 1
        task = self.client.submit(
79
            func, *task_spec["spec"]["args"], **task_spec["spec"]["kwargs"], resources={"process": 1}
80
        )
81 1
        return task_spec["id"], task
82

83 4
    def count_active_task_slots(self) -> int:
84 1
        if hasattr(self.client.cluster, "_count_active_workers"):
85
            # Note: This should be right since its counting Dask Workers, and each Dask Worker = 1 task, which we then
86
            # Multiply by cores_per_task in the manager.
87 0
            return self.client.cluster._count_active_workers()
88
        else:
89 1
            return len(self.client.cluster.scheduler.workers)
90

91 4
    def await_results(self) -> bool:
92 1
        from dask.distributed import wait
93

94 1
        wait(list(self.queue.values()))
95 1
        return True
96

97 4
    def close(self) -> bool:
98

99 1
        self.client.close()
100 1
        return True

Read our documentation on viewing source code .

Loading