1
"""
2
A model for Compute Records
3
"""
4

5 4
import abc
6 4
import datetime
7 4
from enum import Enum
8 4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
9

10 4
import numpy as np
11 4
import qcelemental as qcel
12 4
from pydantic import Field, constr, validator
13

14 4
from ..visualization import scatter_plot
15 4
from .common_models import DriverEnum, ObjectId, ProtoModel, QCSpecification
16 4
from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer
17

18
if TYPE_CHECKING:  # pragma: no cover
19
    from qcelemental.models import OptimizationInput, ResultInput
20

21
    from .common_models import KeywordSet, Molecule
22

23 4
__all__ = ["OptimizationRecord", "ResultRecord", "OptimizationRecord", "RecordBase"]
24

25

26 4
class RecordStatusEnum(str, Enum):
27
    """
28
    The state of a record object. The states which are available are a finite set.
29
    """
30

31 4
    complete = "COMPLETE"
32 4
    incomplete = "INCOMPLETE"
33 4
    running = "RUNNING"
34 4
    error = "ERROR"
35

36

37 4
class RecordBase(ProtoModel, abc.ABC):
38
    """
39
    A BaseRecord object for Result and Procedure records. Contains all basic
40
    fields common to the all records.
41
    """
42

43
    # Classdata
44 4
    _hash_indices: Set[str]
45

46
    # Helper data
47 4
    client: Any = Field(None, description="The client object which the records are fetched from.")
48 4
    cache: Dict[str, Any] = Field(
49
        {},
50
        description="Object cache from expensive queries. It should be very rare that this needs to be set manually "
51
        "by the user.",
52
    )
53

54
    # Base identification
55 4
    id: ObjectId = Field(
56
        None, description="Id of the object on the database. This is assigned automatically by the database."
57
    )
58 4
    hash_index: Optional[str] = Field(
59
        None, description="Hash of this object used to detect duplication and collisions in the database."
60
    )
61 4
    procedure: str = Field(..., description="Name of the procedure which this Record targets.")
62 4
    program: str = Field(
63
        ...,
64
        description="The quantum chemistry program which carries out the individual quantum chemistry calculations.",
65
    )
66 4
    version: int = Field(..., description="The version of this record object describes.")
67 4
    protocols: Optional[Dict[str, Any]] = Field(
68
        None, description="Protocols that change the data stored in top level fields."
69
    )
70

71
    # Extra fields
72 4
    extras: Dict[str, Any] = Field({}, description="Extra information to associate with this record.")
73 4
    stdout: Optional[ObjectId] = Field(
74
        None,
75
        description="The Id of the stdout data stored in the database which was used to generate this record from the "
76
        "various programs which were called in the process.",
77
    )
78 4
    stderr: Optional[ObjectId] = Field(
79
        None,
80
        description="The Id of the stderr data stored in the database which was used to generate this record from the "
81
        "various programs which were called in the process.",
82
    )
83 4
    error: Optional[ObjectId] = Field(
84
        None,
85
        description="The Id of the error data stored in the database in the event that an error was generated in the "
86
        "process of carrying out the process this record targets. If no errors were raised, this field "
87
        "will be empty.",
88
    )
89

90
    # Compute status
91 4
    task_id: Optional[ObjectId] = Field(  # TODO: not used in SQL
92
        None, description="Id of the compute task tracked by Fractal in its TaskTable."
93
    )
94 4
    manager_name: Optional[str] = Field(None, description="Name of the Queue Manager which generated this record.")
95 4
    status: RecordStatusEnum = Field(RecordStatusEnum.incomplete, description=str(RecordStatusEnum.__doc__))
96 4
    modified_on: datetime.datetime = Field(None, description="Last time the data this record points to was modified.")
97 4
    created_on: datetime.datetime = Field(None, description="Time the data this record points to was first created.")
98

99
    # Carry-ons
100 4
    provenance: Optional[qcel.models.Provenance] = Field(
101
        None,
102
        description="Provenance information tied to the creation of this record. This includes things such as every "
103
        "program which was involved in generating the data for this record.",
104
    )
105

106 4
    class Config(ProtoModel.Config):
107 4
        build_hash_index = True
108

109 4
    @validator("program")
110
    def check_program(cls, v):
111 4
        return v.lower()
112

113 4
    def __init__(self, **data):
114

115
        # Set datetime defaults if not automatically available
116 4
        data.setdefault("modified_on", datetime.datetime.utcnow())
117 4
        data.setdefault("created_on", datetime.datetime.utcnow())
118

119 4
        super().__init__(**data)
120

