1
"""
2
Common models for QCPortal/Fractal
3
"""
4 4
import json
5

6
# For compression
7 4
import lzma
8 4
import bz2
9 4
import gzip
10

11 4
from enum import Enum
12 4
from typing import Any, Dict, Optional
13

14 4
from pydantic import Field, validator
15 4
from qcelemental.models import AutodocBaseSettings, Molecule, ProtoModel, Provenance
16 4
from qcelemental.models.procedures import OptimizationProtocols
17 4
from qcelemental.models.results import ResultProtocols
18

19 4
from .model_utils import hash_dictionary, prepare_basis, recursive_normalizer
20

21 4
__all__ = ["QCSpecification", "OptimizationSpecification", "KeywordSet", "ObjectId", "DriverEnum", "Citation"]
22

23
# Add in QCElemental models
24 4
__all__.extend(["Molecule", "Provenance", "ProtoModel"])
25

26
# Autodoc
27 4
__all__.extend(["OptimizationProtocols", "ResultProtocols"])
28

29

30 4
class ObjectId(str):
31
    """
32
    The Id of the object in the data.
33
    """
34

35 4
    _valid_hex = set("0123456789abcdef")
36

37 4
    @classmethod
38
    def __get_validators__(cls):
39 4
        yield cls.validate
40

41 4
    @classmethod
42
    def validate(cls, v):
43 4
        if isinstance(v, str) and (len(v) == 24) and (set(v) <= cls._valid_hex):
44 4
            return v
45 4
        elif isinstance(v, int):
46 4
            return str(v)
47 4
        elif isinstance(v, str) and v.isdigit():
48 4
            return v
49
        else:
50 4
            raise TypeError("The string {} is not a valid 24-character hexadecimal or integer ObjectId!".format(v))
51

52

53 4
class DriverEnum(str, Enum):
54
    """
55
    The type of calculation that is being performed (e.g., energy, gradient, Hessian, ...).
56
    """
57

58 4
    energy = "energy"
59 4
    gradient = "gradient"
60 4
    hessian = "hessian"
61 4
    properties = "properties"
62

63

64 4
class CompressionEnum(str, Enum):
65
    """
66
    How data is compressed (compression method only, ie gzip, bzip2)
67
    """
68

69 4
    none = "none"
70 4
    gzip = "gzip"
71 4
    bzip2 = "bzip2"
72 4
    lzma = "lzma"
73

74

75 4
class KVStore(ProtoModel):
76
    """
77
    Storage of outputs and error messages, with optional compression
78
    """
79

80 4
    id: int = Field(
81
        None, description="Id of the object on the database. This is assigned automatically by the database."
82
    )
83

84 4
    compression: CompressionEnum = Field(CompressionEnum.none, description="Compression method (such as gzip)")
85 4
    compression_level: int = Field(0, description="Level of compression (typically 0-9)")
86 4
    data: bytes = Field(..., description="Compressed raw data of output/errors, etc")
87

88 4
    @validator("data", pre=True)
89
    def _set_data(cls, data, values):
90
        """Handles special data types
91

92
        Strings are converted to byte arrays, and dicts are converted via json.dumps. If a string or
93
        dictionary is given, then compression & compression level must be none/0 (the defaults)
94

95
        Will chack that compression and compression level are None/0. Since this validator
96
        runs after all the others, that is safe.
97

98
        (According to pydantic docs, validators are run in the order of field definition)
99
        """
100 4
        if isinstance(data, dict):
101 4
            if values["compression"] != CompressionEnum.none:
102 4
                raise ValueError("Compression is set, but input is a dictionary")
103 4
            if values["compression_level"] != 0:
104 4
                raise ValueError("Compression level is set, but input is a dictionary")
105 2
            return json.dumps(data).encode()
106 4
        elif isinstance(data, str):
107 4
            if values["compression"] != CompressionEnum.none:
108 4
                raise ValueError("Compression is set, but input is a string")
109 4
            if values["compression_level"] != 0:
110 4
                raise ValueError("Compression level is set, but input is a string")
111 4
            return data.encode()
112
        else:
113 4
            return data
114

115 4
    @validator("compression", pre=True)
116
    def _set_compression(cls, compression):
117
        """Sets the compression type to CompressionEnum.none if compression is None
118

119
        Needed as older entries in the database have null for compression/compression_level
120
        """
121 4
        if compression is None:
122 0
            return CompressionEnum.none
123
        else:
124 4
            return compression
125

126 4
    @validator("compression_level", pre=True)
127
    def _set_compression_level(cls, compression_level):
128
        """Sets the compression_level to zero if compression is None
129

130
        Needed as older entries in the database have null for compression/compression_level
131
        """
