MolSSI / QCFractal
1
"""
2
A model for Compute Records
3
"""
4

5 14
import abc
6 14
import datetime
7 14
from enum import Enum
8 14
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
9

10 14
import numpy as np
11 14
import qcelemental as qcel
12 14
from pydantic import Field, constr, validator
13

14 14
from ..visualization import scatter_plot
15 14
from .common_models import DriverEnum, ObjectId, ProtoModel, QCSpecification
16 14
from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer
17

18
if TYPE_CHECKING:  # pragma: no cover
19
    from qcelemental.models import OptimizationInput, ResultInput
20

21
    from .common_models import KeywordSet, Molecule
22

23 14
__all__ = ["OptimizationRecord", "ResultRecord", "OptimizationRecord", "RecordBase"]
24

25

26 14
class RecordStatusEnum(str, Enum):
27
    """
28
    The state of a record object. The states which are available are a finite set.
29
    """
30

31 14
    complete = "COMPLETE"
32 14
    incomplete = "INCOMPLETE"
33 14
    running = "RUNNING"
34 14
    error = "ERROR"
35

36

37 14
class RecordBase(ProtoModel, abc.ABC):
38
    """
39
    A BaseRecord object for Result and Procedure records. Contains all basic
40
    fields common to the all records.
41
    """
42

43
    # Classdata
44 14
    _hash_indices: Set[str]
45

46
    # Helper data
47 14
    client: Any = Field(None, description="The client object which the records are fetched from.")
48 14
    cache: Dict[str, Any] = Field(
49
        {},
50
        description="Object cache from expensive queries. It should be very rare that this needs to be set manually "
51
        "by the user.",
52
    )
53

54
    # Base identification
55 14
    id: ObjectId = Field(
56
        None, description="Id of the object on the database. This is assigned automatically by the database."
57
    )
58 14
    hash_index: Optional[str] = Field(
59
        None, description="Hash of this object used to detect duplication and collisions in the database."
60
    )
61 14
    procedure: str = Field(..., description="Name of the procedure which this Record targets.")
62 14
    program: str = Field(
63
        ...,
64
        description="The quantum chemistry program which carries out the individual quantum chemistry calculations.",
65
    )
66 14
    version: int = Field(..., description="The version of this record object describes.")
67 14
    protocols: Optional[Dict[str, Any]] = Field(
68
        None, description="Protocols that change the data stored in top level fields."
69
    )
70

71
    # Extra fields
72 14
    extras: Dict[str, Any] = Field({}, description="Extra information to associate with this record.")
73 14
    stdout: Optional[ObjectId] = Field(
74
        None,
75
        description="The Id of the stdout data stored in the database which was used to generate this record from the "
76
        "various programs which were called in the process.",
77
    )
78 14
    stderr: Optional[ObjectId] = Field(
79
        None,
80
        description="The Id of the stderr data stored in the database which was used to generate this record from the "
81
        "various programs which were called in the process.",
82
    )
83 14
    error: Optional[ObjectId] = Field(
84
        None,
85
        description="The Id of the error data stored in the database in the event that an error was generated in the "
86
        "process of carrying out the process this record targets. If no errors were raised, this field "
87
        "will be empty.",
88
    )
89

90
    # Compute status
91 14
    manager_name: Optional[str] = Field(None, description="Name of the Queue Manager which generated this record.")
92 14
    status: RecordStatusEnum = Field(RecordStatusEnum.incomplete, description=str(RecordStatusEnum.__doc__))
93 14
    modified_on: datetime.datetime = Field(None, description="Last time the data this record points to was modified.")
94 14
    created_on: datetime.datetime = Field(None, description="Time the data this record points to was first created.")
95

96
    # Carry-ons
97 14
    provenance: Optional[qcel.models.Provenance] = Field(
98
        None,
99
        description="Provenance information tied to the creation of this record. This includes things such as every "
100
        "program which was involved in generating the data for this record.",
101
    )
