1
"""
2
All procedures tasks involved in on-node computation.
3
"""
4

5 4
from typing import List, Union
6

7 4
import qcelemental as qcel
8

9 4
import qcengine as qcng
10

11 4
from ..interface.models import Molecule, OptimizationRecord, QCSpecification, ResultRecord, TaskRecord, KVStore
12 4
from .procedures_util import parse_single_tasks
13

14 4
_wfn_return_names = set(qcel.models.results.WavefunctionProperties._return_results_names)
15 4
_wfn_all_fields = set(qcel.models.results.WavefunctionProperties.__fields__.keys())
16

17

18 4
class BaseTasks:
19 4
    def __init__(self, storage, logger):
20 4
        self.storage = storage
21 4
        self.logger = logger
22

23 4
    def submit_tasks(self, data):
24

25 4
        new_tasks, results_ids, existing_ids, errors = self.parse_input(data)
26

27 4
        self.storage.queue_submit(new_tasks)
28

29 4
        n_inserted = 0
30 4
        missing = []
31 4
        for num, x in enumerate(results_ids):
32 4
            if x is None:
33 2
                missing.append(num)
34
            else:
35 4
                n_inserted += 1
36

37 4
        results = {
38
            "meta": {
39
                "n_inserted": n_inserted,
40
                "duplicates": [],
41
                "validation_errors": [],
42
                "success": True,
43
                "error_description": False,
44
                "errors": errors,
45
            },
46
            "data": {"ids": results_ids, "submitted": [x.base_result for x in new_tasks], "existing": existing_ids},
47
        }
48

49 4
        return results
50

51 4
    def verify_input(self, data):
52 0
        raise TypeError("verify_input not defined")
53

54 4
    def parse_input(self, data):
55 0
        raise TypeError("parse_input not defined")
56

57 4
    def parse_output(self, data):
58 0
        raise TypeError("parse_output not defined")
59

60

61 4
class SingleResultTasks(BaseTasks):
62
    """A task generator for a single Result.
63
    Unique by: driver, method, basis, option (the name in the options table),
64
    and program.
65
    """
66

67 4
    def verify_input(self, data):
68 4
        program = data.meta.program.lower()
69 4
        if program not in qcng.list_all_programs():
70 4
            return f"Program '{program}' not available in QCEngine."
71

72 4
        if data.meta.dict().get("protocols", None) is not None:
73 4
            try:
74 4
                qcel.models.results.ResultProtocols(**data.meta.protocols)
75 0
            except Exception as e:
76 0
                return f"Could not validate protocols: {str(e)}"
77

78 4
        return True
79

80 4
    def parse_input(self, data):
81
        """Parse input json into internally appropriate format
82

83

84
        Format of the input data:
85
        data = {
86
            "meta": {
87
                "procedure": "single",
88
                "driver": "energy",
89
                "method": "HF",
90
                "basis": "sto-3g",
91
                "keywords": "default",
92
                "program": "psi4"
93
                },
94
            },
95
            "data": ["mol_id_1", "mol_id_2", ...],
96
        }
97

98
        """
99

100
        # Unpack all molecules
101 4
        molecule_list = self.storage.get_add_molecules_mixed(data.data)["data"]
102

103 4
        if data.meta.keywords:
104 1
            keywords = self.storage.get_add_keywords_mixed([data.meta.keywords])["data"][0]
105

106
        else:
107 4
            keywords = None
108

109
        # Grab the tag if available
110 4
        meta = data.meta.dict()
111 4
        tag = meta.pop("tag", None)
112 4
        priority = meta.pop("priority", None)
113

114
        # Construct full tasks
115 4
        new_tasks = []
116 4
        results_ids = []
117 4
        existing_ids = []
118 4
        for mol in molecule_list:
119 4
            if mol is None:
120 2
                results_ids.append(None)
121 2
                continue
122

123 4
            record = ResultRecord(**meta.copy(), molecule=mol.id)
124 4
            inp = record.build_schema_input(mol, keywords)
125 4
            inp.extras["_qcfractal_tags"] = {"program": record.program, "keywords": record.keywords}
126

127 4
            ret = self.storage.add_results([record])
128

129 4
            base_id = ret["data"][0]
130 4
            results_ids.append(base_id)
131

132
            # Task is complete
133 4
            if len(ret["meta"]["duplicates"]):
134 4
                existing_ids.append(base_id)
135 4
                continue
136

137
            # Build task object
138 4
            task = TaskRecord(
139
                **{
140
                    "spec": {
141
                        "function": "qcengine.compute",  # todo: add defaults in models
142
                        "args": [inp.dict(), data.meta.program],
143
                        "kwargs": {},  # todo: add defaults in models
144
                    },
145
                    "parser": "single",
146
                    "program": data.meta.program,
147
                    "tag": tag,
148
                    "priority": priority,
149
                    "base_result": base_id,
150
                }
151
            )
152

153 4
            new_tasks.append(task)
154

155 4
        return new_tasks, results_ids, existing_ids, []
156

157 4
    def parse_output(self, result_outputs):
158

159
        # Add new runs to database
160 2
        completed_tasks = []