132 4
        if compression_level is None:
133 0
            return 0
134
        else:
135 4
            return compression_level
136

137 4
    @classmethod
138 4
    def compress(
139
        cls,
140
        input_str: str,
141
        compression_type: CompressionEnum = CompressionEnum.none,
142
        compression_level: Optional[int] = None,
143
    ):
144
        """Compresses a string given a compression scheme and level
145

146
        Returns an object of type `cls`
147

148
        If compression_level is None, but a compression_type is specified, an appropriate default level is chosen
149
        """
150

151 4
        data = input_str.encode()
152

153
        # No compression
154 4
        if compression_type is CompressionEnum.none:
155 4
            compression_level = 0
156

157
        # gzip compression
158 4
        elif compression_type is CompressionEnum.gzip:
159 4
            if compression_level is None:
160 4
                compression_level = 6
161 4
            data = gzip.compress(data, compresslevel=compression_level)
162

163
        # bzip2 compression
164 4
        elif compression_type is CompressionEnum.bzip2:
165 4
            if compression_level is None:
166 4
                compression_level = 6
167 4
            data = bz2.compress(data, compresslevel=compression_level)
168

169
        # LZMA compression
170
        # By default, use level = 1 for larger files (>15MB or so)
171 4
        elif compression_type is CompressionEnum.lzma:
172 4
            if compression_level is None:
173 4
                if len(data) > 15 * 1048576:
174 0
                    compression_level = 1
175
                else:
176 4
                    compression_level = 6
177 4
            data = lzma.compress(data, preset=compression_level)
178
        else:
179
            # Shouldn't ever happen, unless we change CompressionEnum but not the rest of this function
180 0
            raise TypeError("Unknown compression type??")
181

182 4
        return cls(data=data, compression=compression_type, compression_level=compression_level)
183

184 4
    def get_string(self):
185
        """
186
        Returns the string representing the output
187
        """
188 4
        if self.compression is CompressionEnum.none:
189 4
            return self.data.decode()
190 4
        elif self.compression is CompressionEnum.gzip:
191 4
            return gzip.decompress(self.data).decode()
192 4
        elif self.compression is CompressionEnum.bzip2:
193 4
            return bz2.decompress(self.data).decode()
194 4
        elif self.compression is CompressionEnum.lzma:
195 4
            return lzma.decompress(self.data).decode()
196
        else:
197
            # Shouldn't ever happen, unless we change CompressionEnum but not the rest of this function
198 0
            raise TypeError("Unknown compression type??")
199

200 4
    def get_json(self):
201
        """
202
        Returns a dict if the data stored is a JSON string
203

204
        (errors are stored as JSON. stdout/stderr are just strings)
205
        """
206 2
        s = self.get_string()
207 2
        return json.loads(s)
208

209

210 4
class QCSpecification(ProtoModel):
211
    """
212
    The quantum chemistry metadata specification for individual computations such as energy, gradient, and Hessians.
213
    """
214

215 4
    driver: DriverEnum = Field(..., description=str(DriverEnum.__doc__))
216 4
    method: str = Field(..., description="The quantum chemistry method to evaluate (e.g., B3LYP, PBE, ...).")
217 4
    basis: Optional[str] = Field(
218
        None,
219
        description="The quantum chemistry basis set to evaluate (e.g., 6-31g, cc-pVDZ, ...). Can be ``None`` for "
220
        "methods without basis sets.",
221
    )
222 4
    keywords: Optional[ObjectId] = Field(
223
        None,
224
        description="The Id of the :class:`KeywordSet` registered in the database to run this calculation with. This "
225
        "Id must exist in the database.",
226
    )
227 4
    protocols: Optional[ResultProtocols] = Field(ResultProtocols(), description=str(ResultProtocols.__base_doc__))
228 4
    program: str = Field(
229
        ...,
230
        description="The quantum chemistry program to evaluate the computation with. Not all quantum chemistry programs"
231
        " support all combinations of driver/method/basis.",
232
    )
233

234 4
    def dict(self, *args, **kwargs):
235 4
        ret = super().dict(*args, **kwargs)
236

237
        # Maintain hash compatability
238 4
        if len(ret["protocols"]) == 0:
239 4
            ret.pop("protocols", None)
240

241 4
        return ret
242

243 4
    @validator("basis")
244
    def _check_basis(cls, v):
245 4
        return prepare_basis(v)
246

247 4
    @validator("program")
248
    def _check_program(cls, v):
249 4
        return v.lower()
250

251 4
    @validator("method")
252
    def _check_method(cls, v):
253 4
        return v.lower()
254

