1
"""
2
QCPortal Database ODM
3
"""
4 4
import gzip
5 4
import tempfile
6 4
import warnings
7 4
from pathlib import Path
8 4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
9

10 4
import numpy as np
11 4
import pandas as pd
12 4
import requests
13 4
from pydantic import Field, validator
14 4
from qcelemental import constants
15 4
from qcelemental.models.types import Array
16 4
from tqdm import tqdm
17

18 4
from ..models import Citation, ComputeResponse, ObjectId, ProtoModel
19 4
from ..statistics import wrap_statistics
20 4
from ..visualization import bar_plot, violin_plot
21 4
from .collection import Collection
22 4
from .collection_utils import composition_planner, register_collection
23

24
if TYPE_CHECKING:  # pragma: no cover
25
    from .. import FractalClient
26
    from ..models import KeywordSet, Molecule, ResultRecord
27
    from . import DatasetView
28

29

30 4
class MoleculeEntry(ProtoModel):
31 4
    name: str = Field(..., description="The name of entry.")
32 4
    molecule_id: ObjectId = Field(..., description="The id of the Molecule the entry references.")
33 4
    comment: Optional[str] = Field(None, description="A comment for the entry")
34 4
    local_results: Dict[str, Any] = Field({}, description="Additional local values.")
35

36

37 4
class ContributedValues(ProtoModel):
38 4
    name: str = Field(..., description="The name of the contributed values.")
39 4
    values: Any = Field(..., description="The values in the contributed values.")
40 4
    index: Array[str] = Field(
41
        ..., description="The entry index for the contributed values, matches the order of the `values` array."
42
    )
43 4
    values_structure: Dict[str, Any] = Field(
44
        {}, description="A machine readable description of the values structure. Typically not needed."
45
    )
46

47 4
    theory_level: Union[str, Dict[str, str]] = Field(..., description="A string representation of the theory level.")
48 4
    units: str = Field(..., description="The units of the values, can be any valid QCElemental unit.")
49 4
    theory_level_details: Optional[Union[str, Dict[str, Optional[str]]]] = Field(
50
        None, description="A detailed reprsentation of the theory level."
51
    )
52

53 4
    citations: Optional[List[Citation]] = Field(None, description="Citations associated with the contributed values.")
54 4
    external_url: Optional[str] = Field(None, description="An external URL to the raw contributed values data.")
55 4
    doi: Optional[str] = Field(None, description="A DOI for the contributed values data.")
56

57 4
    comments: Optional[str] = Field(None, description="Additional comments about the contributed values")
58

59 4
    @validator("values")
60
    def _make_array(cls, v):
61 4
        if isinstance(v, (list, tuple)) and isinstance(v[0], (float, int, str, bool)):
62 4
            v = np.array(v)
63

64 4
        return v
65

66

67 4
class Dataset(Collection):
68
    """
69
    The Dataset class for homogeneous computations on many molecules.
70

71
    Attributes
72
    ----------
73
    client : client.FractalClient
74
        A FractalClient connected to a server
75
    data : dict
76
        JSON representation of the database backbone
77
    df : pd.DataFrame
78
        The underlying dataframe for the Dataset object
79
    """
80

81 4
    def __init__(self, name: str, client: Optional["FractalClient"] = None, **kwargs: Any) -> None:
82
        """
83
        Initializer for the Dataset object. If no Portal is supplied or the database name
84
        is not present on the server that the Portal is connected to a blank database will be
85
        created.
86

87
        Parameters
88
        ----------
89
        name : str
90
            The name of the Dataset
91
        client : Optional['FractalClient'], optional
92
            A Portal client to connected to a server
93
        **kwargs : Dict[str, Any]
94
            Additional kwargs to pass to the collection
95
        """
96 4
        super().__init__(name, client=client, **kwargs)
97

98 4
        self._units = self.data.default_units
99

100
        # If we making a new database we may need new hashes and json objects
101 4
        self._new_molecules: Dict[str, Molecule] = {}
102 4
        self._new_keywords: Dict[Tuple[str, str], KeywordSet] = {}
103 4
        self._new_records: List[Dict[str, Any]] = []
104 4
        self._updated_state = False
105

106 4
        self._view: Optional[DatasetView] = None
107 4
        if self.data.view_available:
108 4
            from . import RemoteView
109

110 4
            self._view = RemoteView(client, self.data.id)
111 4
        self._disable_view: bool = False  # for debugging and testing
112 4
        self._disable_query_limit: bool = False  # for debugging and testing
113

114
        # Initialize internal data frames and load in contrib
115 4
        self.df = pd.DataFrame()
116 4
        self._column_metadata: Dict[str, Any] = {}
117

118
        # If this is a brand new dataset, initialize the records and cv fields
119 4
        if self.data.id == "local":
120 4
            if self.data.records is None:
121 4
                self.data.__dict__["records"] = []
122 4
            if self.data.contributed_values is None:
123 4
                self.data.__dict__["contributed_values"] = {}
124

125 4
    class DataModel(Collection.DataModel):
126

127
        # Defaults
128 4
        default_program: Optional[str] = None
129 4
        default_keywords: Dict[str, str] = {}
130 4
        default_driver: str = "energy"
131 4
        default_units: str = "kcal / mol"
132 4
        default_benchmark: Optional[str] = None
133

134 4
        alias_keywords: Dict[str, Dict[str, str]] = {}
135

136
        # Data
137 4
        records: Optional[List[MoleculeEntry]] = None
138 4
        contributed_values: Dict[str, ContributedValues] = None
139

140
        # History: driver, program, method (basis, keywords)
141 4
        history: Set[Tuple[str, str, str, Optional[str], Optional[str]]] = set()
142 4
        history_keys: Tuple[str, str, str, str, str] = ("driver", "program", "method", "basis", "keywords")
143

144 4
    def set_view(self, path: Union[str, Path]) -> None:
145
        """
146
        Set a dataset to use a local view.
147

148
        Parameters
149
        ----------
150
        path: Union[str, Path]
151
            path to an hdf5 file representing a view for this dataset
152
        """
153 4
        from . import HDF5View
154

155 4
        self._view = HDF5View(path)
156

157 4
    def download(
158
        self, local_path: Optional[Union[str, Path]] = None, verify: bool = True, progress_bar: bool = True
159
    ) -> None:
160
        """
161
        Download a remote view if available. The dataset will use this view to avoid server queries for calls to:
162
        - get_entries
163
        - get_molecules
164
        - get_values
165
        - list_values
166

167
        Parameters
168
        ----------
169
        local_path: Optional[Union[str, Path]], optional
170
            Local path the store downloaded view. If None, the view will be stored in a temporary file and deleted on exit.
171
        verify: bool, optional
172
            Verify download checksum. Default: True.
173
        progress_bar: bool, optional
174
            Display a download progress bar. Default: True
175
        """
176 4
        chunk_size = 8192
177 4
        if self.data.view_url_hdf5 is None:
178 0
            raise ValueError("A view for this dataset is not available on the server")
179

180 4
        if local_path is not None:
181 0
            local_path = Path(local_path)
182
        else:
183 4
            self._view_tempfile = tempfile.NamedTemporaryFile()  # keep temp file alive until self is destroyed
184 4
            local_path = self._view_tempfile.name
185

186 4
        r = requests.get(self.data.view_url_hdf5, stream=True)
187 4
        pbar = None
188 4
        if progress_bar:
189 4
            try:
190 4
                file_length = int(r.headers.get("content-length"))
191 4
                pbar = tqdm(total=file_length, initial=0, unit="B", unit_scale=True)
192 4
            except:
193 4
                warnings.warn("Failed to create download progress bar", RuntimeWarning)
194

195 4
        with open(local_path, "wb") as fd:
196 4
            for chunk in r.iter_content(chunk_size=chunk_size):
197 4
                fd.write(chunk)
198 4
                if pbar is not None:
199 4
                    pbar.update(chunk_size)
200

201 4
        with open(local_path, "rb") as f:
202 4
            magic = f.read(2)
203 4
            gzipped = magic == b"\x1f\x8b"
204 4
        if gzipped:
205 4
            extract_tempfile = tempfile.NamedTemporaryFile()  # keep temp file alive until self is destroyed
206 4
            with gzip.open(local_path, "rb") as fgz:
207 4
                with open(extract_tempfile.name, "wb") as f:
208 4
                    f.write(fgz.read())
209 4
            self._view_tempfile = extract_tempfile
210 4
            local_path = self._view_tempfile.name
211

212 4
        if verify:
213 4
            remote_checksum = self.data.view_metadata["blake2b_checksum"]
214 4
            from . import HDF5View
215

216 4
            local_checksum = HDF5View(local_path).hash()
217 4
            if remote_checksum != local_checksum:
218 1
                raise ValueError(f"Checksum verification failed. Expected: {remote_checksum}, Got: {local_checksum}")
219

220 4
        self.set_view(local_path)
221

222 4
    def to_file(self, path: Union[str, Path], encoding: str) -> None:
223
        """
224
        Writes a view of the dataset to a file
225

226
        Parameters
227
        ----------
228
        path: Union[str, Path]
229
            Where to write the file
230
        encoding: str
231
            Options: plaintext, hdf5
232
        """
233 4
        if encoding.lower() == "plaintext":
234 4
            from . import PlainTextView
235

236 4
            PlainTextView(path).write(self)
237 1
        elif encoding.lower() in ["hdf5", "h5"]:
238 1
            from . import HDF5View
239

240 1
            HDF5View(path).write(self)
241
        else:
242 0
            raise NotImplementedError(f"Unsupported encoding: {encoding}")
243

244 4
    def _get_data_records_from_db(self):
245 4
        self._check_client()
246
        # This is hacky. What we want to do is get records and contributed values correctly unpacked into pydantic
247
        # objects. So what we do is call get_collection with include. But we have to also include collection and
248
        # name in the query because they are required in the collection DataModel. But we can use these to check that
249
        # we got back the right data, so that's nice.
250 4
        response = self.client.get_collection(
251
            self.__class__.__name__.lower(),
252
            self.name,
253
            full_return=False,
254
            include=["records", "contributed_values", "collection", "name", "id"],
255
        )
256 4
        if not (response.data.id == self.data.id and response.data.name == self.name):
257 0
            raise ValueError("Got the wrong records and contributed values from the server.")
258
        # This works because get_collection builds a validated Dataset object
259 4
        self.data.__dict__["records"] = response.data.records
260 4
        self.data.__dict__["contributed_values"] = response.data.contributed_values
261

262 4
    def _entry_index(self, subset: Optional[List[str]] = None) -> pd.DataFrame:
263
        # TODO: make this fast for subsets
264 4
        if self.data.records is None:
265 4
            self._get_data_records_from_db()
266

267 4
        ret = pd.DataFrame(
268
            [[entry.name, entry.molecule_id] for entry in self.data.records], columns=["name", "molecule_id"]
269
        )
270 4
        if subset is None:
271 4
            return ret
272
        else:
273 4
            return ret.reset_index().set_index("name").loc[subset].reset_index().set_index("index")
274

275 4
    def _check_state(self) -> None:
276 4
        if self._new_molecules or self._new_keywords or self._new_records or self._updated_state:
277 4
            raise ValueError("New molecules, keywords, or records detected, run save before submitting new tasks.")
278

279 4
    def _canonical_pre_save(self, client: "FractalClient") -> None:
280 4
        self._ensure_contributed_values()
281 4
        if self.data.records is None:
282 0
            self._get_data_records_from_db()
283 4
        for k in list(self._new_keywords.keys()):
284 4
            ret = client.add_keywords([self._new_keywords[k]])
285 4
            assert len(ret) == 1, "KeywordSet added incorrectly"
286 4
            self.data.alias_keywords[k[0]][k[1]] = ret[0]
287 4
            del self._new_keywords[k]
288 4
        self._updated_state = False
289

290 4
    def _pre_save_prep(self, client: "FractalClient") -> None:
291 4
        self._canonical_pre_save(client)
292

293
        # Preps any new molecules introduced to the Dataset before storing data.
294 4
        mol_ret = self._add_molecules_by_dict(client, self._new_molecules)
295

296
        # Update internal molecule UUID's to servers UUID's
297 4
        for record in self._new_records:
298 4
            molecule_hash = record.pop("molecule_hash")
299 4
            new_record = MoleculeEntry(molecule_id=mol_ret[molecule_hash], **record)
300 4
            self.data.records.append(new_record)
301

302 4
        self._new_records = []
303 4
        self._new_molecules = {}
304

305 4
    def get_entries(self, subset: Optional[List[str]] = None, force: bool = False) -> pd.DataFrame:
306
        """
307
        Provides a list of entries for the dataset
308

309
        Parameters
310
        ----------
311
        subset: Optional[List[str]], optional
312
            The indices of the desired subset. Return all indices if subset is None.
313
        force: bool, optional
314
            skip cache
315

316
        Returns
317
        -------
318
        pd.DataFrame
319
            A dataframe containing entry names and specifciations.
320
            For Dataset, specifications are molecule ids.
321
            For ReactionDataset, specifications describe reaction stoichiometry.
322
        """
323 4
        if self._use_view(force):
324 4
            ret = self._view.get_entries(subset)
325
        else:
326 4
            ret = self._entry_index(subset)
327 4
        return ret.copy()
328

329 4
    def _molecule_indexer(
330
        self, subset: Optional[Union[str, Set[str]]] = None, force: bool = False
331
    ) -> Dict[str, ObjectId]:
332
        """Provides a {index: molecule_id} mapping for a given subset.
333

334
        Parameters
335
        ----------
336
        subset : Optional[Union[str, Set[str]]], optional
337
            The indices of the desired subset. Return all indices if subset is None.
338

339
        Returns
340
        -------
341
        Dict[str, 'ObjectId']
342
            Molecule index to molecule ObjectId map
343
        """
344 4
        if subset:
345 4
            if isinstance(subset, str):
346 1
                subset = {subset}
347 4
        index = self.get_entries(force=force, subset=subset)
348
        # index = index[index.name.isin(subset)]
349

350 4
        return {row["name"]: row["molecule_id"] for row in index.to_dict("records")}
351

352 4
    def _add_history(self, **history: Optional[str]) -> None:
353
        """
354
        Adds compute history to the dataset
355
        """
356 4
        if history.keys() != set(self.data.history_keys):
357 0
            raise KeyError("Internal error: Incorrect history keys passed in.")
358

359 4
        new_history = []
360 4
        for key in self.data.history_keys:
361

362 4
            value = history[key]
363 4
            if value is not None:
364 4
                value = value.lower()
365

366 4
            new_history.append(value)
367

368 4
        self.data.history.add(tuple(new_history))
369

370 4
    def list_values(
371
        self,
372
        method: Optional[Union[str, List[str]]] = None,
373
        basis: Optional[Union[str, List[str]]] = None,
374
        keywords: Optional[str] = None,
375
        program: Optional[str] = None,
376
        driver: Optional[str] = None,
377
        name: Optional[Union[str, List[str]]] = None,
378
        native: Optional[bool] = None,
379
        force: bool = False,
380
    ) -> pd.DataFrame:
381
        """
382
        Lists available data that may be queried with get_values.
383
        Results may be narrowed by providing search keys.
384
        `None` is a wildcard selector. To search for `None`, use `"None"`.
385

386
        Parameters
387
        ----------
388
        method : Optional[Union[str, List[str]]], optional
389
            The computational method (B3LYP)
390
        basis : Optional[Union[str, List[str]]], optional
391
            The computational basis (6-31G)
392
        keywords : Optional[str], optional
393
            The keyword alias
394
        program : Optional[str], optional
395
            The underlying QC program
396
        driver : Optional[str], optional
397
            The type of calculation (e.g. energy, gradient, hessian, dipole...)
398
        name : Optional[Union[str, List[str]]], optional
399
            The canonical name of the data column
400
        native: Optional[bool], optional
401
            True: only include data computed with QCFractal
402
            False: only include data contributed from outside sources
403
            None: include both
404
        force : bool, optional
405
            Data is typically cached, forces a new query if True
406

407
        Returns
408
        -------
409
        DataFrame
410
            A DataFrame of the matching data specifications
411
        """
412 4
        spec: Dict[str, Optional[Union[str, bool, List[str]]]] = {
413
            "method": method,
414
            "basis": basis,
415
            "keywords": keywords,
416
            "program": program,
417
            "name": name,
418
            "driver": driver,
419
        }
420

421 4
        if self._use_view(force):
422 4
            ret = self._view.list_values()
423 4
            spec["native"] = native
424
        else:
425 4
            ret = []
426 4
            if native in {True, None}:
427 4
                df = self._list_records(dftd3=False)
428 4
                df["native"] = True
429 4
                ret.append(df)
430

431 4
            if native in {False, None}:
432 4
                df = self._list_contributed_values()
433 4
                df["native"] = False
434 4
                ret.append(df)
435