102

103 14
    class Config(ProtoModel.Config):
104 14
        build_hash_index = True
105

106 14
    @validator("program")
107 6
    def check_program(cls, v):
108 14
        return v.lower()
109

110 14
    def __init__(self, **data):
111

112
        # Set datetime defaults if not automatically available
113 14
        data.setdefault("modified_on", datetime.datetime.utcnow())
114 14
        data.setdefault("created_on", datetime.datetime.utcnow())
115

116 14
        super().__init__(**data)
117

118
        # Set hash index if not present
119 14
        if self.Config.build_hash_index and (self.hash_index is None):
120 14
            self.__dict__["hash_index"] = self.get_hash_index()
121

122 14
    def __repr_args__(self):
123

124 1
        return [("id", f"{self.id}"), ("status", f"{self.status}")]
125

126
    ### Serialization helpers
127

128 14
    @classmethod
129 14
    def get_hash_fields(cls) -> Set[str]:
130
        """Provides a description of the fields to be used in the hash
131
        that uniquely defines this object.
132

133
        Returns
134
        -------
135
        Set[str]
136
            A list of all fields that are used in the hash.
137

138
        """
139 14
        return cls._hash_indices | {"procedure", "program"}
140

141 14
    def get_hash_index(self) -> str:
142
        """Builds (or rebuilds) the hash of this
143
        object using the internally known hash fields.
144

145
        Returns
146
        -------
147
        str
148
            The objects unique hash index.
149
        """
150 14
        data = self.dict(include=self.get_hash_fields(), encoding="json")
151

152 14
        return hash_dictionary(data)
153

154 14
    def dict(self, *args, **kwargs):
155 14
        kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"client", "cache"}
156
        # kwargs["skip_defaults"] = True
157 14
        return super().dict(*args, **kwargs)
158

159
    ### Checkers
160

161 14
    def check_client(self, noraise: bool = False) -> bool:
162
        """Checks whether this object owns a FractalClient or not.
163
        This is often done so that objects pulled from a server using
164
        a FractalClient still posses a connection to the server so that
165
        additional data related to this object can be queried.
166

167
        Raises
168
        ------
169
        ValueError
170
            If this object does not contain own a client.
171

172
        Parameters
173
        ----------
174
        noraise : bool, optional
175
            Does not raise an error if this is True and instead returns
176
            a boolean depending if a client exists or not.
177

178
        Returns
179
        -------
180
        bool
181
            If True, the object owns a connection to a server. False otherwise.
182
        """
183 6
        if self.client is None:
184 0
            if noraise:
185 0
                return False
186

187 0
            raise ValueError("Requested method requires a client, but client was '{}'.".format(self.client))
188

189 6
        return True
190

191
    ### KVStore Getters
192

193 14
    def _kvstore_getter(self, field_name):
194
        """
195
        Internal KVStore getting object
196
        """
197 6
        self.check_client()
198

199 6
        oid = self.__dict__[field_name]
200 6
        if oid is None:
201 0
            return None
202

203 6
        if field_name not in self.cache:
204
            # Decompress here, rather than later
205
            # that way, it is decompressed in the cache
206 6
            kv = self.client.query_kvstore([oid])[oid]
207

208 6
            if field_name == "error":
209 6
                self.cache[field_name] = kv.get_json()
210
            else:
211 1
                self.cache[field_name] = kv.get_string()
212

213 6
        return self.cache[field_name]
214

215 14
    def get_stdout(self) -> Optional[str]:
216
        """Pulls the stdout from the denormalized KVStore and returns it to the user.
217

218
        Returns
219
        -------
220
        Optional[str]
221
            The requested stdout, none if no stdout present.
222
        """
223 1
        return self._kvstore_getter("stdout")
224

225 14
    def get_stderr(self) -> Optional[str]:
226
        """Pulls the stderr from the denormalized KVStore and returns it to the user.
227

228
        Returns
229
        -------
230
        Optional[str]
231
            The requested stderr, none if no stderr present.
232
        """