121
        # Set hash index if not present
122 4
        if self.Config.build_hash_index and (self.hash_index is None):
123 4
            self.__dict__["hash_index"] = self.get_hash_index()
124

125 4
    def __repr_args__(self):
126

127 1
        return [("id", f"{self.id}"), ("status", f"{self.status}")]
128

129
    ### Serialization helpers
130

131 4
    @classmethod
132 4
    def get_hash_fields(cls) -> Set[str]:
133
        """Provides a description of the fields to be used in the hash
134
        that uniquely defines this object.
135

136
        Returns
137
        -------
138
        Set[str]
139
            A list of all fields that are used in the hash.
140

141
        """
142 4
        return cls._hash_indices | {"procedure", "program"}
143

144 4
    def get_hash_index(self) -> str:
145
        """Builds (or rebuilds) the hash of this
146
        object using the internally known hash fields.
147

148
        Returns
149
        -------
150
        str
151
            The objects unique hash index.
152
        """
153 4
        data = self.dict(include=self.get_hash_fields(), encoding="json")
154

155 4
        return hash_dictionary(data)
156

157 4
    def dict(self, *args, **kwargs):
158 4
        kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"client", "cache"}
159
        # kwargs["skip_defaults"] = True
160 4
        return super().dict(*args, **kwargs)
161

162
    ### Checkers
163

164 4
    def check_client(self, noraise: bool = False) -> bool:
165
        """Checks whether this object owns a FractalClient or not.
166
        This is often done so that objects pulled from a server using
167
        a FractalClient still posses a connection to the server so that
168
        additional data related to this object can be queried.
169

170
        Raises
171
        ------
172
        ValueError
173
            If this object does not contain own a client.
174

175
        Parameters
176
        ----------
177
        noraise : bool, optional
178
            Does not raise an error if this is True and instead returns
179
            a boolean depending if a client exists or not.
180

181
        Returns
182
        -------
183
        bool
184
            If True, the object owns a connection to a server. False otherwise.
185
        """
186 2
        if self.client is None:
187 0
            if noraise:
188 0
                return False
189

190 0
            raise ValueError("Requested method requires a client, but client was '{}'.".format(self.client))
191

192 2
        return True
193

194
    ### KVStore Getters
195

196 4
    def _kvstore_getter(self, field_name):
197
        """
198
        Internal KVStore getting object
199
        """
200 2
        self.check_client()
201

202 2
        oid = self.__dict__[field_name]
203 2
        if oid is None:
204 0
            return None
205

206 2
        if field_name not in self.cache:
207
            # Decompress here, rather than later
208
            # that way, it is decompressed in the cache
209 2
            kv = self.client.query_kvstore([oid])[oid]
210

211 2
            if field_name == "error":
212 2
                self.cache[field_name] = kv.get_json()
213
            else:
214 1
                self.cache[field_name] = kv.get_string()
215

216 2
        return self.cache[field_name]
217

218 4
    def get_stdout(self) -> Optional[str]:
219
        """Pulls the stdout from the denormalized KVStore and returns it to the user.
220

221
        Returns
222
        -------
223
        Optional[str]
224
            The requested stdout, none if no stdout present.
225
        """
226 1
        return self._kvstore_getter("stdout")
227

228 4
    def get_stderr(self) -> Optional[str]:
229
        """Pulls the stderr from the denormalized KVStore and returns it to the user.
230

231
        Returns
232
        -------
233
        Optional[str]
234
            The requested stderr, none if no stderr present.
235
        """
236

237 0
        return self._kvstore_getter("stderr")
238

239 4
    def get_error(self) -> Optional[qcel.models.ComputeError]:
240
        """Pulls the stderr from the denormalized KVStore and returns it to the user.
241

242
        Returns
243
        -------
244
        Optional[qcel.models.ComputeError]
245
            The requested compute error, none if no error present.
246
        """
247 2
        value = self._kvstore_getter("error")
248 2
        if value:
249 2
            return qcel.models.ComputeError(**value)
250
        else:
251 0
            return value
252

253

254 4
class ResultRecord(RecordBase):
255

256
    # Classdata
257 4
    _hash_indices = {"driver", "method", "basis", "molecule", "keywords", "program"}
258

259
    # Version data
260 4
    version: int = Field(1, description="Version of the ResultRecord Model which this data was created with.")
261 4
    procedure: constr(strip_whitespace=True, regex="single") = Field(
262
        "single", description='Procedure is fixed as "single" because this is single quantum chemistry result.'
263
    )
264

265
    # Input data
266 4
    driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__))