436 4
            ret = pd.concat(ret)
437

438
        # Filter
439 4
        ret.fillna("None", inplace=True)
440 4
        ret = self._filter_records(ret, **spec)
441

442
        # Sort
443 4
        sort_index = ["native"] + list(self.data.history_keys[:-1])
444 4
        if "stoichiometry" in ret.columns:
445 4
            sort_index += ["stoichiometry", "name"]
446 4
        ret.set_index(sort_index, inplace=True)
447 4
        ret.sort_index(inplace=True)
448 4
        ret.reset_index(inplace=True)
449 4
        ret.set_index(["native"] + list(self.data.history_keys[:-1]), inplace=True)
450

451 4
        return ret
452

453 4
    @staticmethod
454 4
    def _filter_records(
455
        df: pd.DataFrame, **spec: Optional[Union[str, bool, List[Union[str, bool]], Tuple]]
456
    ) -> pd.DataFrame:
457
        """
458
        Helper for filtering records on a spec. Note that `None` is a wildcard while `"None"` matches `None` and NaN.
459
        """
460 4
        ret = df.copy()
461

462 4
        if len(ret) == 0:  # workaround pandas empty dataframe sharp edges
463 4
            return ret
464

465 4
        for key, value in spec.items():
466 4
            if value is None:
467 4
                continue
468 4
            if isinstance(value, bool):
469 4
                ret = ret[ret[key] == value]
470 4
            elif isinstance(value, str):
471 4
                value = value.lower()
472 4
                ret = ret[ret[key].fillna("None").str.lower() == value]
473 4
            elif isinstance(value, (list, tuple)):
474 4
                query = [x.lower() for x in value]
475 4
                ret = ret[ret[key].fillna("None").str.lower().isin(query)]
476
            else:
477 0
                raise TypeError(f"Search type {type(value)} not understood.")
478 4
        return ret
479

480 4
    def list_records(
481
        self, dftd3: bool = False, pretty: bool = True, **search: Optional[Union[str, List[str]]]
482
    ) -> pd.DataFrame:
483
        """
484
        Lists specifications of available records, i.e. method, program, basis set, keyword set, driver combinations
485
        `None` is a wildcard selector. To search for `None`, use `"None"`.
486

487
        Parameters
488
        ----------
489
        pretty: bool
490
            Replace NaN with "None" in returned DataFrame
491
        **search : Dict[str, Optional[str]]
492
            Allows searching to narrow down return.
493

494
        Returns
495
        -------
496
        DataFrame
497
            Record specifications matching **search.
498

499
        """
500 4
        ret = self._list_records(dftd3=dftd3)
501 4
        ret = self._filter_records(ret, **search)
502 4
        if pretty:
503 4
            ret.fillna("None", inplace=True)
504 4
        return ret
505

506 4
    def _list_records(self, dftd3: bool = False) -> pd.DataFrame:
507
        """
508
        Lists specifications of available records, i.e. method, program, basis set, keyword set, driver combinations
509
        `None` is a wildcard selector. To search for `None`, use `"None"`.
510

511
        Parameters
512
        ----------
513
        dftd3: bool, optional
514
            Include dftd3 program record specifications in addition to composite DFT-D3 record specifications
515

516
        Returns
517
        -------
518
        DataFrame
519
            Record specifications matching **search.
520

521
        """
522 4
        show_dftd3 = dftd3
523

524 4
        history = pd.DataFrame(list(self.data.history), columns=self.data.history_keys)
525

526
        # Short circuit because merge and apply below require data
527 4
        if history.shape[0] == 0:
528 4
            ret = history.copy()
529 4
            ret["name"] = None
530 4
            return ret
531

532
        # Build out -D3 combos
533 4
        dftd3 = history[history["program"] == "dftd3"].copy()
534 4
        dftd3["base"] = [x.split("-d3")[0] for x in dftd3["method"]]
535

536 4
        nondftd3 = history[history["program"] != "dftd3"]
537 4
        dftd3combo = nondftd3.merge(dftd3[["method", "base"]], left_on="method", right_on="base")
538 4
        dftd3combo["method"] = dftd3combo["method_y"]
539 4
        dftd3combo.drop(["method_x", "method_y", "base"], axis=1, inplace=True)
540

541 4
        history = pd.concat([history, dftd3combo], sort=False)
542 4
        history = history.reset_index()
543 4
        history.drop("index", axis=1, inplace=True)
544

545
        # Drop duplicates due to stoich in some instances, this could be handled with multiple merges
546
        # Simpler to do it this way.
547 4
        history.drop_duplicates(inplace=True)
548

549
        # Find the returned subset
550 4
        ret = history.copy()
551

552
        # Add name column
553 4
        ret["name"] = ret.apply(
554
            lambda row: self._canonical_name(
555
                program=row["program"],
556
                method=row["method"],
557
                basis=row["basis"],
558
                keywords=row["keywords"],
559
                stoich=row.get("stoichiometry", None),
560
                driver=row["driver"],
561
            ),
562
            axis=1,
563
        )
564 4
        if show_dftd3 is False:
565 4
            ret = ret[ret["program"] != "dftd3"]
566

567 4
        return ret
568

569 4
    def get_values(
570
        self,
571
        method: Optional[Union[str, List[str]]] = None,
572
        basis: Optional[Union[str, List[str]]] = None,
573
        keywords: Optional[str] = None,
574
        program: Optional[str] = None,
575
        driver: Optional[str] = None,
576
        name: Optional[Union[str, List[str]]] = None,
577
        native: Optional[bool] = None,
578
        subset: Optional[Union[str, List[str]]] = None,
579
        force: bool = False,
580
    ) -> pd.DataFrame:
581
        """
582
        Obtains values matching the search parameters provided for the expected `return_result` values.
583
        Defaults to the standard programs and keywords if not provided.
584

585
        Note that unlike `get_records`, `get_values` will automatically expand searches and return multiple method
586
        and basis combinations simultaneously.
587

588
        `None` is a wildcard selector. To search for `None`, use `"None"`.
589

590
        Parameters
591
        ----------
592
        method : Optional[Union[str, List[str]]], optional
593
            The computational method (B3LYP)
594
        basis : Optional[Union[str, List[str]]], optional
595
            The computational basis (6-31G)
596
        keywords : Optional[str], optional
597
            The keyword alias
598
        program : Optional[str], optional
599
            The underlying QC program
600
        driver : Optional[str], optional
601
            The type of calculation (e.g. energy, gradient, hessian, dipole...)
602
        name : Optional[Union[str, List[str]]], optional
603
            Canonical name of the record. Overrides the above selectors.
604
        native: Optional[bool], optional
605
            True: only include data computed with QCFractal
606
            False: only include data contributed from outside sources
607
            None: include both
608
        subset: Optional[List[str]], optional
609
            The indices of the desired subset. Return all indices if subset is None.
610
        force : bool, optional
611
            Data is typically cached, forces a new query if True
612

613
        Returns
614
        -------
615
        DataFrame
616
            A DataFrame of values with columns corresponding to methods and rows corresponding to molecule entries.
617
        """
618 4
        return self._get_values(
619
            method=method,
620
            basis=basis,
621
            keywords=keywords,
622
            program=program,
623
            driver=driver,
624
            name=name,
625
            native=native,
626
            subset=subset,
627
            force=force,
628
        )
629

630 4
    def _get_values(
631
        self,
632
        native: Optional[bool] = None,
633
        force: bool = False,
634
        subset: Optional[Union[str, List[str]]] = None,
635
        **spec: Union[List[str], str, None],
636
    ) -> pd.DataFrame:
637 4
        ret = []
638

639 4
        if subset is None:
640 4
            subset_set = set(self.get_index(force=force))
641 4
        elif isinstance(subset, str):
642 4
            subset_set = {subset}
643 4
        elif isinstance(subset, list):
644 4
            subset_set = set(subset)
645
        else:
646 0
            raise ValueError(f"Subset must be str, List[str], or None. Got {type(subset)}")
647

648 4
        if native in {True, None}:
649 4
            spec_nodriver = spec.copy()
650 4
            driver = spec_nodriver.pop("driver")
651 4
            if driver is not None and driver != self.data.default_driver:
652 0
                raise KeyError(
653
                    f"For native values, driver ({driver}) must be the same as the dataset's default driver "
654
                    f"({self.data.default_driver}). Consider using get_records instead."
655
                )
656 4
            df = self._get_native_values(subset=subset_set, force=force, **spec_nodriver)
657 4
            ret.append(df)
658

659 4
        if native in {False, None}:
660 4
            df = self._get_contributed_values(subset=subset_set, force=force, **spec)
661 4
            ret.append(df)
662 4
        ret_df = pd.concat(ret, axis=1)
663 4
        ret_df = ret_df.loc[subset if subset is not None else self.get_index()]
664

665 4
        return ret_df
666

667 4
    def _get_native_values(
668
        self,
669
        subset: Set[str],
670
        method: Optional[Union[str, List[str]]] = None,
671
        basis: Optional[Union[str, List[str]]] = None,
672
        keywords: Optional[str] = None,
673
        program: Optional[str] = None,
674
        name: Optional[Union[str, List[str]]] = None,
675
        force: bool = False,
676
    ) -> pd.DataFrame:
677
        """
678
        Obtains records matching the provided search criteria.
679
        Defaults to the standard programs and keywords if not provided.
680

681
        Parameters
682
        ----------
683
        subset: Set[str]
684
            The indices of the desired subset.
685
        method : Optional[Union[str, List[str]]], optional
686
            The computational method to compute (B3LYP)
687
        basis : Optional[Union[str, List[str]]], optional
688
            The computational basis to compute (6-31G)
689
        keywords : Optional[str], optional
690
            The keyword alias for the requested compute
691
        program : Optional[str], optional
692
            The underlying QC program
693
        name : Optional[Union[str, List[str]]], optional
694
            Canonical name of the record. Overrides the above selectors.
695
        force : bool, optional
696
            Data is typically cached, forces a new query if True.
697

698
        Returns
699
        -------
700
        DataFrame
701
            A DataFrame of the queried parameters
702
        """
703 4
        au_units = {"energy": "hartree", "gradient": "hartree/bohr", "hessian": "hartree/bohr**2"}
704

705
        # So that datasets with no records do not require a default program and default keywords
706 4
        if len(self.list_records()) == 0:
707 4
            return pd.DataFrame(index=self.get_index(subset))
708

709 4
        queries = self._form_queries(method=method, basis=basis, keywords=keywords, program=program, name=name)
710 4
        names = []
711 4
        new_queries = []
712 4
        for _, query in queries.iterrows():
713

714 4
            query = query.replace({np.nan: None}).to_dict()
715 4
            if "stoichiometry" in query:
716 0
                query["stoich"] = query.pop("stoichiometry")
717

718 4
            qname = query["name"]
719 4
            names.append(qname)
720 4
            if force or not self._subset_in_cache(qname, subset):
721 4
                self._column_metadata[qname] = query
722 4
                new_queries.append(query)
723

724 4
        new_data = pd.DataFrame(index=subset)
725

726 4
        if not self._use_view(force):
727 4
            units: Dict[str, str] = {}
728 4
            for query in new_queries:
729 4
                driver = query.pop("driver")
730 4
                qname = query.pop("name")
731 4
                data = self.get_records(
732
                    query.pop("method").upper(), include=["return_result"], merge=True, subset=subset, **query
733
                )
734 4
                new_data[qname] = data["return_result"]
735 4
                units[qname] = au_units[driver]
736 4
                query["name"] = qname
737
        else:
738 4
            for query in new_queries:
739 1
                query["native"] = True
740 4
            new_data, units = self._view.get_values(new_queries, subset)
741

742 4
        for query in new_queries:
743 4
            qname = query["name"]
744 4
            new_data[qname] *= constants.conversion_factor(units[qname], self.units)
745 4
            self._column_metadata[qname].update({"native": True, "units": self.units})
746

747 4
        self._update_cache(new_data)
748 4
        return self.df.loc[subset, names]
749

750 4
    def _form_queries(
751
        self,
752
        method: Optional[Union[str, List[str]]] = None,
753
        basis: Optional[Union[str, List[str]]] = None,
754
        keywords: Optional[str] = None,
755
        program: Optional[str] = None,
756
        stoich: Optional[str] = None,
757
        name: Optional[Union[str, List[str]]] = None,
758
    ) -> pd.DataFrame:
759 4
        if name is None:
760 4
            _, _, history = self._default_parameters(program, "nan", "nan", keywords, stoich=stoich)
761 4
            for k, v in [("method", method), ("basis", basis)]:
762

763 4
                if v is not None:
764 4
                    history[k] = v
765
                else:
766 4
                    history.pop(k, None)
767 4
            queries = self.list_records(**history, dftd3=True, pretty=False)
768
        else:
769 4
            if any((field is not None for field in {program, method, basis, keywords})):
770 1
                warnings.warn(
771
                    "Name and additional field were provided. Only name will be used as a selector.", RuntimeWarning
772
                )
773 4
            queries = self.list_records(name=name, dftd3=True, pretty=False)
774

775 4
        if queries.shape[0] > 10 and self._disable_query_limit is False:
776 0
            raise TypeError("More than 10 queries formed, please narrow the search.")
777 4
        return queries
778

779 4
    def _visualize(
780
        self,
781
        metric,
782
        bench,
783
        query: Dict[str, Union[Optional[str], List[str]]],
784
        groupby: Optional[str] = None,
785
        return_figure=None,
786
        digits=3,
787
        kind="bar",
788
        show_incomplete: bool = False,
789
    ) -> "plotly.Figure":
790

791
        # Validate query dimensions
792 4
        list_queries = [k for k, v in query.items() if isinstance(v, (list, tuple))]
793 4
        if len(list_queries) > 2:
794 0
            raise TypeError("A maximum of two lists are allowed.")
795

796
        # Check kind
797 4
        kind = kind.lower()
798 4
        if kind not in ["bar", "violin"]:
799 0
            raise KeyError(f"Visualiztion kind must either be 'bar' or 'violin', found {kind}")
800

801
        # Check metric
802 4
        metric = metric.upper()
803 4
        if metric == "UE":
804 4
            ylabel = f"UE [{self.units}]"
805 0
        elif metric == "URE":
806 0
            ylabel = "URE [%]"
807
        else:
808 0
            raise KeyError('Metric {} not understood, available metrics: "UE", "URE"'.format(metric))
809

810 4
        if kind == "bar":
811 4
            ylabel = "M" + ylabel
812 4
            metric = "M" + metric
813

814
        # Are we a groupby?
815 4
        _valid_groupby = {"method", "basis", "keywords", "program", "stoich", "d3"}
816 4
        if groupby is not None:
817 4
            groupby = groupby.lower()
818 4
            if groupby not in _valid_groupby:
819 0
                raise KeyError(f"Groupby option {groupby} not understood.")
820 4
            if (groupby != "d3") and (groupby not in query):
821 0
                raise KeyError(f"Groupby option {groupby} not found in query, must provide a search on this parameter.")
822

823 4
            if (groupby != "d3") and (not isinstance(query[groupby], (tuple, list))):
824 0
                raise KeyError(f"Groupby option {groupby} must be a list.")
825

826 4
            query_names = []
827 4
            queries = []
828 4
            if groupby == "d3":
829 4
                base = [method.upper().split("-D3")[0] for method in query["method"]]
830 4
                d3types = [method.upper().replace(b, "").replace("-D", "D") for method, b in zip(query["method"], base)]
831

832
                # Preserve order of first unique appearance
833 4
                seen: Set[str] = set()
834 4
                unique_d3types = [x for x in d3types if not (x in seen or seen.add(x))]
835

836 4
                for d3type in unique_d3types:
837 4
                    gb_query = query.copy()
838 4
                    gb_query["method"] = []
839 4
                    for i in range(len(base)):
840 4
                        method = query["method"][i]
841 4
                        if method.upper().replace(base[i], "").replace("-D", "D") == d3type:
842 4
                            gb_query["method"].append(method)
843 4
                    queries.append(gb_query)
844 4
                    if d3type == "":
845 4
                        query_names.append("No -D3")
846
                    else:
847 4
                        query_names.append(d3type.upper())
848
            else:
849 4
                for gb in query[groupby]:
850 4
                    gb_query = query.copy()
851 4
                    gb_query[groupby] = gb
852

853 4
                    queries.append(gb_query)
854 4
                    query_names.append(self._canonical_name(**{groupby: gb}))
855

856 4
            if (kind == "violin") and (len(queries) != 2):
857 0
                raise KeyError(f"Groupby option for violin plots must have two entries.")
858

859
        else:
860 4
            queries = [query]
861 4
            query_names = ["Stats"]
862

863 4
        title = f"{self.data.name} Dataset Statistics"
864

865 4
        series = []
866 4
        for q, name in zip(queries, query_names):
867

868 4
            if len(q) == 0:
869 0
                raise KeyError("No query matches, nothing to visualize!")
870

871
            # Pull the values
872 4
            if "stoichiometry" in q:
873 4
                q["stoich"] = q.pop("stoichiometry")
874 4
            values = self.get_values(**q)
875

876 4
            if not show_incomplete:
877 4
                values = values.dropna(axis=1, how="any")
878

879
            # Create the statistics
880 4
            stat = self.statistics(metric, values, bench=bench)
881 4
            stat = stat.round(digits)
882 4
            stat.sort_index(inplace=True)
883 4
            stat.name = name
884

885
            # Munge the column names based on the groupby parameter
886 4
            col_names = {}
887 4
            for k, v in stat.iteritems():
888 4
                record = self._column_metadata[k].copy()
889 4
                if groupby == "d3":
890 4
                    record["method"] = record["method"].upper().split("-D3")[0]
891

892 4
                elif groupby:
893 4
                    record[groupby] = None
894

895 4
                index_name = self._canonical_name(
896
                    record["program"],
897
                    record["method"],
898
                    record["basis"],
899
                    record["keywords"],
900
                    stoich=record.get("stoich"),
901
                )
902

903 4
                col_names[k] = index_name
904

905 4
            if kind == "bar":
906 4
                stat.index = [col_names[x] for x in stat.index]
907
            else:
908 4
                stat.columns = [col_names[x] for x in stat.columns]
909

910 4
            series.append(stat)
911

912 4
        if kind == "bar":
913 4
            return bar_plot(series, title=title, ylabel=ylabel, return_figure=return_figure)
914
        else:
915 4
            negative = None
916 4
            if groupby:
917 4
                negative = series[1]
918

919 4
            return violin_plot(series[0], negative=negative, title=title, ylabel=ylabel, return_figure=return_figure)
920

921 4
    def visualize(
922
        self,
923
        method: Optional[str] = None,
924
        basis: Optional[str] = None,
925
        keywords: Optional[str] = None,
926
        program: Optional[str] = None,
927
        groupby: Optional[str] = None,
928
        metric: str = "UE",
929
        bench: Optional[str] = None,
930
        kind: str = "bar",
931
        return_figure: Optional[bool] = None,
932
        show_incomplete: bool = False,
933
    ) -> "plotly.Figure":
934
        """
935
        Parameters
936
        ----------
937
        method : Optional[str], optional
938
            Methods to query
939
        basis : Optional[str], optional
940
            Bases to query
941
        keywords : Optional[str], optional
942
            Keyword aliases to query
943
        program : Optional[str], optional
944
            Programs aliases to query
945
        groupby : Optional[str], optional
946
            Groups the plot by this index.
947
        metric : str, optional
948
            The metric to use either UE (unsigned error) or URE (unsigned relative error)
949
        bench : Optional[str], optional
950
            The benchmark level of theory to use
951
        kind : str, optional
952
            The kind of chart to produce, either 'bar' or 'violin'
953
        return_figure : Optional[bool], optional
954
            If True, return the raw plotly figure. If False, returns a hosted iPlot.
955
            If None, return a iPlot display in Jupyter notebook and a raw plotly figure in all other circumstances.
956
        show_incomplete: bool, optional
957
            Display statistics method/basis set combinations where results are incomplete
958

959
        Returns
960
        -------
961
        plotly.Figure
962
            The requested figure.
963
        """
964

965 0
        query = {"method": method, "basis": basis, "keywords": keywords, "program": program}
966 0
        query = {k: v for k, v in query.items() if v is not None}
967

968 0
        return self._visualize(metric, bench, query=query, groupby=groupby, return_figure=return_figure, kind=kind)
969

970 4
    def _canonical_name(
971
        self,
972
        program: Optional[str] = None,
973
        method: Optional[str] = None,
974
        basis: Optional[str] = None,
975
        keywords: Optional[str] = None,
976
        stoich: Optional[str] = None,
977
        driver: Optional[str] = None,
978
    ) -> str:
979
        """
980
        Attempts to build a canonical name for a DataFrame column
981
        """
982

983 4
        name = ""
984 4
        if method:
985 4
            name = method.upper()
986

987 4
        if basis and name:
988 4
            name = f"{name}/{basis.lower()}"
989 4
        elif basis:
990 4
            name = f"{basis.lower()}"
991

992 4
        if keywords and (keywords != self.data.default_keywords.get(program, None)):
993 4
            name = f"{name}-{keywords}"
994

995 4
        if program and (program.lower() != self.data.default_program):
996 4
            name = f"{name}-{program.title()}"
997

998 4
        if stoich:
999 4
            if name == "":
1000 0
                name = stoich.lower()
1001 4
            elif stoich.lower() != "default":
1002 4
                name = f"{stoich.lower()}-{name}"
1003

1004 4
        return name
1005

1006 4
    def _default_parameters(
1007
        self,
1008
        program: Optional[str],
1009
        method: str,
1010
        basis: Optional[str],
1011
        keywords: Optional[str],
1012
        stoich: Optional[str] = None,
1013
    ) -> Tuple[str, Dict[str, Union[str, "KeywordSet"]], Dict[str, str]]:
1014
        """
1015
        Takes raw input parsed parameters and applies defaults to them.
1016
        """
1017

1018
        # Handle default program
1019 4
        if program is None:
1020 4
            if self.data.default_program is None:
1021 0
                raise KeyError("No default program was set and none was provided.")
1022 4
            program = self.data.default_program
1023
        else:
1024 4
            program = program.lower()
1025

1026 4
        driver = self.data.default_driver
1027

1028
        # Handle keywords
1029 4
        keywords_alias = keywords
1030 4
        if keywords is None:
1031 4
            if program in self.data.default_keywords:
1032 4
                keywords_alias = self.data.default_keywords[program]
1033 4
                keywords = self.data.alias_keywords[program][keywords_alias]
1034
        else:
1035 4
            if (program not in self.data.alias_keywords) or (keywords not in self.data.alias_keywords[program]):
1036 0
                raise KeyError("KeywordSet alias '{}' not found for program '{}'.".format(keywords, program))
1037

1038 4
            keywords_alias = keywords
1039 4
            keywords = self.data.alias_keywords[program][keywords]
1040

1041
        # Form database and history keys
1042 4
        dbkeys = {"driver": driver, "program": program, "method": method, "basis": basis, "keywords": keywords}
1043 4
        history = {**dbkeys, **{"keywords": keywords_alias}}
1044 4
        if stoich is not None:
1045 4
            history["stoichiometry"] = stoich
1046

1047 4
        name = self._canonical_name(program, method, basis, keywords_alias, stoich)
1048

1049 4
        return name, dbkeys, history
1050

1051 4
    def _get_molecules(self, indexer: Dict[Any, ObjectId], force: bool = False) -> pd.DataFrame:
1052
        """Queries a list of molecules using a molecule indexer
1053

1054
        Parameters
1055
        ----------
1056
        indexer : Dict[str, 'ObjectId']
1057
            A key/value index of molecules to query
1058
        force : bool, optional
1059
            Force pull of molecules from server
1060

1061
        Returns
1062
        -------
1063
        pd.DataFrame
1064
            A table of Molecules, indexed by Entry names
1065

1066
        Raises
1067
        ------
1068
        KeyError
1069
            If no records match the query
1070
        """
1071

1072 4
        molecule_ids = list(set(indexer.values()))
1073 4
        if not self._use_view(force):
1074 4
            molecules: List["Molecule"] = []
1075 4
            for i in range(0, len(molecule_ids), self.client.query_limit):
1076 4
                molecules.extend(self.client.query_molecules(id=molecule_ids[i : i + self.client.query_limit]))
1077
            # XXX: molecules = pd.DataFrame({"molecule_id": molecule_ids, "molecule": molecules}) fails
1078
            #      test_gradient_dataset_get_molecules and I don't know why
1079 4
            molecules = pd.DataFrame({"molecule_id": molecule.id, "molecule": molecule} for molecule in molecules)
1080
        else:
1081 4
            molecules = self._view.get_molecules(molecule_ids)
1082 4
            molecules = pd.DataFrame({"molecule_id": molecule_ids, "molecule": molecules})
1083

1084 4
        if len(molecules) == 0:
1085 0
            raise KeyError("Query matched 0 records.")
1086

1087 4
        df = pd.DataFrame.from_dict(indexer, orient="index", columns=["molecule_id"])
1088

1089 4
        df.reset_index(inplace=True)
1090

1091
        # Outer join on left to merge duplicate molecules
1092 4
        df = df.merge(molecules, how="left", on="molecule_id")