233

234 0
        return self._kvstore_getter("stderr")
235

236 14
    def get_error(self) -> Optional[qcel.models.ComputeError]:
237
        """Pulls the stderr from the denormalized KVStore and returns it to the user.
238

239
        Returns
240
        -------
241
        Optional[qcel.models.ComputeError]
242
            The requested compute error, none if no error present.
243
        """
244 6
        value = self._kvstore_getter("error")
245 6
        if value:
246 6
            return qcel.models.ComputeError(**value)
247
        else:
248 0
            return value
249

250

251 14
class ResultRecord(RecordBase):
252

253
    # Classdata
254 14
    _hash_indices = {"driver", "method", "basis", "molecule", "keywords", "program"}
255

256
    # Version data
257 14
    version: int = Field(1, description="Version of the ResultRecord Model which this data was created with.")
258 14
    procedure: constr(strip_whitespace=True, regex="single") = Field(
259
        "single", description='Procedure is fixed as "single" because this is single quantum chemistry result.'
260
    )
261

262
    # Input data
263 14
    driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__))
264 14
    method: str = Field(..., description="The quantum chemistry method the driver runs with.")
265 14
    molecule: ObjectId = Field(
266
        ..., description="The Id of the molecule in the Database which the result is computed on."
267
    )
268 14
    basis: Optional[str] = Field(
269
        None,
270
        description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
271
        "methods without basis sets.",
272
    )
273 14
    keywords: Optional[ObjectId] = Field(
274
        None,
275
        description="The Id of the :class:`KeywordSet` which was passed into the quantum chemistry program that "
276
        "performed this calculation.",
277
    )
278 14
    protocols: Optional[qcel.models.results.ResultProtocols] = Field(
279
        qcel.models.results.ResultProtocols(), description=""
280
    )
281

282
    # Output data
283 14
    return_result: Union[float, qcel.models.types.Array[float], Dict[str, Any]] = Field(
284
        None, description="The primary result of the calculation, output is a function of the specified ``driver``."
285
    )
286 14
    properties: qcel.models.ResultProperties = Field(
287
        None, description="Additional data and results computed as part of the ``return_result``."
288
    )
289 14
    wavefunction: Optional[Dict[str, Any]] = Field(None, description="Wavefunction data generated by the Result.")
290 14
    wavefunction_data_id: Optional[ObjectId] = Field(None, description="The id of the wavefunction")
291

292 14
    class Config(RecordBase.Config):
293
        """A hash index is not used for ResultRecords as they can be
294
        uniquely determined with queryable keys.
295
        """
296

297 14
        build_hash_index = False
298

299 14
    @validator("method")
300 6
    def check_method(cls, v):
301
        """Methods should have a lower string to match the database."""
302 14
        return v.lower()
303

304 14
    @validator("basis")
305 6
    def check_basis(cls, v):
306 14
        return prepare_basis(v)
307

308 14
    def get_wavefunction(self, key: Union[str, List[str]]) -> Any:
309
        """
310
        Pulls down the Wavefunction data associated with the computation.
311
        """
312

313 1
        if self.wavefunction is None:
314 0
            raise AttributeError("This Record was not computed with Wavefunction data.")
315

316 1
        single_return = False
317 1
        if isinstance(key, str):
318 1
            key = [key]
319 1
            single_return = True
320

321 1
        keys = [x.lower() for x in key]
322

323 1
        self.cache.setdefault("wavefunction", {})
324

325 1
        mapped_keys = {self.wavefunction["return_map"].get(x, x) for x in keys}
326 1
        missing = mapped_keys - self.cache["wavefunction"].keys()
327

328 1
        unknown = missing - set(self.wavefunction["available"] + ["basis", "restricted"])
329 1
        if unknown:
330 0
            raise KeyError(
331
                f"Wavefunction Key(s) `{unknown}` not understood, available keys are: {self.wavefunction['available']}"
332
            )
