1
"""
2
Utilities and base functions for Services.
3
"""
4

5 4
import abc
6 4
import datetime
7 4
from typing import Any, Dict, List, Optional, Set, Tuple
8

9 4
from pydantic import validator
10 4
from qcelemental.models import ComputeError
11

12 4
from ..interface.models import ObjectId, ProtoModel
13 4
from ..interface.models.rest_models import TaskQueuePOSTBody
14 4
from ..interface.models.task_models import PriorityEnum
15 4
from ..procedures import get_procedure_parser
16

17

18 4
class TaskManager(ProtoModel):
19

20 4
    storage_socket: Optional[Any] = None
21 4
    logger: Optional[Any] = None
22

23 4
    required_tasks: Dict[str, str] = {}
24 4
    tag: Optional[str] = None
25 4
    priority: PriorityEnum = PriorityEnum.HIGH
26

27 4
    class Config(ProtoModel.Config):
28 4
        allow_mutation = True
29 4
        serialize_default_excludes = {"storage_socket", "logger"}
30

31 4
    def done(self) -> bool:
32
        """
33
        Check if requested tasks are complete.
34
        """
35

36 1
        if len(self.required_tasks) == 0:
37 1
            return True
38

39 1
        task_query = self.storage_socket.get_procedures(
40
            id=list(self.required_tasks.values()), include=["status", "error"]
41
        )
42

43 1
        status_values = set(x["status"] for x in task_query["data"])
44 1
        if status_values == {"COMPLETE"}:
45 1
            return True
46

47 1
        elif "ERROR" in status_values:
48 1
            for x in task_query["data"]:
49 1
                if x["status"] != "ERROR":
50 0
                    continue
51

52 1
            self.logger.error("Error in service compute as follows:")
53 1
            tasks = self.storage_socket.get_queue()["data"]
54 1
            for x in tasks:
55 1
                if "error" not in x:
56 1
                    continue
57

58 0
                self.logger.error(x["error"]["error_message"])
59

60 1
            raise KeyError("All tasks did not execute successfully.")
61
        else:
62 0
            return False
63

64 4
    def get_tasks(self) -> Dict[str, Any]:
65
        """
66
        Pulls currently held tasks.
67
        """
68

69 1
        ret = {}
70 1
        for k, id in self.required_tasks.items():
71 1
            ret[k] = self.storage_socket.get_procedures(id=id)["data"][0]
72

73 1
        return ret
74

75 4
    def submit_tasks(self, procedure_type: str, tasks: Dict[str, Any]) -> bool:
76
        """
77
        Submits new tasks to the queue and provides a waiter until there are done.
78
        """
79 1
        procedure_parser = get_procedure_parser(procedure_type, self.storage_socket, self.logger)
80

81 1
        required_tasks = {}
82

83
        # Add in all new tasks
84 1
        for key, packet in tasks.items():
85 1
            packet["meta"].update({"tag": self.tag, "priority": self.priority})
86
            # print("Check tag and priority:", packet)
87 1
            packet = TaskQueuePOSTBody(**packet)
88

89
            # Turn packet into a full task, if there are duplicates, get the ID
90 1
            r = procedure_parser.submit_tasks(packet)
91

92 1
            if len(r["meta"]["errors"]):
93 0
                raise KeyError("Problem submitting task: {}.".format(errors))
94

95
            # print("Submission:", r["data"])
96 1
            required_tasks[key] = r["data"]["ids"][0]
97

98 1
        self.required_tasks = required_tasks
99

100 1
        return True
101

102

103 4
class BaseService(ProtoModel, abc.ABC):
104

105
    # Excluded fields
106 4
    storage_socket: Optional[Any]
107 4
    logger: Optional[Any]
108

109
    # Base identification
110 4
    id: Optional[ObjectId] = None
111 4
    hash_index: str
112 4
    service: str
113 4
    program: str
114 4
    procedure: str
115

116
    # Output data
117 4
    output: Any
118

119
    # Links
120 4
    task_id: Optional[ObjectId] = None
121 4
    procedure_id: Optional[ObjectId] = None
122

123
    # Task manager
124 4
    task_tag: Optional[str] = None
125 4
    task_priority: PriorityEnum
126 4
    task_manager: TaskManager = TaskManager()
127

128 4
    status: str = "WAITING"
129 4
    error: Optional[ComputeError] = None
130 4
    tag: Optional[str] = None
131

132
    # Sorting and priority
133 4
    priority: PriorityEnum = PriorityEnum.NORMAL
134 4
    modified_on: datetime.datetime = None
135 4
    created_on: datetime.datetime = None
136

137 4
    class Config(ProtoModel.Config):
138 4
        allow_mutation = True
139 4
        serialize_default_excludes = {"storage_socket", "logger"}
140

141 4
    def __init__(self, **data):
142

143 4
        dt = datetime.datetime.utcnow()
144 4
        data.setdefault("modified_on", dt)
145 4
        data.setdefault("created_on", dt)
146

147 4
        super().__init__(**data)
148 4
        self.task_manager.logger = self.logger
149 4
        self.task_manager.storage_socket = self.storage_socket
150 4
        self.task_manager.tag = self.task_tag
151 4
        self.task_manager.priority = self.task_priority
152

153 4
    @validator("task_priority", pre=True)
154
    def munge_priority(cls, v):
155 4
        if isinstance(v, str):
156 1
            v = PriorityEnum[v.upper()]
157 4
        elif v is None:
158 1
            v = PriorityEnum.HIGH
159 4
        return v
160

161 4
    @classmethod
162 4
    @abc.abstractmethod
163 4
    def initialize_from_api(cls, storage_socket, meta, molecule, tag=None, priority=None):
164
        """
165
        Initalizes a Service from the API.
166
        """
167

168 4
    @abc.abstractmethod
169
    def iterate(self):
170
        """
171
        Takes a "step" of the service. Should return False if not finished.
172
        """
173

174

175 4
def expand_ndimensional_grid(
176
    dimensions: Tuple[int, ...], seeds: Set[Tuple[int, ...]], complete: Set[Tuple[int, ...]]
177
) -> List[Tuple[Tuple[int, ...], Tuple[int, ...]]]:
178
    """
179
    Expands an n-dimensional key/value grid.
180

181
    Example
182
    -------
183
    >>> expand_ndimensional_grid((3, 3), {(1, 1)}, set())
184
    [((1, 1), (0, 1)), ((1, 1), (2, 1)), ((1, 1), (1, 0)), ((1, 1), (1, 2))]
185
    """
186

187 1
    dimensions = tuple(dimensions)
188 1
    compute = set()
189 1
    connections = []
190

191 1
    for d in range(len(dimensions)):
192

193
        # Loop over all compute seeds
194 1
        for seed in seeds:
195

196
            # Iterate both directions
197 1
            for disp in [-1, 1]:
198 1
                new_dim = seed[d] + disp
199

200
                # Bound check
201 1
                if new_dim >= dimensions[d]:
202 1
                    continue
203 1
                if new_dim < 0:
204 1
                    continue
205

206 1
                new = list(seed)
207 1
                new[d] = new_dim
208 1
                new = tuple(new)
209

210
                # Push out duplicates from both new compute and copmlete
211 1
                if new in compute:
212 1
                    continue
213 1
                if new in complete:
214 1
                    continue
215

216 1
                compute |= {new}
217 1
                connections.append((seed, new))
218

219 1
    return connections

Read our documentation on viewing source code .

Loading