1093 4
        df.set_index("index", inplace=True)
1094 4
        df.drop("molecule_id", axis=1, inplace=True)
1095

1096 4
        return df
1097

1098 4
    def _get_records(
1099
        self,
1100
        indexer: Dict[Any, ObjectId],
1101
        query: Dict[str, Any],
1102
        include: Optional[List[str]] = None,
1103
        merge: bool = False,
1104
        raise_on_plan: Union[str, bool] = False,
1105
    ) -> "pd.Series":
1106
        """
1107
        Runs a query based on an indexer which is index : molecule_id
1108

1109
        Parameters
1110
        ----------
1111
        indexer : Dict[str, ObjectId]
1112
            A key/value index of molecules to query
1113
        query : Dict[str, Any]
1114
            A results query
1115
        include : Optional[List[str]], optional
1116
            The attributes to return. Otherwise returns ResultRecord objects.
1117
        merge : bool, optional
1118
            Sum compound queries together, useful for mixing results
1119
        raise_on_plan : Union[str, bool], optional
1120
            Raises a KeyError is True or string if a multi-stage plan is detected.
1121

1122
        Returns
1123
        -------
1124
        pd.Series
1125
            A Series of the data results
1126

1127
        """
1128 4
        self._check_client()
1129 4
        self._check_state()
1130

1131 4
        ret = []
1132 4
        plan = composition_planner(**query)
1133 4
        if raise_on_plan and (len(plan) > 1):
1134 0
            if raise_on_plan is True:
1135 0
                raise KeyError("Recieved a multi-stage plan when this function does not support multi-staged plans.")
1136
            else:
1137 0
                raise KeyError(raise_on_plan)
1138

1139 4
        for query_set in plan:
1140

1141 4
            query_set["keywords"] = self.get_keywords(query_set["keywords"], query_set["program"], return_id=True)
1142
            # Set the index to remove duplicates
1143 4
            molecules = list(set(indexer.values()))
1144 4
            if include:
1145 4
                proj = [k.lower() for k in include]
1146 4
                if "molecule" not in proj:
1147 4
                    proj.append("molecule")
1148 4
                query_set["include"] = proj
1149

1150
            # Chunk up the queries
1151 4
            records: List[ResultRecord] = []
1152 4
            for i in range(0, len(molecules), self.client.query_limit):
1153 4
                query_set["molecule"] = molecules[i : i + self.client.query_limit]
1154 4
                records.extend(self.client.query_results(**query_set))
1155

1156 4
            if include is None:
1157 1
                records = [{"molecule": x.molecule, "record": x} for x in records]
1158

1159 4
            records = pd.DataFrame.from_dict(records)
1160

1161 4
            df = pd.DataFrame.from_dict(indexer, orient="index", columns=["molecule"])
1162 4
            df.reset_index(inplace=True)
1163

1164 4
            if records.shape[0] > 0:
1165
                # Outer join on left to merge duplicate molecules
1166 4
                df = df.merge(records, how="left", on="molecule")
1167
            else:
1168
                # No results, fill NaN values
1169 1
                if include is None:
1170 1
                    df["record"] = None
1171
                else:
1172 1
                    for k in include:
1173 1
                        df[k] = np.nan
1174

1175 4
            df.set_index("index", inplace=True)
1176 4
            df.drop("molecule", axis=1, inplace=True)
1177

1178 4
            ret.append(df)
1179

1180 4
        if len(molecules) == 0:
1181 1
            raise KeyError("Query matched 0 records.")
1182

1183 4
        if merge:
1184 4
            retdf = ret[0]
1185 4
            for df in ret[1:]:
1186 4
                retdf += df
1187 4
            return retdf
1188
        else:
1189 1
            return ret
1190

1191 4
    def _compute(
1192
        self,
1193
        compute_keys: Dict[str, Union[str, None]],
1194
        molecules: Union[List[str], pd.Series],
1195
        tag: Optional[str] = None,
1196
        priority: Optional[str] = None,
1197
        protocols: Optional[Dict[str, Any]] = None,
1198
    ) -> ComputeResponse:
1199
        """
1200
        Internal compute function
1201
        """
1202

1203 4
        name, dbkeys, history = self._default_parameters(
1204
            compute_keys["program"],
1205
            compute_keys["method"],
1206
            compute_keys["basis"],
1207
            compute_keys["keywords"],
1208
            stoich=compute_keys.get("stoich", None),
1209
        )
1210

1211 4
        self._check_client()
1212 4
        self._check_state()
1213

1214 4
        umols = list(set(molecules))
1215

1216 4
        ids: List[Optional[ObjectId]] = []
1217 4
        submitted: List[ObjectId] = []
1218 4
        existing: List[ObjectId] = []
1219 4
        for compute_set in composition_planner(**dbkeys):
1220

1221 4
            for i in range(0, len(umols), self.client.query_limit):
1222 4
                chunk_mols = umols[i : i + self.client.query_limit]
1223 4
                ret = self.client.add_compute(
1224
                    **compute_set, molecule=chunk_mols, tag=tag, priority=priority, protocols=protocols
1225
                )
1226

1227 4
                ids.extend(ret.ids)
1228 4
                submitted.extend(ret.submitted)
1229 4
                existing.extend(ret.existing)
1230

1231 4
            qhistory = history.copy()
1232 4
            qhistory["program"] = compute_set["program"]
1233 4
            qhistory["method"] = compute_set["method"]
1234 4
            qhistory["basis"] = compute_set["basis"]
1235 4
            self._add_history(**qhistory)
1236

1237 4
        return ComputeResponse(ids=ids, submitted=submitted, existing=existing)
1238

1239 4
    @property
1240
    def units(self):
1241 4
        return self._units
1242

1243 4
    @units.setter
1244
    def units(self, value):
1245 4
        for column in self.df.columns:
1246 4
            try:
1247 4
                self.df[column] *= constants.conversion_factor(self._column_metadata[column]["units"], value)
1248

1249
                # Cast units to quantities so that `kcal / mol` == `kilocalorie / mole`
1250 4
                metadata_quantity = constants.Quantity(self._column_metadata[column]["units"])
1251 4
                self_quantity = constants.Quantity(self._units)
1252 4
                if metadata_quantity != self_quantity:
1253 0
                    warnings.warn(
1254
                        f"Data column '{column}' did not have the same units as the dataset. "
1255
                        f"This has been corrected."
1256
                    )
1257 4
                self._column_metadata[column]["units"] = value
1258 4
            except (ValueError, TypeError) as e:
1259
                # This is meant to catch pint.errors.DimensionalityError without importing pint, which is too slow.
1260
                # In pint <=0.9, DimensionalityError is a ValueError.
1261
                # In pint >=0.10, DimensionalityError is TypeError.
1262 4
                if e.__class__.__name__ == "DimensionalityError":
1263 4
                    pass
1264
                else:
1265 0
                    raise
1266 4
        self._units = value
1267

1268 4
    def set_default_program(self, program: str) -> bool:
1269
        """
1270
        Sets the default program.
1271

1272
        Parameters
1273
        ----------
1274
        program : str
1275
            The program to default to.
1276
        """
1277

1278 1
        self.data.__dict__["default_program"] = program.lower()
1279 1
        return True
1280

1281 4
    def set_default_benchmark(self, benchmark: str) -> bool:
1282
        """
1283
        Sets the default benchmark value.
1284

1285
        Parameters
1286
        ----------
1287
        benchmark : str
1288
            The benchmark to default to.
1289
        """
1290

1291 1
        self.data.__dict__["default_benchmark"] = benchmark
1292 1
        return True
1293

1294 4
    def add_keywords(self, alias: str, program: str, keyword: "KeywordSet", default: bool = False) -> bool:
1295
        """
1296
        Adds an option alias to the dataset. Not that keywords are not present
1297
        until a save call has been completed.
1298

1299
        Parameters
1300
        ----------
1301
        alias : str
1302
            The alias of the option
1303
        program : str
1304
            The compute program the alias is for
1305
        keyword : KeywordSet
1306
            The Keywords object to use.
1307
        default : bool, optional
1308
            Sets this option as the default for the program
1309

1310
        """
1311

1312 4
        alias = alias.lower()
1313 4
        program = program.lower()
1314 4
        if program not in self.data.alias_keywords:
1315 4
            self.data.alias_keywords[program] = {}
1316

1317 4
        if alias in self.data.alias_keywords[program]:
1318 0
            raise KeyError("Alias '{}' already set for program {}.".format(alias, keyword.program))
1319