255 4
    def form_schema_object(self, keywords: Optional["KeywordSet"] = None, checks=True) -> Dict[str, Any]:
256 4
        if checks and self.keywords:
257 1
            assert keywords.id == self.keywords
258

259 4
        ret = {
260
            "driver": str(self.driver.name),
261
            "program": self.program,
262
            "model": {"method": self.method},
263
        }  # yapf: disable
264 4
        if self.basis:
265 1
            ret["model"]["basis"] = self.basis
266

267 4
        if keywords:
268 1
            ret["keywords"] = keywords.values
269
        else:
270 4
            ret["keywords"] = {}
271

272 4
        return ret
273

274

275 4
class OptimizationSpecification(ProtoModel):
276
    """
277
    Metadata describing a geometry optimization.
278
    """
279

280 4
    program: str = Field(..., description="Optimization program to run the optimization with")
281 4
    keywords: Optional[Dict[str, Any]] = Field(
282
        None,
283
        description="Dictionary of keyword arguments to pass into the ``program`` when the program runs. "
284
        "Note that unlike :class:`QCSpecification` this is a dictionary of keywords, not the Id for a "
285
        ":class:`KeywordSet`. ",
286
    )
287 4
    protocols: Optional[OptimizationProtocols] = Field(
288
        OptimizationProtocols(), description=str(OptimizationProtocols.__base_doc__)
289
    )
290

291 4
    def dict(self, *args, **kwargs):
292 4
        ret = super().dict(*args, **kwargs)
293

294
        # Maintain hash compatability
295 4
        if len(ret["protocols"]) == 0:
296 4
            ret.pop("protocols", None)
297

298 4
        return ret
299

300 4
    @validator("program")
301
    def _check_program(cls, v):
302 4
        return v.lower()
303

304 4
    @validator("keywords")
305
    def _check_keywords(cls, v):
306 4
        if v is not None:
307 4
            v = recursive_normalizer(v)
308 4
        return v
309

310

311 4
class KeywordSet(ProtoModel):
312
    """
313
    A key:value storage object for Keywords.
314
    """
315

316 4
    id: Optional[ObjectId] = Field(
317
        None, description="The Id of this object, will be automatically assigned when added to the database."
318
    )
319 4
    hash_index: str = Field(
320
        ...,
321
        description="The hash of this keyword set to store and check for collisions. This string is automatically "
322
        "computed.",
323
    )
324 4
    values: Dict[str, Optional[Any]] = Field(
325
        ...,
326
        description="The key-value pairs which make up this KeywordSet. There is no direct relation between this "
327
        "dictionary and applicable program/spec it can be used on.",
328
    )
329 4
    lowercase: bool = Field(
330
        True,
331
        description="String keys are in the ``values`` dict are normalized to lowercase if this is True. Assists in "
332
        "matching against other :class:`KeywordSet` objects in the database.",
333
    )
334 4
    exact_floats: bool = Field(
335
        False,
336
        description="All floating point numbers are rounded to 1.e-10 if this is False."
337
        "Assists in matching against other :class:`KeywordSet` objects in the database.",
338
    )
339 4
    comments: Optional[str] = Field(
340
        None,
341
        description="Additional comments for this KeywordSet. Intended for pure human/user consumption " "and clarity.",
342
    )
343

344 4
    def __init__(self, **data):
345

346 4
        build_index = False
347 4
        if ("hash_index" not in data) or data.pop("build_index", False):
348 4
            build_index = True
349 4
            data["hash_index"] = "placeholder"
350

351 4
        ProtoModel.__init__(self, **data)
352

353
        # Overwrite options with massaged values
354 4
        kwargs = {"lowercase": self.lowercase}
355 4
        if self.exact_floats:
356 4
            kwargs["digits"] = False
357

358 4
        self.__dict__["values"] = recursive_normalizer(self.values, **kwargs)
359

360
        # Build a hash index if we need it
361 4
        if build_index:
362 4
            self.__dict__["hash_index"] = self.get_hash_index()
363

364 4
    def get_hash_index(self):
365 4
        return hash_dictionary(self.values.copy())
366

367

368 4
class Citation(ProtoModel):
369
    """ A literature citation.  """
370

371 4
    acs_citation: Optional[
372
        str
373
    ] = None  # hand-formatted citation in ACS style. In the future, this could be bibtex, rendered to different formats.
374 4
    bibtex: Optional[str] = None  # bibtex blob for later use with bibtex-renderer
375 4
    doi: Optional[str] = None
376 4
    url: Optional[str] = None
377

378 4
    def to_acs(self) -> str:
379
        """ Returns an ACS-formatted citation """
380 0
        return self.acs_citation

Read our documentation on viewing source code .

Loading