161 2
        updates = []
162 2
        for data in result_outputs:
163 2
            result = self.storage.get_results(id=data["base_result"])["data"][0]
164 2
            result = ResultRecord(**result)
165

166 2
            rdata = data["result"]
167

168 2
            outputs = [rdata["stdout"], rdata["stderr"], rdata["error"]]
169 2
            kvstores = [KVStore(data=x) if x is not None else None for x in outputs]
170 2
            stdout, stderr, error = self.storage.add_kvstore(kvstores)["data"]
171 2
            rdata["stdout"] = stdout
172 2
            rdata["stderr"] = stderr
173 2
            rdata["error"] = error
174

175
            # Store Wavefunction data
176 2
            if data["result"].get("wavefunction", False):
177 1
                wfn = data["result"].get("wavefunction", False)
178 1
                available = set(wfn.keys()) - {"restricted", "basis"}
179 1
                return_map = {k: wfn[k] for k in wfn.keys() & _wfn_return_names}
180

181 1
                rdata["wavefunction"] = {
182
                    "available": list(available),
183
                    "restricted": wfn["restricted"],
184
                    "return_map": return_map,
185
                }
186

187
                # Extra fields are trimmed as we have a column *per* wavefunction structure.
188 1
                available_keys = wfn.keys() - _wfn_return_names
189 1
                if available_keys > _wfn_all_fields:
190 0
                    self.logger.warning(
191
                        f"Too much wavefunction data for result {data['base_result']}, removing extra data."
192
                    )
193 0
                    available_keys &= _wfn_all_fields
194

195 1
                wavefunction_save = {k: wfn[k] for k in available_keys}
196 1
                wfn_data_id = self.storage.add_wavefunction_store([wavefunction_save])["data"][0]
197 1
                rdata["wavefunction_data_id"] = wfn_data_id
198

199 2
            result._consume_output(rdata)
200 2
            updates.append(result)
201 2
            completed_tasks.append(data["task_id"])
202

203
        # TODO: sometimes it should be update, and others its add
204 2
        self.storage.update_results(updates)
205

206 2
        return completed_tasks, [], []
207

208

209
# ----------------------------------------------------------------------------
210

211

212 4
class OptimizationTasks(BaseTasks):
213
    """
214
    Optimization task manipulation
215
    """
216

217 4
    def verify_input(self, data):
218 4
        program = data.meta.program.lower()
219 4
        if program not in qcng.list_all_procedures():
220 4
            return "Procedure '{}' not available in QCEngine.".format(program)
221

222 4
        program = data.meta.qc_spec["program"].lower()
223 4
        if program not in qcng.list_all_programs():
224 4
            return "Program '{}' not available in QCEngine.".format(program)
225

226 4
        return True
227

228 4
    def parse_input(self, data, duplicate_id="hash_index"):
229
        """Parse input json into internally appropriate format
230

231
        json_data = {
232
            "meta": {
233
                "procedure": "optimization",
234
                "option": "default",
235
                "program": "geometric",
236
                "qc_meta": {
237
                    "driver": "energy",
238
                    "method": "HF",
239
                    "basis": "sto-3g",
240
                    "keywords": "default",
241
                    "program": "psi4"
242
                },
243
            },
244
            "data": ["mol_id_1", "mol_id_2", ...],
245
        }
246

247
        qc_schema_input = {
248
            "molecule": {
249
                "geometry": [
250
                    0.0,  0.0, -0.6,
251
                    0.0,  0.0,  0.6,
252
                ],
253
                "symbols": ["H", "H"],
254
                "connectivity": [[0, 1, 1]]
255
            },
256
            "driver": "gradient",
257
            "model": {
258
                "method": "HF",
259
                "basis": "sto-3g"
260
            },
261
            "keywords": {},
262
        }
263
        json_data = {
264
            "keywords": {
265
                "coordsys": "tric",
266
                "maxiter": 100,
267
                "program": "psi4"
268
            },
269
        }
270

271
        """
272

273
        # Unpack all molecules
274 4
        intitial_molecule_list = self.storage.get_add_molecules_mixed(data.data)["data"]
275

276
        # Unpack keywords
277 4
        if data.meta.keywords is None:
278 4
            opt_keywords = {}
279
        else:
280 1
            opt_keywords = data.meta.keywords
281 4
        opt_keywords["program"] = data.meta.qc_spec["program"]
282

283 4
        qc_spec = QCSpecification(**data.meta.qc_spec)
284 4
        if qc_spec.keywords:
285 1
            qc_keywords = self.storage.get_add_keywords_mixed([qc_spec.keywords])["data"][0]
286 1
            if qc_keywords is None:
287 0
                raise KeyError("Could not find requested KeywordsSet from id key.")
288
        else:
289 4
            qc_keywords = None
290

291 4
        tag = data.meta.tag
292 4
        priority = data.meta.priority
293

294 4
        new_tasks = []
295 4
        results_ids = []
296 4
        existing_ids = []
297 4
        for initial_molecule in intitial_molecule_list:
298 4
            if initial_molecule is None:
299 1
                results_ids.append(None)
300 1
                continue