1320 4
        self._new_keywords[(program, alias)] = keyword
1321

1322 4
        if default:
1323 4
            self.data.default_keywords[program] = alias
1324 4
        return True
1325

1326 4
    def list_keywords(self) -> pd.DataFrame:
1327
        """Lists keyword aliases for each program in the dataset.
1328

1329
        Returns
1330
        -------
1331
        pd.DataFrame
1332
            A dataframe containing programs, keyword aliases, KeywordSet ids, and whether those keywords are the
1333
            default for a program. Indexed on program.
1334
        """
1335 4
        data = []
1336 4
        for program, kwaliases in self.data.alias_keywords.items():
1337 4
            prog_default_kw = self.data.default_keywords.get(program, None)
1338 4
            for kwalias, kwid in kwaliases.items():
1339 4
                data.append(
1340
                    {
1341
                        "program": program,
1342
                        "keywords": kwalias,
1343
                        "id": kwid,
1344
                        "default": prog_default_kw == kwalias,
1345
                    }
1346
                )
1347 4
        return pd.DataFrame(data).set_index("program")
1348

1349 4
    def get_keywords(self, alias: str, program: str, return_id: bool = False) -> Union["KeywordSet", str]:
1350
        """Pulls the keywords alias from the server for inspection.
1351

1352
        Parameters
1353
        ----------
1354
        alias : str
1355
            The keywords alias.
1356
        program : str
1357
            The program the keywords correspond to.
1358
        return_id : bool, optional
1359
            If True, returns the ``id`` rather than the ``KeywordSet`` object.
1360
            Description
1361

1362
        Returns
1363
        -------
1364
        Union['KeywordSet', str]
1365
            The requested ``KeywordSet`` or ``KeywordSet`` ``id``.
1366

1367
        """
1368 4
        self._check_client()
1369 4
        if alias is None:
1370 4
            if return_id:
1371 4
                return None
1372
            else:
1373 0
                return {}
1374

1375 4
        alias = alias.lower()
1376 4
        program = program.lower()
1377 4
        if (program not in self.data.alias_keywords) or (alias not in self.data.alias_keywords[program]):
1378 0
            raise KeyError("Keywords {}: {} not found.".format(program, alias))
1379

1380 4
        kwid = self.data.alias_keywords[program][alias]
1381 4
        if return_id:
1382 4
            return kwid
1383
        else:
1384 1
            return self.client.query_keywords([kwid])[0]
1385

1386 4
    def add_contributed_values(self, contrib: ContributedValues, overwrite: bool = False) -> None:
1387
        """
1388
        Adds a ContributedValues to the database. Be sure to call save() to commit changes to the server.
1389

1390
        Parameters
1391
        ----------
1392
        contrib : ContributedValues
1393
            The ContributedValues to add.
1394
        overwrite : bool, optional
1395
            Overwrites pre-existing values
1396
        """
1397 4
        self.get_entries(force=True)
1398 4
        self._ensure_contributed_values()
1399

1400
        # Convert and validate
1401 4
        if isinstance(contrib, ContributedValues):
1402 0
            contrib = contrib.copy()
1403
        else:
1404 4
            contrib = ContributedValues(**contrib)
1405

1406 4
        if set(contrib.index) != set(self.get_index()):
1407 4
            raise ValueError("Contributed values indices do not match the entries in the dataset.")
1408

1409
        # Check the key
1410 4
        key = contrib.name.lower()
1411 4
        if (key in self.data.contributed_values) and (overwrite is False):
1412 4
            raise KeyError(
1413
                "Key '{}' already found in contributed values. Use `overwrite=True` to force an update.".format(key)
1414
            )
1415

1416 4
        self.data.contributed_values[key] = contrib
1417 4
        self._updated_state = True
1418

1419 4
    def _ensure_contributed_values(self) -> None:
1420 4
        if self.data.contributed_values is None:
1421 4
            self._get_data_records_from_db()
1422

1423 4
    def _list_contributed_values(self) -> pd.DataFrame:
1424
        """
1425
        Lists all specifications of contributed data, i.e. method, program, basis set, keyword set, driver combinations
1426

1427
        Returns
1428
        -------
1429
        DataFrame
1430
            Contributed value specifications.
1431
        """
1432 4
        self._ensure_contributed_values()
1433 4
        ret = pd.DataFrame(columns=self.data.history_keys + tuple(["name"]))
1434

1435 4
        cvs = (
1436
            (cv_data.name, cv_data.theory_level_details) for (cv_name, cv_data) in self.data.contributed_values.items()
1437
        )
1438

1439 4
        for cv_name, theory_level_details in cvs:
1440 4
            spec = {"name": cv_name}
1441 4
            for k in self.data.history_keys:
1442 4
                spec[k] = "Unknown"
1443
            # ReactionDataset uses "default" as a default value for stoich,
1444
            # but many contributed datasets lack a stoich field
1445 4
            if "stoichiometry" in self.data.history_keys:
1446 4
                spec["stoichiometry"] = "default"
1447 4
            if isinstance(theory_level_details, dict):
1448 4
                spec.update(**theory_level_details)
1449 4
            ret = ret.append(spec, ignore_index=True)
1450

1451 4
        return ret
1452

1453 4
    def _subset_in_cache(self, column_name: str, subset: Set[str]) -> bool:
1454 4
        try:
1455 4
            return not self.df.loc[subset, column_name].isna().any()
1456 4
        except KeyError:
1457 4
            return False
1458

1459 4
    def _update_cache(self, new_data: pd.DataFrame) -> None:
1460 4
        new_df = pd.DataFrame(
1461
            index=set(self.df.index) | set(new_data.index), columns=set(self.df.columns) | set(new_data.columns)
1462
        )
1463 4
        new_df.update(new_data)
1464 4
        new_df.update(self.df)
1465 4
        self.df = new_df
1466

1467 4
    def _get_contributed_values(self, subset: Set[str], force: bool = False, **spec) -> pd.DataFrame:
1468

1469 4
        cv_list = self.list_values(native=False, force=force).reset_index()
1470 4
        queries = self._filter_records(cv_list.rename(columns={"stoichiometry": "stoich"}), **spec)
1471 4
        column_names: List[str] = []
1472 4
        new_queries = []
1473

1474 4
        for query in queries.to_dict("records"):
1475 4
            column_name = query["name"]
1476 4
            column_names.append(column_name)
1477 4
            if force or not self._subset_in_cache(column_name, subset):
1478 4
                self._column_metadata[column_name] = query
1479 4
                new_queries.append(query)
1480

1481 4
        new_data = pd.DataFrame(index=subset)
1482 4
        if not self._use_view(force):
1483 4
            self._ensure_contributed_values()
1484 4
            units: Dict[str, str] = {}
1485

1486 4
            for query in new_queries:
1487 4
                data = self.data.contributed_values[query["name"].lower()].copy()
1488 4
                column_name = data.name
1489

1490
                # Annoying work around to prevent some pandas magic
1491 4
                if isinstance(data.values[0], (int, float, bool, np.number)):
1492 4
                    values = data.values
1493
                else:
1494
                    # TODO temporary patch until msgpack collections
1495 4
                    if isinstance(data.theory_level_details, dict) and "driver" in data.theory_level_details:
1496 4
                        cv_driver = data.theory_level_details["driver"]
1497
                    else:
1498 1
                        cv_driver = self.data.default_driver
1499

1500 4
                    if cv_driver == "gradient":
1501 4
                        values = [np.array(v).reshape(-1, 3) for v in data.values]
1502
                    else:
1503 4
                        values = [np.array(v) for v in data.values]
1504

1505 4
                new_data[column_name] = pd.Series(values, index=data.index)[subset]
1506 4
                units[column_name] = data.units
1507
        else:
1508 4
            for query in new_queries:
1509 4
                query["native"] = False
1510 4
            new_data, units = self._view.get_values(new_queries, subset)
1511

1512
        # convert units
1513 4
        for query in new_queries:
1514 4
            column_name = query["name"]
1515 4
            metadata = {"native": False}
1516 4
            try:
1517 4
                new_data[column_name] *= constants.conversion_factor(units[column_name], self.units)
1518 4
                metadata["units"] = self.units
1519 4
            except (ValueError, TypeError) as e:
1520
                # This is meant to catch pint.errors.DimensionalityError without importing pint, which is too slow.
1521
                # In pint <=0.9, DimensionalityError is a ValueError.
1522
                # In pint >=0.10, DimensionalityError is TypeError.
1523 4
                if e.__class__.__name__ == "DimensionalityError":