267 4
    method: str = Field(..., description="The quantum chemistry method the driver runs with.")
268 4
    molecule: ObjectId = Field(
269
        ..., description="The Id of the molecule in the Database which the result is computed on."
270
    )
271 4
    basis: Optional[str] = Field(
272
        None,
273
        description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
274
        "methods without basis sets.",
275
    )
276 4
    keywords: Optional[ObjectId] = Field(
277
        None,
278
        description="The Id of the :class:`KeywordSet` which was passed into the quantum chemistry program that "
279
        "performed this calculation.",
280
    )
281 4
    protocols: Optional[qcel.models.results.ResultProtocols] = Field(
282
        qcel.models.results.ResultProtocols(), description=""
283
    )
284

285
    # Output data
286 4
    return_result: Union[float, qcel.models.types.Array[float], Dict[str, Any]] = Field(
287
        None, description="The primary result of the calculation, output is a function of the specified ``driver``."
288
    )
289 4
    properties: qcel.models.ResultProperties = Field(
290
        None, description="Additional data and results computed as part of the ``return_result``."
291
    )
292 4
    wavefunction: Optional[Dict[str, Any]] = Field(None, description="Wavefunction data generated by the Result.")
293 4
    wavefunction_data_id: Optional[ObjectId] = Field(None, description="The id of the wavefunction")
294

295 4
    class Config(RecordBase.Config):
296
        """A hash index is not used for ResultRecords as they can be
297
        uniquely determined with queryable keys.
298
        """
299

300 4
        build_hash_index = False
301

302 4
    @validator("method")
303
    def check_method(cls, v):
304
        """Methods should have a lower string to match the database.
305
        """
306 4
        return v.lower()
307

308 4
    @validator("basis")
309
    def check_basis(cls, v):
310 4
        return prepare_basis(v)
311

312 4
    def get_wavefunction(self, key: Union[str, List[str]]) -> Any:
313
        """
314
        Pulls down the Wavefunction data associated with the computation.
315
        """
316

317 1
        if self.wavefunction is None:
318 0
            raise AttributeError("This Record was not computed with Wavefunction data.")
319

320 1
        single_return = False
321 1
        if isinstance(key, str):
322 1
            key = [key]
323 1
            single_return = True
324

325 1
        keys = [x.lower() for x in key]
326

327 1
        self.cache.setdefault("wavefunction", {})
328

329 1
        mapped_keys = {self.wavefunction["return_map"].get(x, x) for x in keys}
330 1
        missing = mapped_keys - self.cache["wavefunction"].keys()
331

332 1
        unknown = missing - set(self.wavefunction["available"] + ["basis", "restricted"])
333 1
        if unknown:
334 0
            raise KeyError(
335
                f"Wavefunction Key(s) `{unknown}` not understood, available keys are: {self.wavefunction['available']}"
336
            )
337

338 1
        if missing:
339

340
            # Translate a return value
341 1
            proj = [self.wavefunction["return_map"].get(x, x) for x in missing]
342

343 1
            self.cache["wavefunction"].update(
344
                self.client.custom_query(
345
                    "wavefunctionstore", None, {"id": self.wavefunction_data_id}, meta={"include": proj}
346
                )
347
            )
348

349 1
            if "basis" in missing:
350 1
                self.cache["wavefunction"]["basis"] = qcel.models.BasisSet(**self.cache["wavefunction"]["basis"])
351

352
        # Remap once more
353 1
        ret = {}
354 1
        for k in keys:
355 1
            mkey = self.wavefunction["return_map"].get(k, k)
356 1
            ret[k] = self.cache["wavefunction"][mkey]
357

358 1
        if single_return:
359 1
            return ret[keys[0]]
360
        else:
361 1
            return ret
362

363
    ## QCSchema constructors
364

365 4
    def build_schema_input(
366
        self, molecule: "Molecule", keywords: Optional["KeywordSet"] = None, checks: bool = True
367
    ) -> "ResultInput":
368
        """
369
        Creates a OptimizationInput schema.
370
        """
371

372 4
        if checks:
373 4
            assert self.molecule == molecule.id
374 4
            if self.keywords:
375 1
                assert self.keywords == keywords.id
376

377 4
        model = {"method": self.method}
378 4
        if self.basis:
379 4
            model["basis"] = self.basis
380

381 4
        if not self.keywords:
382 4
            keywords = {}
383
        else:
384 1
            keywords = keywords.values
385

386 4
        if not self.protocols:
387 0
            protocols = {}
388
        else:
389 4
            protocols = self.protocols
390

