1
|
|
"""
|
2
|
|
A model for Compute Records
|
3
|
|
"""
|
4
|
|
|
5
|
14
|
import abc
|
6
|
14
|
import datetime
|
7
|
14
|
from enum import Enum
|
8
|
14
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Union
|
9
|
|
|
10
|
14
|
import numpy as np
|
11
|
14
|
import qcelemental as qcel
|
12
|
14
|
from pydantic import Field, constr, validator
|
13
|
|
|
14
|
14
|
from ..visualization import scatter_plot
|
15
|
14
|
from .common_models import DriverEnum, ObjectId, ProtoModel, QCSpecification
|
16
|
14
|
from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer
|
17
|
|
|
18
|
|
if TYPE_CHECKING: # pragma: no cover
|
19
|
|
from qcelemental.models import OptimizationInput, ResultInput
|
20
|
|
|
21
|
|
from .common_models import KeywordSet, Molecule
|
22
|
|
|
23
|
14
|
__all__ = ["OptimizationRecord", "ResultRecord", "OptimizationRecord", "RecordBase"]
|
24
|
|
|
25
|
|
|
26
|
14
|
class RecordStatusEnum(str, Enum):
|
27
|
|
"""
|
28
|
|
The state of a record object. The states which are available are a finite set.
|
29
|
|
"""
|
30
|
|
|
31
|
14
|
complete = "COMPLETE"
|
32
|
14
|
incomplete = "INCOMPLETE"
|
33
|
14
|
running = "RUNNING"
|
34
|
14
|
error = "ERROR"
|
35
|
|
|
36
|
|
|
37
|
14
|
class RecordBase(ProtoModel, abc.ABC):
|
38
|
|
"""
|
39
|
|
A BaseRecord object for Result and Procedure records. Contains all basic
|
40
|
|
fields common to the all records.
|
41
|
|
"""
|
42
|
|
|
43
|
|
# Classdata
|
44
|
14
|
_hash_indices: Set[str]
|
45
|
|
|
46
|
|
# Helper data
|
47
|
14
|
client: Any = Field(None, description="The client object which the records are fetched from.")
|
48
|
14
|
cache: Dict[str, Any] = Field(
|
49
|
|
{},
|
50
|
|
description="Object cache from expensive queries. It should be very rare that this needs to be set manually "
|
51
|
|
"by the user.",
|
52
|
|
)
|
53
|
|
|
54
|
|
# Base identification
|
55
|
14
|
id: ObjectId = Field(
|
56
|
|
None, description="Id of the object on the database. This is assigned automatically by the database."
|
57
|
|
)
|
58
|
14
|
hash_index: Optional[str] = Field(
|
59
|
|
None, description="Hash of this object used to detect duplication and collisions in the database."
|
60
|
|
)
|
61
|
14
|
procedure: str = Field(..., description="Name of the procedure which this Record targets.")
|
62
|
14
|
program: str = Field(
|
63
|
|
...,
|
64
|
|
description="The quantum chemistry program which carries out the individual quantum chemistry calculations.",
|
65
|
|
)
|
66
|
14
|
version: int = Field(..., description="The version of this record object describes.")
|
67
|
14
|
protocols: Optional[Dict[str, Any]] = Field(
|
68
|
|
None, description="Protocols that change the data stored in top level fields."
|
69
|
|
)
|
70
|
|
|
71
|
|
# Extra fields
|
72
|
14
|
extras: Dict[str, Any] = Field({}, description="Extra information to associate with this record.")
|
73
|
14
|
stdout: Optional[ObjectId] = Field(
|
74
|
|
None,
|
75
|
|
description="The Id of the stdout data stored in the database which was used to generate this record from the "
|
76
|
|
"various programs which were called in the process.",
|
77
|
|
)
|
78
|
14
|
stderr: Optional[ObjectId] = Field(
|
79
|
|
None,
|
80
|
|
description="The Id of the stderr data stored in the database which was used to generate this record from the "
|
81
|
|
"various programs which were called in the process.",
|
82
|
|
)
|
83
|
14
|
error: Optional[ObjectId] = Field(
|
84
|
|
None,
|
85
|
|
description="The Id of the error data stored in the database in the event that an error was generated in the "
|
86
|
|
"process of carrying out the process this record targets. If no errors were raised, this field "
|
87
|
|
"will be empty.",
|
88
|
|
)
|
89
|
|
|
90
|
|
# Compute status
|
91
|
14
|
manager_name: Optional[str] = Field(None, description="Name of the Queue Manager which generated this record.")
|
92
|
14
|
status: RecordStatusEnum = Field(RecordStatusEnum.incomplete, description=str(RecordStatusEnum.__doc__))
|
93
|
14
|
modified_on: datetime.datetime = Field(None, description="Last time the data this record points to was modified.")
|
94
|
14
|
created_on: datetime.datetime = Field(None, description="Time the data this record points to was first created.")
|
95
|
|
|
96
|
|
# Carry-ons
|
97
|
14
|
provenance: Optional[qcel.models.Provenance] = Field(
|
98
|
|
None,
|
99
|
|
description="Provenance information tied to the creation of this record. This includes things such as every "
|
100
|
|
"program which was involved in generating the data for this record.",
|
101
|
|
)
|
102
|
|
|
103
|
14
|
class Config(ProtoModel.Config):
|
104
|
14
|
build_hash_index = True
|
105
|
|
|
106
|
14
|
@validator("program")
|
107
|
6
|
def check_program(cls, v):
|
108
|
14
|
return v.lower()
|
109
|
|
|
110
|
14
|
def __init__(self, **data):
|
111
|
|
|
112
|
|
# Set datetime defaults if not automatically available
|
113
|
14
|
data.setdefault("modified_on", datetime.datetime.utcnow())
|
114
|
14
|
data.setdefault("created_on", datetime.datetime.utcnow())
|
115
|
|
|
116
|
14
|
super().__init__(**data)
|
117
|
|
|
118
|
|
# Set hash index if not present
|
119
|
14
|
if self.Config.build_hash_index and (self.hash_index is None):
|
120
|
14
|
self.__dict__["hash_index"] = self.get_hash_index()
|
121
|
|
|
122
|
14
|
def __repr_args__(self):
|
123
|
|
|
124
|
1
|
return [("id", f"{self.id}"), ("status", f"{self.status}")]
|
125
|
|
|
126
|
|
### Serialization helpers
|
127
|
|
|
128
|
14
|
@classmethod
|
129
|
14
|
def get_hash_fields(cls) -> Set[str]:
|
130
|
|
"""Provides a description of the fields to be used in the hash
|
131
|
|
that uniquely defines this object.
|
132
|
|
|
133
|
|
Returns
|
134
|
|
-------
|
135
|
|
Set[str]
|
136
|
|
A list of all fields that are used in the hash.
|
137
|
|
|
138
|
|
"""
|
139
|
14
|
return cls._hash_indices | {"procedure", "program"}
|
140
|
|
|
141
|
14
|
def get_hash_index(self) -> str:
|
142
|
|
"""Builds (or rebuilds) the hash of this
|
143
|
|
object using the internally known hash fields.
|
144
|
|
|
145
|
|
Returns
|
146
|
|
-------
|
147
|
|
str
|
148
|
|
The objects unique hash index.
|
149
|
|
"""
|
150
|
14
|
data = self.dict(include=self.get_hash_fields(), encoding="json")
|
151
|
|
|
152
|
14
|
return hash_dictionary(data)
|
153
|
|
|
154
|
14
|
def dict(self, *args, **kwargs):
|
155
|
14
|
kwargs["exclude"] = (kwargs.pop("exclude", None) or set()) | {"client", "cache"}
|
156
|
|
# kwargs["skip_defaults"] = True
|
157
|
14
|
return super().dict(*args, **kwargs)
|
158
|
|
|
159
|
|
### Checkers
|
160
|
|
|
161
|
14
|
def check_client(self, noraise: bool = False) -> bool:
|
162
|
|
"""Checks whether this object owns a FractalClient or not.
|
163
|
|
This is often done so that objects pulled from a server using
|
164
|
|
a FractalClient still posses a connection to the server so that
|
165
|
|
additional data related to this object can be queried.
|
166
|
|
|
167
|
|
Raises
|
168
|
|
------
|
169
|
|
ValueError
|
170
|
|
If this object does not contain own a client.
|
171
|
|
|
172
|
|
Parameters
|
173
|
|
----------
|
174
|
|
noraise : bool, optional
|
175
|
|
Does not raise an error if this is True and instead returns
|
176
|
|
a boolean depending if a client exists or not.
|
177
|
|
|
178
|
|
Returns
|
179
|
|
-------
|
180
|
|
bool
|
181
|
|
If True, the object owns a connection to a server. False otherwise.
|
182
|
|
"""
|
183
|
6
|
if self.client is None:
|
184
|
0
|
if noraise:
|
185
|
0
|
return False
|
186
|
|
|
187
|
0
|
raise ValueError("Requested method requires a client, but client was '{}'.".format(self.client))
|
188
|
|
|
189
|
6
|
return True
|
190
|
|
|
191
|
|
### KVStore Getters
|
192
|
|
|
193
|
14
|
def _kvstore_getter(self, field_name):
|
194
|
|
"""
|
195
|
|
Internal KVStore getting object
|
196
|
|
"""
|
197
|
6
|
self.check_client()
|
198
|
|
|
199
|
6
|
oid = self.__dict__[field_name]
|
200
|
6
|
if oid is None:
|
201
|
0
|
return None
|
202
|
|
|
203
|
6
|
if field_name not in self.cache:
|
204
|
|
# Decompress here, rather than later
|
205
|
|
# that way, it is decompressed in the cache
|
206
|
6
|
kv = self.client.query_kvstore([oid])[oid]
|
207
|
|
|
208
|
6
|
if field_name == "error":
|
209
|
6
|
self.cache[field_name] = kv.get_json()
|
210
|
|
else:
|
211
|
1
|
self.cache[field_name] = kv.get_string()
|
212
|
|
|
213
|
6
|
return self.cache[field_name]
|
214
|
|
|
215
|
14
|
def get_stdout(self) -> Optional[str]:
|
216
|
|
"""Pulls the stdout from the denormalized KVStore and returns it to the user.
|
217
|
|
|
218
|
|
Returns
|
219
|
|
-------
|
220
|
|
Optional[str]
|
221
|
|
The requested stdout, none if no stdout present.
|
222
|
|
"""
|
223
|
1
|
return self._kvstore_getter("stdout")
|
224
|
|
|
225
|
14
|
def get_stderr(self) -> Optional[str]:
|
226
|
|
"""Pulls the stderr from the denormalized KVStore and returns it to the user.
|
227
|
|
|
228
|
|
Returns
|
229
|
|
-------
|
230
|
|
Optional[str]
|
231
|
|
The requested stderr, none if no stderr present.
|
232
|
|
"""
|
233
|
|
|
234
|
0
|
return self._kvstore_getter("stderr")
|
235
|
|
|
236
|
14
|
def get_error(self) -> Optional[qcel.models.ComputeError]:
|
237
|
|
"""Pulls the stderr from the denormalized KVStore and returns it to the user.
|
238
|
|
|
239
|
|
Returns
|
240
|
|
-------
|
241
|
|
Optional[qcel.models.ComputeError]
|
242
|
|
The requested compute error, none if no error present.
|
243
|
|
"""
|
244
|
6
|
value = self._kvstore_getter("error")
|
245
|
6
|
if value:
|
246
|
6
|
return qcel.models.ComputeError(**value)
|
247
|
|
else:
|
248
|
0
|
return value
|
249
|
|
|
250
|
|
|
251
|
14
|
class ResultRecord(RecordBase):
|
252
|
|
|
253
|
|
# Classdata
|
254
|
14
|
_hash_indices = {"driver", "method", "basis", "molecule", "keywords", "program"}
|
255
|
|
|
256
|
|
# Version data
|
257
|
14
|
version: int = Field(1, description="Version of the ResultRecord Model which this data was created with.")
|
258
|
14
|
procedure: constr(strip_whitespace=True, regex="single") = Field(
|
259
|
|
"single", description='Procedure is fixed as "single" because this is single quantum chemistry result.'
|
260
|
|
)
|
261
|
|
|
262
|
|
# Input data
|
263
|
14
|
driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__))
|
264
|
14
|
method: str = Field(..., description="The quantum chemistry method the driver runs with.")
|
265
|
14
|
molecule: ObjectId = Field(
|
266
|
|
..., description="The Id of the molecule in the Database which the result is computed on."
|
267
|
|
)
|
268
|
14
|
basis: Optional[str] = Field(
|
269
|
|
None,
|
270
|
|
description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
|
271
|
|
"methods without basis sets.",
|
272
|
|
)
|
273
|
14
|
keywords: Optional[ObjectId] = Field(
|
274
|
|
None,
|
275
|
|
description="The Id of the :class:`KeywordSet` which was passed into the quantum chemistry program that "
|
276
|
|
"performed this calculation.",
|
277
|
|
)
|
278
|
14
|
protocols: Optional[qcel.models.results.ResultProtocols] = Field(
|
279
|
|
qcel.models.results.ResultProtocols(), description=""
|
280
|
|
)
|
281
|
|
|
282
|
|
# Output data
|
283
|
14
|
return_result: Union[float, qcel.models.types.Array[float], Dict[str, Any]] = Field(
|
284
|
|
None, description="The primary result of the calculation, output is a function of the specified ``driver``."
|
285
|
|
)
|
286
|
14
|
properties: qcel.models.ResultProperties = Field(
|
287
|
|
None, description="Additional data and results computed as part of the ``return_result``."
|
288
|
|
)
|
289
|
14
|
wavefunction: Optional[Dict[str, Any]] = Field(None, description="Wavefunction data generated by the Result.")
|
290
|
14
|
wavefunction_data_id: Optional[ObjectId] = Field(None, description="The id of the wavefunction")
|
291
|
|
|
292
|
14
|
class Config(RecordBase.Config):
|
293
|
|
"""A hash index is not used for ResultRecords as they can be
|
294
|
|
uniquely determined with queryable keys.
|
295
|
|
"""
|
296
|
|
|
297
|
14
|
build_hash_index = False
|
298
|
|
|
299
|
14
|
@validator("method")
|
300
|
6
|
def check_method(cls, v):
|
301
|
|
"""Methods should have a lower string to match the database."""
|
302
|
14
|
return v.lower()
|
303
|
|
|
304
|
14
|
@validator("basis")
|
305
|
6
|
def check_basis(cls, v):
|
306
|
14
|
return prepare_basis(v)
|
307
|
|
|
308
|
14
|
def get_wavefunction(self, key: Union[str, List[str]]) -> Any:
|
309
|
|
"""
|
310
|
|
Pulls down the Wavefunction data associated with the computation.
|
311
|
|
"""
|
312
|
|
|
313
|
1
|
if self.wavefunction is None:
|
314
|
0
|
raise AttributeError("This Record was not computed with Wavefunction data.")
|
315
|
|
|
316
|
1
|
single_return = False
|
317
|
1
|
if isinstance(key, str):
|
318
|
1
|
key = [key]
|
319
|
1
|
single_return = True
|
320
|
|
|
321
|
1
|
keys = [x.lower() for x in key]
|
322
|
|
|
323
|
1
|
self.cache.setdefault("wavefunction", {})
|
324
|
|
|
325
|
1
|
mapped_keys = {self.wavefunction["return_map"].get(x, x) for x in keys}
|
326
|
1
|
missing = mapped_keys - self.cache["wavefunction"].keys()
|
327
|
|
|
328
|
1
|
unknown = missing - set(self.wavefunction["available"] + ["basis", "restricted"])
|
329
|
1
|
if unknown:
|
330
|
0
|
raise KeyError(
|
331
|
|
f"Wavefunction Key(s) `{unknown}` not understood, available keys are: {self.wavefunction['available']}"
|
332
|
|
)
|
333
|
|
|
334
|
1
|
if missing:
|
335
|
|
|
336
|
|
# Translate a return value
|
337
|
1
|
proj = [self.wavefunction["return_map"].get(x, x) for x in missing]
|
338
|
|
|
339
|
1
|
self.cache["wavefunction"].update(
|
340
|
|
self.client.custom_query(
|
341
|
|
"wavefunctionstore", None, {"id": self.wavefunction_data_id}, meta={"include": proj}
|
342
|
|
)
|
343
|
|
)
|
344
|
|
|
345
|
1
|
if "basis" in missing:
|
346
|
1
|
self.cache["wavefunction"]["basis"] = qcel.models.BasisSet(**self.cache["wavefunction"]["basis"])
|
347
|
|
|
348
|
|
# Remap once more
|
349
|
1
|
ret = {}
|
350
|
1
|
for k in keys:
|
351
|
1
|
mkey = self.wavefunction["return_map"].get(k, k)
|
352
|
1
|
ret[k] = self.cache["wavefunction"][mkey]
|
353
|
|
|
354
|
1
|
if single_return:
|
355
|
1
|
return ret[keys[0]]
|
356
|
|
else:
|
357
|
1
|
return ret
|
358
|
|
|
359
|
14
|
def get_molecule(self) -> "Molecule":
|
360
|
|
"""
|
361
|
|
Pulls the Result's Molecule from the connected database.
|
362
|
|
|
363
|
|
Returns
|
364
|
|
-------
|
365
|
|
Molecule
|
366
|
|
The requested Molecule
|
367
|
|
"""
|
368
|
6
|
self.check_client()
|
369
|
|
|
370
|
6
|
if self.molecule is None:
|
371
|
0
|
return None
|
372
|
|
|
373
|
6
|
if "molecule" not in self.cache:
|
374
|
6
|
self.cache["molecule"] = self.client.query_molecules(id=self.molecule)[0]
|
375
|
|
|
376
|
6
|
return self.cache["molecule"]
|
377
|
|
|
378
|
|
|
379
|
14
|
class OptimizationRecord(RecordBase):
|
380
|
|
"""
|
381
|
|
A OptimizationRecord for all optimization procedure data.
|
382
|
|
"""
|
383
|
|
|
384
|
|
# Class data
|
385
|
14
|
_hash_indices = {"initial_molecule", "keywords", "qc_spec"}
|
386
|
|
|
387
|
|
# Version data
|
388
|
14
|
version: int = Field(1, description="Version of the OptimizationRecord Model which this data was created with.")
|
389
|
14
|
procedure: constr(strip_whitespace=True, regex="optimization") = Field(
|
390
|
|
"optimization", description='A fixed string indication this is a record for an "Optimization".'
|
391
|
|
)
|
392
|
14
|
schema_version: int = Field(1, description="The version number of QCSchema under which this record conforms to.")
|
393
|
|
|
394
|
|
# Input data
|
395
|
14
|
initial_molecule: ObjectId = Field(
|
396
|
|
..., description="The Id of the molecule which was passed in as the reference for this Optimization."
|
397
|
|
)
|
398
|
14
|
qc_spec: QCSpecification = Field(
|
399
|
|
..., description="The specification of the quantum chemistry calculation to run at each point."
|
400
|
|
)
|
401
|
14
|
keywords: Dict[str, Any] = Field(
|
402
|
|
{},
|
403
|
|
description="The keyword options which were passed into the Optimization program. "
|
404
|
|
"Note: These are a dictionary and not a :class:`KeywordSet` object.",
|
405
|
|
)
|
406
|
14
|
protocols: Optional[qcel.models.procedures.OptimizationProtocols] = Field(
|
407
|
|
qcel.models.procedures.OptimizationProtocols(), description=""
|
408
|
|
)
|
409
|
|
|
410
|
|
# Automatting issue currently
|
411
|
|
# description=str(qcel.models.procedures.OptimizationProtocols.__doc__))
|
412
|
|
|
413
|
|
# Results
|
414
|
14
|
energies: List[float] = Field(None, description="The ordered list of energies at each step of the Optimization.")
|
415
|
14
|
final_molecule: ObjectId = Field(
|
416
|
|
None, description="The ``ObjectId`` of the final, optimized Molecule the Optimization procedure converged to."
|
417
|
|
)
|
418
|
14
|
trajectory: List[ObjectId] = Field(
|
419
|
|
None,
|
420
|
|
description="The list of Molecule Id's the Optimization procedure generated at each step of the optimization."
|
421
|
|
"``initial_molecule`` will be the first index, and ``final_molecule`` will be the last index.",
|
422
|
|
)
|
423
|
|
|
424
|
14
|
class Config(RecordBase.Config):
|
425
|
14
|
pass
|
426
|
|
|
427
|
14
|
@validator("keywords")
|
428
|
6
|
def check_keywords(cls, v):
|
429
|
14
|
if v is not None:
|
430
|
14
|
v = recursive_normalizer(v)
|
431
|
14
|
return v
|
432
|
|
|
433
|
|
## Standard function
|
434
|
|
|
435
|
14
|
def get_final_energy(self) -> float:
|
436
|
|
"""The final energy of the geometry optimization.
|
437
|
|
|
438
|
|
Returns
|
439
|
|
-------
|
440
|
|
float
|
441
|
|
The optimization molecular energy.
|
442
|
|
"""
|
443
|
1
|
return self.energies[-1]
|
444
|
|
|
445
|
14
|
def get_trajectory(self) -> List[ResultRecord]:
|
446
|
|
"""Returns the Result records for each gradient evaluation in the trajectory.
|
447
|
|
|
448
|
|
Returns
|
449
|
|
-------
|
450
|
|
List['ResultRecord']
|
451
|
|
A ordered list of Result record gradient computations.
|
452
|
|
|
453
|
|
"""
|
454
|
|
|
455
|
1
|
if "trajectory" not in self.cache:
|
456
|
1
|
result = {x.id: x for x in self.client.query_results(id=self.trajectory)}
|
457
|
|
|
458
|
1
|
self.cache["trajectory"] = [result[x] for x in self.trajectory]
|
459
|
|
|
460
|
1
|
return self.cache["trajectory"]
|
461
|
|
|
462
|
14
|
def get_molecular_trajectory(self) -> List["Molecule"]:
|
463
|
|
"""Returns the Molecule at each gradient evaluation in the trajectory.
|
464
|
|
|
465
|
|
Returns
|
466
|
|
-------
|
467
|
|
List['Molecule']
|
468
|
|
A ordered list of Molecules in the trajectory.
|
469
|
|
|
470
|
|
"""
|
471
|
|
|
472
|
1
|
if "molecular_trajectory" not in self.cache:
|
473
|
1
|
mol_ids = [x.molecule for x in self.get_trajectory()]
|
474
|
|
|
475
|
1
|
mols = {x.id: x for x in self.client.query_molecules(id=mol_ids)}
|
476
|
1
|
self.cache["molecular_trajectory"] = [mols[x] for x in mol_ids]
|
477
|
|
|
478
|
1
|
return self.cache["molecular_trajectory"]
|
479
|
|
|
480
|
14
|
def get_initial_molecule(self) -> "Molecule":
|
481
|
|
"""Returns the initial molecule
|
482
|
|
|
483
|
|
Returns
|
484
|
|
-------
|
485
|
|
Molecule
|
486
|
|
The initial molecule
|
487
|
|
"""
|
488
|
|
|
489
|
0
|
ret = self.client.query_molecules(id=[self.initial_molecule])
|
490
|
0
|
return ret[0]
|
491
|
|
|
492
|
14
|
def get_final_molecule(self) -> "Molecule":
|
493
|
|
"""Returns the optimized molecule
|
494
|
|
|
495
|
|
Returns
|
496
|
|
-------
|
497
|
|
Molecule
|
498
|
|
The optimized molecule
|
499
|
|
"""
|
500
|
|
|
501
|
1
|
ret = self.client.query_molecules(id=[self.final_molecule])
|
502
|
1
|
return ret[0]
|
503
|
|
|
504
|
|
## Show functions
|
505
|
|
|
506
|
14
|
def show_history(
|
507
|
|
self, units: str = "kcal/mol", digits: int = 3, relative: bool = True, return_figure: Optional[bool] = None
|
508
|
|
) -> "plotly.Figure":
|
509
|
|
"""Plots the energy of the trajectory the optimization took.
|
510
|
|
|
511
|
|
Parameters
|
512
|
|
----------
|
513
|
|
units : str, optional
|
514
|
|
Units to display the trajectory in.
|
515
|
|
digits : int, optional
|
516
|
|
The number of valid digits to show.
|
517
|
|
relative : bool, optional
|
518
|
|
If True, all energies are shifted by the lowest energy in the trajectory. Otherwise provides raw energies.
|
519
|
|
return_figure : Optional[bool], optional
|
520
|
|
If True, return the raw plotly figure. If False, returns a hosted iPlot. If None, return a iPlot display in
|
521
|
|
Jupyter notebook and a raw plotly figure in all other circumstances.
|
522
|
|
|
523
|
|
Returns
|
524
|
|
-------
|
525
|
|
plotly.Figure
|
526
|
|
The requested figure.
|
527
|
|
"""
|
528
|
0
|
cf = qcel.constants.conversion_factor("hartree", units)
|
529
|
|
|
530
|
0
|
energies = np.array(self.energies)
|
531
|
0
|
if relative:
|
532
|
0
|
energies = energies - np.min(energies)
|
533
|
|
|
534
|
0
|
trace = {"mode": "lines+markers", "x": list(range(1, len(energies) + 1)), "y": np.around(energies * cf, digits)}
|
535
|
|
|
536
|
0
|
if relative:
|
537
|
0
|
ylabel = f"Relative Energy [{units}]"
|
538
|
|
else:
|
539
|
0
|
ylabel = f"Absolute Energy [{units}]"
|
540
|
|
|
541
|
0
|
custom_layout = {
|
542
|
|
"title": "Geometry Optimization",
|
543
|
|
"yaxis": {"title": ylabel, "zeroline": True},
|
544
|
|
"xaxis": {
|
545
|
|
"title": "Optimization Step",
|
546
|
|
# "zeroline": False,
|
547
|
|
"range": [min(trace["x"]), max(trace["x"])],
|
548
|
|
},
|
549
|
|
}
|
550
|
|
|
551
|
0
|
return scatter_plot([trace], custom_layout=custom_layout, return_figure=return_figure)
|