1524 4
                    metadata["units"] = units[column_name]
1525
                else:
1526 0
                    raise
1527 4
            self._column_metadata[column_name].update(metadata)
1528

1529 4
        self._update_cache(new_data)
1530 4
        return self.df.loc[subset, column_names]
1531

1532 4
    def get_molecules(
1533
        self, subset: Optional[Union[str, Set[str]]] = None, force: bool = False
1534
    ) -> Union[pd.DataFrame, "Molecule"]:
1535
        """Queries full Molecules from the database.
1536

1537
        Parameters
1538
        ----------
1539
        subset : Optional[Union[str, Set[str]]], optional
1540
            The index subset to query on
1541
        force : bool, optional
1542
            Force pull of molecules from server
1543

1544
        Returns
1545
        -------
1546
        Union[pd.DataFrame, 'Molecule']
1547
            Either a DataFrame of indexed Molecules or a single Molecule if a single subset string was provided.
1548
        """
1549 4
        indexer = self._molecule_indexer(subset=subset, force=force)
1550 4
        df = self._get_molecules(indexer, force)
1551

1552 4
        if isinstance(subset, str):
1553 1
            return df.iloc[0, 0]
1554
        else:
1555 4
            return df
1556

1557 4
    def get_records(
1558
        self,
1559
        method: str,
1560
        basis: Optional[str] = None,
1561
        *,
1562
        keywords: Optional[str] = None,
1563
        program: Optional[str] = None,
1564
        include: Optional[List[str]] = None,
1565
        subset: Optional[Union[str, Set[str]]] = None,
1566
        merge: bool = False,
1567
    ) -> Union[pd.DataFrame, "ResultRecord"]:
1568
        """
1569
        Queries full ResultRecord objects from the database.
1570

1571
        Parameters
1572
        ----------
1573
        method : str
1574
            The computational method to query on (B3LYP)
1575
        basis : Optional[str], optional
1576
            The computational basis query on (6-31G)
1577
        keywords : Optional[str], optional
1578
            The option token desired
1579
        program : Optional[str], optional
1580
            The program to query on
1581
        include : Optional[List[str]], optional
1582
            The attributes to return. Otherwise returns ResultRecord objects.
1583
        subset : Optional[Union[str, Set[str]]], optional
1584
            The index subset to query on
1585
        merge : bool
1586
            Merge multiple results into one (as in the case of DFT-D3).
1587
            This only works when include=['return_results'], as in get_values.
1588

1589
        Returns
1590
        -------
1591
        Union[pd.DataFrame, 'ResultRecord']
1592
            Either a DataFrame of indexed ResultRecords or a single ResultRecord if a single subset string was provided.
1593
        """
1594 4
        name, _, history = self._default_parameters(program, method, basis, keywords)
1595 4
        if len(self.list_records(**history)) == 0:
1596 1
            raise KeyError(f"Requested query ({name}) did not match a known record.")
1597

1598 4
        indexer = self._molecule_indexer(subset=subset, force=True)
1599 4
        df = self._get_records(indexer, history, include=include, merge=merge)
1600

1601 4
        if not merge and len(df) == 1:
1602 1
            df = df[0]
1603

1604 4
        if len(df) == 0:
1605 0
            raise KeyError("Query matched no records!")
1606

1607 4
        if isinstance(subset, str):
1608 1
            return df.iloc[0, 0]
1609
        else:
1610 4
            return df
1611

1612 4
    def add_entry(self, name: str, molecule: "Molecule", **kwargs: Dict[str, Any]) -> None:
1613
        """Adds a new entry to the Dataset
1614

1615
        Parameters
1616
        ----------
1617
        name : str
1618
            The name of the record
1619
        molecule : Molecule
1620
            The Molecule associated with this record
1621
        **kwargs : Dict[str, Any]
1622
            Additional arguments to pass to the record
1623
        """
1624 4
        mhash = molecule.get_hash()
1625 4
        self._new_molecules[mhash] = molecule
1626 4
        self._new_records.append({"name": name, "molecule_hash": mhash, **kwargs})
1627

1628 4
    def compute(
1629
        self,
1630
        method: str,
1631
        basis: Optional[str] = None,
1632
        *,
1633
        keywords: Optional[str] = None,
1634
        program: Optional[str] = None,
1635
        tag: Optional[str] = None,
1636
        priority: Optional[str] = None,
1637
        protocols: Optional[Dict[str, Any]] = None,
1638
    ) -> ComputeResponse:
1639
        """Executes a computational method for all reactions in the Dataset.
1640
        Previously completed computations are not repeated.
1641

1642
        Parameters
1643
        ----------
1644
        method : str
1645
            The computational method to compute (B3LYP)
1646
        basis : Optional[str], optional
1647
            The computational basis to compute (6-31G)
1648
        keywords : Optional[str], optional
1649
            The keyword alias for the requested compute
1650
        program : Optional[str], optional
1651
            The underlying QC program
1652
        tag : Optional[str], optional
1653
            The queue tag to use when submitting compute requests.
1654
        priority : Optional[str], optional
1655
            The priority of the jobs low, medium, or high.
1656
        protocols: Optional[Dict[str, Any]], optional
1657
            Protocols for store more or less data per field. Current valid
1658
            protocols: {'wavefunction'}
1659

1660
        Returns
1661
        -------
1662
        ComputeResponse
1663
            An object that contains the submitted ObjectIds of the new compute. This object has the following fields:
1664
              - ids: The ObjectId's of the task in the order of input molecules
1665
              - submitted: A list of ObjectId's that were submitted to the compute queue
1666
              - existing: A list of ObjectId's of tasks already in the database
1667
        """
1668 4
        self.get_entries(force=True)
1669 4
        compute_keys = {"program": program, "method": method, "basis": basis, "keywords": keywords}
1670

1671 4
        molecule_idx = [e.molecule_id for e in self.data.records]
1672

1673 4
        ret = self._compute(compute_keys, molecule_idx, tag, priority, protocols)
1674 4
        self.save()
1675

1676 4
        return ret
1677

1678 4
    def get_index(self, subset: Optional[List[str]] = None, force: bool = False) -> List[str]:
1679
        """
1680
        Returns the current index of the database.
1681

1682
        Returns
1683
        -------
1684
        ret : List[str]
1685
            The names of all reactions in the database
1686
        """
1687 4
        return list(self.get_entries(subset=subset, force=force)["name"].unique())
1688

1689
    # Statistical quantities
1690 4
    def statistics(
1691
        self, stype: str, value: str, bench: Optional[str] = None, **kwargs: Dict[str, Any]
1692
    ) -> Union[np.ndarray, pd.Series, np.float64]:
1693
        """Provides statistics for various columns in the underlying dataframe.
1694

1695
        Parameters
1696
        ----------
1697
        stype : str
1698
            The type of statistic in question
1699
        value : str
1700
            The method string to compare
1701
        bench : str, optional
1702
            The benchmark method for the comparison, defaults to `default_benchmark`.
1703
        kwargs: Dict[str, Any]
1704
            Additional kwargs to pass to the statistics functions
1705

1706

1707
        Returns
1708
        -------
1709
        np.ndarray, pd.Series, float
1710
            Returns an ndarray, Series, or float with the requested statistics depending on input.
1711
        """
1712

1713 4
        if bench is None:
1714 4
            bench = self.data.default_benchmark
1715

1716 4
        if bench is None:
1717 0
            raise KeyError("No benchmark provided and default_benchmark is None!")
1718

1719 4
        return wrap_statistics(stype.upper(), self, value, bench, **kwargs)
1720

1721 4
    def _use_view(self, force: bool = False) -> bool:
1722
        """Helper function to decide whether to use a locally available HDF5 view"""
1723 4
        return (force is False) and (self._view is not None) and (self._disable_view is False)
1724

1725 4
    def _clear_cache(self) -> None:
1726 4
        self.df = pd.DataFrame()
1727 4
        self.data.__dict__["records"] = None
1728 4
        self.data.__dict__["contributed_values"] = None
1729

1730
    # Getters
1731 4
    def __getitem__(self, args: str) -> pd.Series:
1732
        """A wrapped to the underlying pd.DataFrame to access columnar data
1733

1734
        Parameters
1735
        ----------
1736
        args : str
1737
            The column to access
1738

1739
        Returns
1740
        -------
1741
        ret : pd.Series, pd.DataFrame
1742
            A view of the underlying dataframe data
1743
        """
1744 0
        return self.df[args]
1745

1746

1747 4
register_collection(Dataset)

Read our documentation on viewing source code .

Loading