301

302 4
            doc_data = {
303
                "initial_molecule": initial_molecule.id,
304
                "qc_spec": qc_spec,
305
                "keywords": opt_keywords,
306
                "program": data.meta.program,
307
            }
308 4
            if hasattr(data.meta, "protocols"):
309 1
                doc_data["protocols"] = data.meta.protocols
310 4
            doc = OptimizationRecord(**doc_data)
311

312 4
            inp = doc.build_schema_input(initial_molecule=initial_molecule, qc_keywords=qc_keywords)
313 4
            inp.input_specification.extras["_qcfractal_tags"] = {
314
                "program": qc_spec.program,
315
                "keywords": qc_spec.keywords,
316
            }
317

318 4
            ret = self.storage.add_procedures([doc])
319 4
            base_id = ret["data"][0]
320 4
            results_ids.append(base_id)
321

322
            # Task is complete
323 4
            if len(ret["meta"]["duplicates"]):
324 1
                existing_ids.append(base_id)
325 1
                continue
326

327
            # Build task object
328 4
            task = TaskRecord(
329
                **{
330
                    "spec": {
331
                        "function": "qcengine.compute_procedure",
332
                        "args": [inp.dict(), data.meta.program],
333
                        "kwargs": {},
334
                    },
335
                    "parser": "optimization",
336
                    "program": qc_spec.program,
337
                    "procedure": data.meta.program,
338
                    "tag": tag,
339
                    "priority": priority,
340
                    "base_result": base_id,
341
                }
342
            )
343

344 4
            new_tasks.append(task)
345

346 4
        return new_tasks, results_ids, existing_ids, []
347

348 4
    def parse_output(self, opt_outputs):
349
        """Save the results of the procedure.
350
        It must make sure to save the results in the results table
351
        including the task_id in the TaskQueue table
352
        """
353

354 1
        completed_tasks = []
355 1
        updates = []
356 1
        for output in opt_outputs:
357 1
            rec = self.storage.get_procedures(id=output["base_result"])["data"][0]
358 1
            rec = OptimizationRecord(**rec)
359

360 1
            procedure = output["result"]
361

362
            # Add initial and final molecules
363 1
            update_dict = {}
364 1
            initial_mol, final_mol = self.storage.add_molecules(
365
                [Molecule(**procedure["initial_molecule"]), Molecule(**procedure["final_molecule"])]
366
            )["data"]
367 1
            assert initial_mol == rec.initial_molecule
368 1
            update_dict["final_molecule"] = final_mol
369

370
            # Parse trajectory computations and add task_id
371 1
            traj_dict = {k: v for k, v in enumerate(procedure["trajectory"])}
372 1
            results = parse_single_tasks(self.storage, traj_dict)
373 1
            for k, v in results.items():
374 1
                v["task_id"] = output["task_id"]
375 1
                results[k] = ResultRecord(**v)
376

377 1
            ret = self.storage.add_results(list(results.values()))
378 1
            update_dict["trajectory"] = ret["data"]
379 1
            update_dict["energies"] = procedure["energies"]
380

381
            # Save stdout/stderr
382 1
            outputs = [procedure["stdout"], procedure["stderr"], procedure["error"]]
383 1
            kvstores = [KVStore(data=x) if x is not None else None for x in outputs]
384 1
            stdout, stderr, error = self.storage.add_kvstore(kvstores)["data"]
385 1
            update_dict["stdout"] = stdout
386 1
            update_dict["stderr"] = stderr
387 1
            update_dict["error"] = error
388 1
            update_dict["provenance"] = procedure["provenance"]
389

390 1
            rec = OptimizationRecord(**{**rec.dict(), **update_dict})
391 1
            updates.append(rec)
392 1
            completed_tasks.append(output["task_id"])
393

394 1
        self.storage.update_procedures(updates)
395

396 1
        return completed_tasks, [], []
397

398

399
# ----------------------------------------------------------------------------
400

401 4
supported_procedures = Union[SingleResultTasks, OptimizationTasks]
402 4
__procedure_map = {"single": SingleResultTasks, "optimization": OptimizationTasks}
403

404

405 4
def check_procedure_available(procedure: str) -> List[str]:
406
    """
407
    Lists all available procedures
408
    """
409 4
    return procedure.lower() in __procedure_map
410

411

412 4
def get_procedure_parser(procedure_type: str, storage, logger) -> supported_procedures:
413
    """A factory method that returns the appropriate parser class
414
    for the supported procedure types (like single and optimization)
415

416
    Parameters
417
    ---------
418
    procedure_type: str, 'single' or 'optimization'
419
    storage: storage socket object
420
        such as MongoengineSocket object
421

422
    Returns
423
    -------
424
    A parser class corresponding to the procedure_type:
425
        'single' --> SingleResultTasks
426
        'optimization' --> OptimizationTasks
427
    """
428

429 4
    try:
430 4
        return __procedure_map[procedure_type.lower()](storage, logger)
431 0
    except KeyError:
432 0
        raise KeyError("Procedure type ({}) is not suported yet.".format(procedure_type))

Read our documentation on viewing source code .

Loading