333

334 1
        if missing:
335

336
            # Translate a return value
337 1
            proj = [self.wavefunction["return_map"].get(x, x) for x in missing]
338

339 1
            self.cache["wavefunction"].update(
340
                self.client.custom_query(
341
                    "wavefunctionstore", None, {"id": self.wavefunction_data_id}, meta={"include": proj}
342
                )
343
            )
344

345 1
            if "basis" in missing:
346 1
                self.cache["wavefunction"]["basis"] = qcel.models.BasisSet(**self.cache["wavefunction"]["basis"])
347

348
        # Remap once more
349 1
        ret = {}
350 1
        for k in keys:
351 1
            mkey = self.wavefunction["return_map"].get(k, k)
352 1
            ret[k] = self.cache["wavefunction"][mkey]
353

354 1
        if single_return:
355 1
            return ret[keys[0]]
356
        else:
357 1
            return ret
358

359 14
    def get_molecule(self) -> "Molecule":
360
        """
361
        Pulls the Result's Molecule from the connected database.
362

363
        Returns
364
        -------
365
        Molecule
366
            The requested Molecule
367
        """
368 6
        self.check_client()
369

370 6
        if self.molecule is None:
371 0
            return None
372

373 6
        if "molecule" not in self.cache:
374 6
            self.cache["molecule"] = self.client.query_molecules(id=self.molecule)[0]
375

376 6
        return self.cache["molecule"]
377

378

379 14
class OptimizationRecord(RecordBase):
380
    """
381
    A OptimizationRecord for all optimization procedure data.
382
    """
383

384
    # Class data
385 14
    _hash_indices = {"initial_molecule", "keywords", "qc_spec"}
386

387
    # Version data
388 14
    version: int = Field(1, description="Version of the OptimizationRecord Model which this data was created with.")
389 14
    procedure: constr(strip_whitespace=True, regex="optimization") = Field(
390
        "optimization", description='A fixed string indication this is a record for an "Optimization".'
391
    )
392 14
    schema_version: int = Field(1, description="The version number of QCSchema under which this record conforms to.")
393

394
    # Input data
395 14
    initial_molecule: ObjectId = Field(
396
        ..., description="The Id of the molecule which was passed in as the reference for this Optimization."
397
    )
398 14
    qc_spec: QCSpecification = Field(
399
        ..., description="The specification of the quantum chemistry calculation to run at each point."
400
    )
401 14
    keywords: Dict[str, Any] = Field(
402
        {},
403
        description="The keyword options which were passed into the Optimization program. "
404
        "Note: These are a dictionary and not a :class:`KeywordSet` object.",
405
    )
406 14
    protocols: Optional[qcel.models.procedures.OptimizationProtocols] = Field(
407
        qcel.models.procedures.OptimizationProtocols(), description=""
408
    )
409

410
    # Automatting issue currently
411
    # description=str(qcel.models.procedures.OptimizationProtocols.__doc__))
412

413
    # Results
414 14
    energies: List[float] = Field(None, description="The ordered list of energies at each step of the Optimization.")
415 14
    final_molecule: ObjectId = Field(
416
        None, description="The ``ObjectId`` of the final, optimized Molecule the Optimization procedure converged to."
417
    )
418 14
    trajectory: List[ObjectId] = Field(
419
        None,
420
        description="The list of Molecule Id's the Optimization procedure generated at each step of the optimization."
421
        "``initial_molecule`` will be the first index, and ``final_molecule`` will be the last index.",
422
    )
423

424 14
    class Config(RecordBase.Config):
425 14
        pass
426

427 14
    @validator("keywords")
428 6
    def check_keywords(cls, v):
429 14
        if v is not None:
430 14
            v = recursive_normalizer(v)
431 14
        return v
432

433
    ## Standard function
434

435 14
    def get_final_energy(self) -> float:
436
        """The final energy of the geometry optimization.
437

438
        Returns
439
        -------
440
        float
441
            The optimization molecular energy.
442
        """
443 1
        return self.energies[-1]
444

445 14
    def get_trajectory(self) -> List[ResultRecord]:
446
        """Returns the Result records for each gradient evaluation in the trajectory.
447

448
        Returns
449
        -------
450
        List['ResultRecord']
451
            A ordered list of Result record gradient computations.
452

453
        """
454

455 1
        if "trajectory" not in self.cache:
456 1
            result = {x.id: x for x in self.client.query_results(id=self.trajectory)}
457

458 1
            self.cache["trajectory"] = [result[x] for x in self.trajectory]
459

460 1
        return self.cache["trajectory"]
461

462 14
    def get_molecular_trajectory(self) -> List["Molecule"]:
463
        """Returns the Molecule at each gradient evaluation in the trajectory.
464

465
        Returns
466
        -------
467
        List['Molecule']
468
            A ordered list of Molecules in the trajectory.
469

470
        """
471

472 1
        if "molecular_trajectory" not in self.cache:
473 1
            mol_ids = [x.molecule for x in self.get_trajectory()]
474

475 1
            mols = {x.id: x for x in self.client.query_molecules(id=mol_ids)}
476 1
            self.cache["molecular_trajectory"] = [mols[x] for x in mol_ids]
477

478 1
        return self.cache["molecular_trajectory"]
479

480 14
    def get_initial_molecule(self) -> "Molecule":
481
        """Returns the initial molecule
482

483
        Returns
484
        -------
485
        Molecule
486
            The initial molecule
487
        """
488

489 0
        ret = self.client.query_molecules(id=[self.initial_molecule])
490 0
        return ret[0]
491

492 14
    def get_final_molecule(self) -> "Molecule":
493
        """Returns the optimized molecule
494

495
        Returns
496
        -------
497
        Molecule
498
            The optimized molecule
499
        """
500

501 1
        ret = self.client.query_molecules(id=[self.final_molecule])
502 1
        return ret[0]
503

504
    ## Show functions
505

506 14
    def show_history(
507
        self, units: str = "kcal/mol", digits: int = 3, relative: bool = True, return_figure: Optional[bool] = None
508
    ) -> "plotly.Figure":
509
        """Plots the energy of the trajectory the optimization took.
510

511
        Parameters
512
        ----------
513
        units : str, optional
514
            Units to display the trajectory in.
515
        digits : int, optional
516
            The number of valid digits to show.
517
        relative : bool, optional
518
            If True, all energies are shifted by the lowest energy in the trajectory. Otherwise provides raw energies.
519
        return_figure : Optional[bool], optional
520
            If True, return the raw plotly figure. If False, returns a hosted iPlot. If None, return a iPlot display in
521
            Jupyter notebook and a raw plotly figure in all other circumstances.
522

523
        Returns
524
        -------
525
        plotly.Figure
526
            The requested figure.
527
        """
528 0
        cf = qcel.constants.conversion_factor("hartree", units)
529

530 0
        energies = np.array(self.energies)
531 0
        if relative:
532 0
            energies = energies - np.min(energies)
533

534 0
        trace = {"mode": "lines+markers", "x": list(range(1, len(energies) + 1)), "y": np.around(energies * cf, digits)}
535

536 0
        if relative:
537 0
            ylabel = f"Relative Energy [{units}]"
538
        else:
539 0
            ylabel = f"Absolute Energy [{units}]"
540

541 0
        custom_layout = {
542
            "title": "Geometry Optimization",
543
            "yaxis": {"title": ylabel, "zeroline": True},
544
            "xaxis": {
545
                "title": "Optimization Step",
546
                # "zeroline": False,
547
                "range": [min(trace["x"]), max(trace["x"])],
548
            },
549
        }
550

551 0
        return scatter_plot([trace], custom_layout=custom_layout, return_figure=return_figure)

Read our documentation on viewing source code .

Loading