391 4
        model = qcel.models.ResultInput(
392
            id=self.id,
393
            driver=self.driver.name,
394
            model=model,
395
            molecule=molecule,
396
            keywords=keywords,
397
            extras=self.extras,
398
            protocols=protocols,
399
        )
400 4
        return model
401

402 4
    def _consume_output(self, data: Dict[str, Any], checks: bool = True):
403 2
        assert self.method == data["model"]["method"]
404 2
        values = self.__dict__
405

406
        # Result specific
407 2
        values["extras"] = data["extras"]
408 2
        values["extras"].pop("_qcfractal_tags", None)
409 2
        values["return_result"] = data["return_result"]
410 2
        values["properties"] = data["properties"]
411

412
        # Wavefunction data
413 2
        values["wavefunction"] = data.get("wavefunction", None)
414 2
        values["wavefunction_data_id"] = data.get("wavefunction_data_id", None)
415

416
        # Standard blocks
417 2
        values["provenance"] = data["provenance"]
418 2
        values["error"] = data["error"]
419 2
        values["stdout"] = data["stdout"]
420 2
        values["stderr"] = data["stderr"]
421 2
        values["status"] = "COMPLETE"
422

423
    ## QCSchema constructors
424

425 4
    def get_molecule(self) -> "Molecule":
426
        """
427
        Pulls the Result's Molecule from the connected database.
428

429
        Returns
430
        -------
431
        Molecule
432
            The requested Molecule
433
        """
434 2
        self.check_client()
435

436 2
        if self.molecule is None:
437 0
            return None
438

439 2
        if "molecule" not in self.cache:
440 2
            self.cache["molecule"] = self.client.query_molecules(id=self.molecule)[0]
441

442 2
        return self.cache["molecule"]
443

444

445 4
class OptimizationRecord(RecordBase):
446
    """
447
    A OptimizationRecord for all optimization procedure data.
448
    """
449

450
    # Class data
451 4
    _hash_indices = {"initial_molecule", "keywords", "qc_spec"}
452

453
    # Version data
454 4
    version: int = Field(1, description="Version of the OptimizationRecord Model which this data was created with.")
455 4
    procedure: constr(strip_whitespace=True, regex="optimization") = Field(
456
        "optimization", description='A fixed string indication this is a record for an "Optimization".'
457
    )
458 4
    schema_version: int = Field(1, description="The version number of QCSchema under which this record conforms to.")
459

460
    # Input data
461 4
    initial_molecule: ObjectId = Field(
462
        ..., description="The Id of the molecule which was passed in as the reference for this Optimization."
463
    )
464 4
    qc_spec: QCSpecification = Field(
465
        ..., description="The specification of the quantum chemistry calculation to run at each point."
466
    )
467 4
    keywords: Dict[str, Any] = Field(
468
        {},
469
        description="The keyword options which were passed into the Optimization program. "
470
        "Note: These are a dictionary and not a :class:`KeywordSet` object.",
471
    )
472 4
    protocols: Optional[qcel.models.procedures.OptimizationProtocols] = Field(
473
        qcel.models.procedures.OptimizationProtocols(), description=""
474
    )
475

476
    # Automatting issue currently
477
    # description=str(qcel.models.procedures.OptimizationProtocols.__doc__))
478

479
    # Results
480 4
    energies: List[float] = Field(None, description="The ordered list of energies at each step of the Optimization.")
481 4
    final_molecule: ObjectId = Field(
482
        None, description="The ``ObjectId`` of the final, optimized Molecule the Optimization procedure converged to."
483
    )
484 4
    trajectory: List[ObjectId] = Field(
485
        None,
486
        description="The list of Molecule Id's the Optimization procedure generated at each step of the optimization."
487
        "``initial_molecule`` will be the first index, and ``final_molecule`` will be the last index.",
488
    )
489

490 4
    class Config(RecordBase.Config):
491 4
        pass
492

493 4
    @validator("keywords")
494
    def check_keywords(cls, v):
495 4
        if v is not None:
496 4
            v = recursive_normalizer(v)
497 4
        return v
498

499
    ## QCSchema constructors
500

501 4
    def build_schema_input(
502
        self, initial_molecule: "Molecule", qc_keywords: Optional["KeywordSet"] = None, checks: bool = True
503
    ) -> "OptimizationInput":
504
        """
505
        Creates a OptimizationInput schema.
506
        """
507

508 4
        if checks:
509 4
            assert self.initial_molecule == initial_molecule.id
510 4
            if self.qc_spec.keywords:
511 1
                assert self.qc_spec.keywords == qc_keywords.id
512

513 4
        qcinput_spec = self.qc_spec.form_schema_object(keywords=qc_keywords, checks=checks)
514 4
        qcinput_spec.pop("program", None)
515

516 4
        model = qcel.models.OptimizationInput(
517
            id=self.id,
518
            initial_molecule=initial_molecule,
519
            keywords=self.keywords,
520
            extras=self.extras,
521
            hash_index=self.hash_index,
522
            input_specification=qcinput_spec,
523
            protocols=self.protocols,
524
        )
525 4
        return model
526

527
    ## Standard function
528

529 4
    def get_final_energy(self) -> float:
530
        """The final energy of the geometry optimization.
531

532
        Returns
533
        -------
534
        float
535
            The optimization molecular energy.
536
        """
537 1
        return self.energies[-1]
538

539 4
    def get_trajectory(self) -> List[ResultRecord]:
540
        """Returns the Result records for each gradient evaluation in the trajectory.
541

542
        Returns
543
        -------
544
        List['ResultRecord']
545
            A ordered list of Result record gradient computations.
546

547
        """
548

549 1
        if "trajectory" not in self.cache:
550 1
            result = {x.id: x for x in self.client.query_results(id=self.trajectory)}
551

552 1
            self.cache["trajectory"] = [result[x] for x in self.trajectory]
553

554 1
        return self.cache["trajectory"]
555

556 4
    def get_molecular_trajectory(self) -> List["Molecule"]:
557
        """Returns the Molecule at each gradient evaluation in the trajectory.
558

559
        Returns
560
        -------
561
        List['Molecule']
562
            A ordered list of Molecules in the trajectory.
563

564
        """
565

566 1
        if "molecular_trajectory" not in self.cache:
567 1
            mol_ids = [x.molecule for x in self.get_trajectory()]
568

569 1
            mols = {x.id: x for x in self.client.query_molecules(id=mol_ids)}
570 1
            self.cache["molecular_trajectory"] = [mols[x] for x in mol_ids]
571

572 1
        return self.cache["molecular_trajectory"]
573

574 4
    def get_initial_molecule(self) -> "Molecule":
575
        """Returns the initial molecule
576

577
        Returns
578
        -------
579
        Molecule
580
            The initial molecule
581
        """
582

583 0
        ret = self.client.query_molecules(id=[self.initial_molecule])
584 0
        return ret[0]
585

586 4
    def get_final_molecule(self) -> "Molecule":
587
        """Returns the optimized molecule
588

589
        Returns
590
        -------
591
        Molecule
592
            The optimized molecule
593
        """
594

595 1
        ret = self.client.query_molecules(id=[self.final_molecule])
596 1
        return ret[0]
597

598
    ## Show functions
599

600 4
    def show_history(
601
        self, units: str = "kcal/mol", digits: int = 3, relative: bool = True, return_figure: Optional[bool] = None
602
    ) -> "plotly.Figure":
603
        """Plots the energy of the trajectory the optimization took.
604

605
        Parameters
606
        ----------
607
        units : str, optional
608
            Units to display the trajectory in.
609
        digits : int, optional
610
            The number of valid digits to show.
611
        relative : bool, optional
612
            If True, all energies are shifted by the lowest energy in the trajectory. Otherwise provides raw energies.
613
        return_figure : Optional[bool], optional
614
            If True, return the raw plotly figure. If False, returns a hosted iPlot. If None, return a iPlot display in
615
            Jupyter notebook and a raw plotly figure in all other circumstances.
616

617
        Returns
618
        -------
619
        plotly.Figure
620
            The requested figure.
621
        """
622 0
        cf = qcel.constants.conversion_factor("hartree", units)
623

624 0
        energies = np.array(self.energies)
625 0
        if relative:
626 0
            energies = energies - np.min(energies)
627

628 0
        trace = {"mode": "lines+markers", "x": list(range(1, len(energies) + 1)), "y": np.around(energies * cf, digits)}
629

630 0
        if relative:
631 0
            ylabel = f"Relative Energy [{units}]"
632
        else:
633 0
            ylabel = f"Absolute Energy [{units}]"
634

635 0
        custom_layout = {
636
            "title": "Geometry Optimization",
637
            "yaxis": {"title": ylabel, "zeroline": True},
638
            "xaxis": {
639
                "title": "Optimization Step",
640
                # "zeroline": False,
641
                "range": [min(trace["x"]), max(trace["x"])],
642
            },
643
        }
644

645 0
        return scatter_plot([trace], custom_layout=custom_layout, return_figure=return_figure)

Read our documentation on viewing source code .

Loading