1
"""
2
QCPortal Database ODM
3
"""
4 4
import itertools as it
5 4
from enum import Enum
6 4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Union
7

8 4
import numpy as np
9 4
import pandas as pd
10 4
from qcelemental import constants
11

12 4
from ..models import Molecule, ProtoModel
13 4
from ..util import replace_dict_keys
14 4
from .collection_utils import nCr, register_collection
15 4
from .dataset import Dataset
16

17
if TYPE_CHECKING:  # pragma: no cover
18
    from .. import FractalClient
19
    from ..models import ComputeResponse
20

21

22 4
class _ReactionTypeEnum(str, Enum):
23
    """Helper class for locking the reaction type into one or the other"""
24

25 4
    rxn = "rxn"
26 4
    ie = "ie"
27

28

29 4
class ReactionEntry(ProtoModel):
30
    """Data model for the `reactions` list in Dataset"""
31

32 4
    attributes: Dict[str, Union[int, float, str]]  # Might be overloaded key types
33 4
    reaction_results: Dict[str, dict]
34 4
    name: str
35 4
    stoichiometry: Dict[str, Dict[str, float]]
36 4
    extras: Dict[str, Any] = {}
37

38

39 4
class ReactionDataset(Dataset):
40
    """
41
    The ReactionDataset class for homogeneous computations on many reactions.
42

43
    Attributes
44
    ----------
45
    client : client.FractalClient
46
        A FractalClient connected to a server
47
    data : ReactionDataset.DataModel
48
        A Model representation of the database backbone
49
    df : pd.DataFrame
50
        The underlying dataframe for the Dataset object
51
    rxn_index : pd.Index
52
        The unrolled reaction index for all reactions in the Dataset
53
    """
54

55 4
    def __init__(self, name: str, client: Optional["FractalClient"] = None, ds_type: str = "rxn", **kwargs) -> None:
56
        """
57
        Initializer for the Dataset object. If no Portal is supplied or the database name
58
        is not present on the server that the Portal is connected to a blank database will be
59
        created.
60

61
        Parameters
62
        ----------
63
        name : str
64
            The name of the Dataset
65
        client : client.FractalClient, optional
66
            A FractalClient connected to a server
67
        ds_type : str, optional
68
            The type of Dataset involved
69

70
        """
71 4
        ds_type = ds_type.lower()
72 4
        super().__init__(name, client=client, ds_type=ds_type, **kwargs)
73

74 4
    class DataModel(Dataset.DataModel):
75

76 4
        ds_type: _ReactionTypeEnum = _ReactionTypeEnum.rxn
77 4
        records: Optional[List[ReactionEntry]] = None
78

79 4
        history: Set[Tuple[str, str, str, Optional[str], Optional[str], str]] = set()
80 4
        history_keys: Tuple[str, str, str, str, str, str] = (
81
            "driver",
82
            "program",
83
            "method",
84
            "basis",
85
            "keywords",
86
            "stoichiometry",
87
        )
88

89 4
    def _entry_index(self, subset: Optional[List[str]] = None) -> None:
90 4
        if self.data.records is None:
91 4
            self._get_data_records_from_db()
92
        # Unroll the index
93 4
        tmp_index = []
94 4
        for rxn in self.data.records:
95 4
            name = rxn.name
96 4
            for stoich_name in list(rxn.stoichiometry):
97 4
                for mol_hash, coef in rxn.stoichiometry[stoich_name].items():
98 4
                    tmp_index.append([name, stoich_name, mol_hash, coef])
99 4
        ret = pd.DataFrame(tmp_index, columns=["name", "stoichiometry", "molecule", "coefficient"])
100 4
        if subset is None:
101 4
            return ret
102
        else:
103 4
            return ret.reset_index().set_index("name").loc[subset].reset_index().set_index("index")
104

105 4
    def _molecule_indexer(
106
        self,
107
        stoich: Union[str, List[str]],
108
        subset: Optional[Union[str, Set[str]]] = None,
109
        coefficients: bool = False,
110
        force: bool = False,
111
    ) -> Tuple[Dict[Tuple[str, ...], "ObjectId"], Tuple[str]]:
112
        """Provides a {index: molecule_id} mapping for a given subset.
113

114
        Parameters
115
        ----------
116
        stoich : Union[str, List[str]]
117
            The stoichiometries, or list of stoichiometries to return
118
        subset : Optional[Union[str, Set[str]]], optional
119
            The indices of the desired subset. Return all indices if subset is None.
120
        coefficients : bool, optional
121
            Returns the coefficients if as part of the index if True
122

123
        No Longer Returned
124
        ------------------
125
        Dict[str, 'ObjectId']
126
            Molecule index to molecule ObjectId map
127

128
        Returns
129
        -------
130
        Tuple[Dict[Tuple[str, ...], 'ObjectId'], Tuple[str]]
131
            Molecule index to molecule ObjectId map, and index names
132
        """
133 4
        if isinstance(stoich, str):
134 4
            stoich = [stoich]
135

136 4
        index = self.get_entries(subset=subset, force=force)
137 4
        matched_rows = index[np.in1d(index["stoichiometry"], stoich)]
138

139 4
        if subset:
140 4
            matched_rows = matched_rows[np.in1d(matched_rows["name"], subset)]
141

142 4
        names = ("name", "stoichiometry", "idx")
143 4
        if coefficients:
144 4
            names = names + ("coefficient",)
145

146 4
        ret = {}
147 4
        for gb_idx, group in matched_rows.groupby(["name", "stoichiometry"]):
148 4
            for cnt, (idx, row) in enumerate(group.iterrows()):
149 4
                if coefficients:
150 4
                    ret[gb_idx + (cnt, row["coefficient"])] = row["molecule"]
151
                else:
152 4
                    ret[gb_idx + (cnt,)] = row["molecule"]
153

154 4
        return ret, names
155

156 4
    def valid_stoich(self, subset=None, force: bool = False) -> Set[str]:
157 4
        entries = self.get_entries(subset=subset, force=force)
158 4
        return set(entries["stoichiometry"].unique())
159

160 4
    def _validate_stoich(self, stoich: Union[List[str], str], subset=None, force: bool = False) -> None:
161 4
        if isinstance(stoich, str):
162 4
            stoich = [stoich]
163 4
        if isinstance(subset, str):
164 0
            subset = [subset]
165 4
        valid_stoich = self.valid_stoich(subset=subset, force=force)
166 4
        for s in stoich:
167 4
            if s.lower() not in valid_stoich:
168 1
                raise KeyError("Stoichiometry not understood, valid keys are {}.".format(valid_stoich))
169

170 4
    def _pre_save_prep(self, client: "FractalClient") -> None:
171 4
        self._canonical_pre_save(client)
172

173 4
        mol_ret = self._add_molecules_by_dict(client, self._new_molecules)
174

175
        # Update internal molecule UUID's to servers UUID's
176 4
        for record in self._new_records:
177 4
            stoichiometry = replace_dict_keys(record.stoichiometry, mol_ret)
178 4
            new_record = record.copy(update={"stoichiometry": stoichiometry})
179 4
            self.data.records.append(new_record)
180

181 4
        self._new_records: List[ReactionEntry] = []
182 4
        self._new_molecules = {}
183

184 4
        self._entry_index()
185

186 4
    def get_values(
187
        self,
188
        method: Optional[Union[str, List[str]]] = None,
189
        basis: Optional[Union[str, List[str]]] = None,
190
        keywords: Optional[str] = None,
191
        program: Optional[str] = None,
192
        driver: Optional[str] = None,
193
        stoich: str = "default",
194
        name: Optional[Union[str, List[str]]] = None,
195
        native: Optional[bool] = None,
196
        subset: Optional[Union[str, List[str]]] = None,
197
        force: bool = False,
198
    ) -> pd.DataFrame:
199
        """
200
        Obtains values from the known history from the search paramaters provided for the expected `return_result` values.
201
        Defaults to the standard programs and keywords if not provided.
202

203
        Note that unlike `get_records`, `get_values` will automatically expand searches and return multiple method
204
        and basis combinations simultaneously.
205

206
        `None` is a wildcard selector. To search for `None`, use `"None"`.
207

208
        Parameters
209
        ----------
210
        method : Optional[Union[str, List[str]]], optional
211
            The computational method (B3LYP)
212
        basis : Optional[Union[str, List[str]]], optional
213
            The computational basis (6-31G)
214
        keywords : Optional[str], optional
215
            The keyword alias
216
        program : Optional[str], optional
217
            The underlying QC program
218
        driver : Optional[str], optional
219
            The type of calculation (e.g. energy, gradient, hessian, dipole...)
220
        stoich : str, optional
221
            Stoichiometry of the reaction.
222
        name : Optional[Union[str, List[str]]], optional
223
            Canonical name of the record. Overrides the above selectors.
224
        native: Optional[bool], optional
225
            True: only include data computed with QCFractal
226
            False: only include data contributed from outside sources
227
            None: include both
228
        subset: Optional[List[str]], optional
229
            The indices of the desired subset. Return all indices if subset is None.
230
        force : bool, optional
231
            Data is typically cached, forces a new query if True
232

233
        Returns
234
        ------
235
        DataFrame
236
           A DataFrame of values with columns corresponding to methods and rows corresponding to reaction entries.
237
           Contributed (native=False) columns are marked with "(contributed)" and may include units in square brackets
238
           if their units differ in dimensionality from the ReactionDataset's default units.
239
        """
240 4
        return self._get_values(
241
            method=method,
242
            basis=basis,
243
            keywords=keywords,
244
            program=program,
245
            driver=driver,
246
            stoich=stoich,
247
            name=name,
248
            native=native,
249
            subset=subset,
250
            force=force,
251
        )
252

253 4
    def _get_native_values(
254
        self,
255
        subset: Set[str],
256
        method: Optional[str] = None,
257
        basis: Optional[str] = None,
258
        keywords: Optional[str] = None,
259
        program: Optional[str] = None,
260
        stoich: Optional[str] = None,
261
        name: Optional[str] = None,
262
        force: bool = False,
263
    ) -> pd.DataFrame:
264 4
        self._validate_stoich(stoich, subset=subset, force=force)
265

266
        # So that datasets with no records do not require a default program and default keywords
267 4
        if len(self.list_records()) == 0:
268 0
            return pd.DataFrame(index=self.get_index(subset))
269

270 4
        queries = self._form_queries(
271
            method=method, basis=basis, keywords=keywords, program=program, stoich=stoich, name=name
272
        )
273

274 4
        if len(queries) == 0:
275 4
            return pd.DataFrame(index=self.get_index(subset))
276

277 4
        stoich_complex = queries.pop("stoichiometry").values[0]
278 4
        stoich_monomer = "".join([x for x in stoich_complex if not x.isdigit()]) + "1"
279

280 4
        def _query_apply_coeffients(stoich, query):
281

282
            # Build the starting table
283 4
            indexer, names = self._molecule_indexer(stoich=stoich, coefficients=True, force=force)
284 4
            df = self._get_records(indexer, query, include=["return_result"], merge=True)
285 4
            df.index = pd.MultiIndex.from_tuples(df.index, names=names)
286 4
            df.reset_index(inplace=True)
287

288
            # Block out null values `groupby.sum()` will return 0 rather than NaN in all cases
289 4
            null_mask = df[["name", "return_result"]].copy()
290 4
            null_mask["return_result"] = null_mask["return_result"].isnull()
291 4
            null_mask = null_mask.groupby(["name"])["return_result"].sum() != False
292

293
            # Multiply by coefficients and sum
294 4
            df["return_result"] *= df["coefficient"]
295 4
            df = df.groupby(["name"])["return_result"].sum()
296 4
            df[null_mask] = np.nan
297 4
            return df
298

299 4
        names = []
300 4
        new_queries = []
301 4
        new_data = pd.DataFrame(index=subset)
302

303 4
        for _, query in queries.iterrows():
304

305 4
            query = query.replace({np.nan: None}).to_dict()
306 4
            qname = query["name"]
307 4
            names.append(qname)
308

309 4
            if force or not self._subset_in_cache(qname, subset):
310 4
                self._column_metadata[qname] = query
311 4
                new_queries.append(query)
312

313 4
        if not self._use_view(force):
314 4
            units: Dict[str, str] = {}
315 4
            for query in new_queries:
316 4
                qname = query.pop("name")
317 4
                if self.data.ds_type == _ReactionTypeEnum.ie:
318
                    # This implements 1-body counterpoise correction
319
                    # TODO: this will need to contain the logic for VMFC or other method-of-increments strategies
320 4
                    data_complex = _query_apply_coeffients(stoich_complex, query)
321 4
                    data_monomer = _query_apply_coeffients(stoich_monomer, query)
322 4
                    data = data_complex - data_monomer
323 1
                elif self.data.ds_type == _ReactionTypeEnum.rxn:
324 1
                    data = _query_apply_coeffients(stoich_complex, query)
325
                else:
326 0
                    raise ValueError(
327
                        f"ReactionDataset ds_type is not a member of _ReactionTypeEnum. (Got {self.data.ds_type}.)"
328
                    )
329

330 4
                new_data[qname] = data * constants.conversion_factor("hartree", self.units)
331 4
                query["name"] = qname
332 4
                units[qname] = self.units
333
        else:
334 4
            for query in new_queries:
335 1
                query["native"] = True
336 4
            new_data, units = self._view.get_values(new_queries)
337 4
            for query in new_queries:
338 1
                qname = query["name"]
339 1
                new_data[qname] = new_data[qname] * constants.conversion_factor(units[qname], self.units)
340

341 4
        for query in new_queries:
342 4
            qname = query["name"]
343 4
            self._column_metadata[qname].update({"native": True, "units": units[qname]})
344

345 4
        self._update_cache(new_data)
346 4
        return self.df.loc[subset, names]
347

348 4
    def visualize(
349
        self,
350
        method: Optional[str] = None,
351
        basis: Optional[str] = None,
352
        keywords: Optional[str] = None,
353
        program: Optional[str] = None,
354
        stoich: str = "default",
355
        groupby: Optional[str] = None,
356
        metric: str = "UE",
357
        bench: Optional[str] = None,
358
        kind: str = "bar",
359
        return_figure: Optional[bool] = None,
360
        show_incomplete: bool = False,
361
    ) -> "plotly.Figure":
362
        """
363
        Parameters
364
        ----------
365
        method : Optional[str], optional
366
            Methods to query
367
        basis : Optional[str], optional
368
            Bases to query
369
        keywords : Optional[str], optional
370
            Keyword aliases to query
371
        program : Optional[str], optional
372
            Programs aliases to query
373
        stoich : str, optional
374
            Stoichiometry to query
375
        groupby : Optional[str], optional
376
            Groups the plot by this index.
377
        metric : str, optional
378
            The metric to use either UE (unsigned error) or URE (unsigned relative error)
379
        bench : Optional[str], optional
380
            The benchmark level of theory to use
381
        kind : str, optional
382
            The kind of chart to produce, either 'bar' or 'violin'
383
        return_figure : Optional[bool], optional
384
            If True, return the raw plotly figure. If False, returns a hosted iPlot. If None, return a iPlot display in Jupyter notebook and a raw plotly figure in all other circumstances.
385
        show_incomplete: bool, optional
386
            Display statistics method/basis set combinations where results are incomplete
387

388
        Returns
389
        -------
390
        plotly.Figure
391
            The requested figure.
392
        """
393

394 4
        query = {"method": method, "basis": basis, "keywords": keywords, "program": program, "stoichiometry": stoich}
395 4
        query = {k: v for k, v in query.items() if v is not None}
396

397 4
        return self._visualize(
398
            metric,
399
            bench,
400
            query=query,
401
            groupby=groupby,
402
            return_figure=return_figure,
403
            kind=kind,
404
            show_incomplete=show_incomplete,
405
        )
406

407 4
    def get_molecules(
408
        self,
409
        subset: Optional[Union[str, Set[str]]] = None,
410
        stoich: Union[str, List[str]] = "default",
411
        force: bool = False,
412
    ) -> pd.DataFrame:
413
        """Queries full Molecules from the database.
414

415
        Parameters
416
        ----------
417
        subset : Optional[Union[str, Set[str]]], optional
418
            The index subset to query on
419
        stoich : Union[str, List[str]], optional
420
            The stoichiometries to pull from, either a single or multiple stoichiometries
421
        force : bool, optional
422
            Force pull of molecules from server
423

424
        Return
425
        ------
426
        pd.DataFrame
427
            Indexed Molecules which match the stoich and subset string.
428
        """
429

430 4
        self._check_client()
431 4
        self._check_state()
432 4
        if isinstance(subset, str):
433 4
            subset = [subset]
434

435 4
        self._validate_stoich(stoich, subset=subset, force=force)
436

437 4
        indexer, names = self._molecule_indexer(stoich=stoich, subset=subset, force=force)
438 4
        df = self._get_molecules(indexer, force=force)
439 4
        df.index = pd.MultiIndex.from_tuples(df.index, names=names)
440

441 4
        return df
442

443 4
    def get_records(
444
        self,
445
        method: str,
446
        basis: Optional[str] = None,
447
        *,
448
        keywords: Optional[str] = None,
449
        program: Optional[str] = None,
450
        stoich: Union[str, List[str]] = "default",
451
        include: Optional[List[str]] = None,
452
        subset: Optional[Union[str, Set[str]]] = None,
453
    ) -> Union[pd.DataFrame, "ResultRecord"]:
454
        """
455
        Queries the local Portal for the requested keys and stoichiometry.
456

457
        Parameters
458
        ----------
459
        method : str
460
            The computational method to query on (B3LYP)
461
        basis : Optional[str], optional
462
            The computational basis to query on (6-31G)
463
        keywords : Optional[str], optional
464
            The option token desired
465
        program : Optional[str], optional
466
            The program to query on
467
        stoich : Union[str, List[str]], optional
468
            The given stoichiometry to compute.
469
        include : Optional[Dict[str, bool]], optional
470
            The attribute project to perform on the query, otherwise returns ResultRecord objects.
471
        subset : Optional[Union[str, Set[str]]], optional
472
            The index subset to query on
473

474
        Returns
475
        -------
476
        Union[pd.DataFrame, 'ResultRecord']
477
            The name of the queried column
478

479
        """
480

481 4
        self._check_client()
482 4
        self._check_state()
483

484 1
        method = method.upper()
485 1
        if isinstance(stoich, str):
486 1
            stoich = [stoich]
487

488 1
        ret = []
489 1
        for s in stoich:
490 1
            name, _, history = self._default_parameters(program, method, basis, keywords, stoich=s)
491 1
            history.pop("stoichiometry")
492 1
            indexer, names = self._molecule_indexer(stoich=s, subset=subset, force=True)
493 1
            df = self._get_records(
494
                indexer,
495
                history,
496
                include=include,
497
                merge=False,
498
                raise_on_plan="`get_records` can only be used for non-composite quantities. You likely queried a DFT+D method or similar that requires a combination of DFT and -D. Please query each piece separately.",
499
            )
500 1
            df = df[0]
501 1
            df.index = pd.MultiIndex.from_tuples(df.index, names=names)
502 1
            ret.append(df)
503

504 1
        ret = pd.concat(ret)
505 1
        ret.sort_index(inplace=True)
506

507 1
        return ret
508

509 4
    def compute(
510
        self,
511
        method: str,
512
        basis: Optional[str] = None,
513
        *,
514
        keywords: Optional[str] = None,
515
        program: Optional[str] = None,
516
        stoich: str = "default",
517
        ignore_ds_type: bool = False,
518
        tag: Optional[str] = None,
519
        priority: Optional[str] = None,
520
    ) -> "ComputeResponse":
521
        """Executes a computational method for all reactions in the Dataset.
522
        Previously completed computations are not repeated.
523

524
        Parameters
525
        ----------
526
        method : str
527
            The computational method to compute (B3LYP)
528
        basis : Optional[str], optional
529
            The computational basis to compute (6-31G)
530
        keywords : Optional[str], optional
531
            The keyword alias for the requested compute
532
        program : Optional[str], optional
533
            The underlying QC program
534
        stoich : str, optional
535
            The stoichiometry of the requested compute (cp/nocp/etc)
536
        ignore_ds_type : bool, optional
537
            Optionally only compute the "default" geometry
538
        tag : Optional[str], optional
539
            The queue tag to use when submitting compute requests.
540
        priority : Optional[str], optional
541
            The priority of the jobs low, medium, or high.
542

543
        Returns
544
        -------
545
        ComputeResponse
546
            An object that contains the submitted ObjectIds of the new compute. This object has the following fields:
547
              - ids: The ObjectId's of the task in the order of input molecules
548
              - submitted: A list of ObjectId's that were submitted to the compute queue
549
              - existing: A list of ObjectId's of tasks already in the database
550

551
        """
552 4
        self._check_client()
553 4
        self._check_state()
554

555 1
        entry_index = self.get_entries(force=True)
556

557 1
        self._validate_stoich(stoich, subset=None, force=True)
558 1
        compute_keys = {"program": program, "method": method, "basis": basis, "keywords": keywords, "stoich": stoich}
559

560
        # Figure out molecules that we need
561 1
        if (not ignore_ds_type) and (self.data.ds_type.lower() == "ie"):
562

563 1
            monomer_stoich = "".join([x for x in stoich if not x.isdigit()]) + "1"
564 1
            tmp_monomer = entry_index[entry_index["stoichiometry"] == monomer_stoich].copy()
565

566 1
            ret1 = self._compute(compute_keys, tmp_monomer["molecule"], tag, priority)
567

568 1
            tmp_complex = entry_index[entry_index["stoichiometry"] == stoich].copy()
569 1
            ret2 = self._compute(compute_keys, tmp_complex["molecule"], tag, priority)
570

571 1
            ret = ret1.merge(ret2)
572
        else:
573 1
            tmp_complex = entry_index[entry_index["stoichiometry"] == stoich].copy()
574 1
            ret = self._compute(compute_keys, tmp_complex["molecule"], tag, priority)
575

576
        # Update the record that this was computed
577 1
        self.save()
578

579 1
        return ret
580

581 4
    def get_rxn(self, name: str) -> ReactionEntry:
582
        """
583
        Returns the JSON object of a specific reaction.
584

585
        Parameters
586
        ----------
587
        name : str
588
            The name of the reaction to query
589

590
        Returns
591
        -------
592
        ret : dict
593
            The JSON representation of the reaction
594

595
        """
596

597 4
        found = []
598 4
        for num, x in enumerate(self.data.records):
599 4
            if x.name == name:
600 4
                found.append(num)
601

602 4
        if len(found) == 0:
603 0
            raise KeyError("Dataset:get_rxn: Reaction name '{}' not found.".format(name))
604

605 4
        if len(found) > 1:
606 0
            raise KeyError("Dataset:get_rxn: Multiple reactions of name '{}' found. Dataset failure.".format(name))
607

608 4
        return self.data.records[found[0]]
609

610
    # Visualization
611 4
    def ternary(self, cvals=None):
612
        """Plots a ternary diagram of the DataBase if available
613

614
        Parameters
615
        ----------
616
        cvals : None, optional
617
            Description
618

619
        """
620 0
        raise Exception("MPL not avail")
621

622
    #        return visualization.Ternary2D(self.df, cvals=cvals)
623

624
    # Adders
625

626 4
    def parse_stoichiometry(self, stoichiometry: List[Tuple[Union[Molecule, str], float]]) -> Dict[str, float]:
627
        """
628
        Parses a stiochiometry list.
629

630
        Parameters
631
        ----------
632
        stoichiometry : list
633
            A list of tuples describing the stoichiometry.
634

635
        Returns
636
        -------
637
        Dict[str, float]
638
            A dictionary describing the stoichiometry for use in the database.
639
            Keys are molecule hashes. Values are stoichiometric coefficients.
640

641
        Notes
642
        -----
643
        This function attempts to convert the molecule into its corresponding hash. The following will happen depending on the form of the Molecule.
644
            - Molecule hash - Used directly in the stoichiometry.
645
            - Molecule class - Hash is obtained and the molecule will be added to the database upon saving.
646
            - Molecule string - Molecule will be converted to a Molecule class and the same process as the above will occur.
647

648

649
        """
650

651 4
        mol_hashes = []
652 4
        mol_values = []
653

654 4
        for line in stoichiometry:
655 4
            if len(line) != 2:
656 0
                raise KeyError("Dataset: Parse stoichiometry: passed in as a list, must be of key : value type")
657

658
            # Get the values
659 4
            try:
660 4
                mol_values.append(float(line[1]))
661 0
            except:
662 0
                raise TypeError("Dataset: Parse stoichiometry: must be able to cast second value to a float.")
663

664
            # What kind of molecule is it?
665 4
            mol = line[0]
666

667 4
            if isinstance(mol, str) and (len(mol) == 40):
668 4
                molecule_hash = mol
669

670 4
            elif isinstance(mol, str):
671 4
                qcf_mol = Molecule.from_data(mol)
672

673 4
                molecule_hash = qcf_mol.get_hash()
674

675 4
                if molecule_hash not in list(self._new_molecules):
676 0
                    self._new_molecules[molecule_hash] = qcf_mol
677

678 4
            elif isinstance(mol, Molecule):
679 4
                molecule_hash = mol.get_hash()
680

681 4
                if molecule_hash not in list(self._new_molecules):
682 4
                    self._new_molecules[molecule_hash] = mol
683

684
            else:
685 0
                raise TypeError(
686
                    "Dataset: Parse stoichiometry: first value must either be a molecule hash, "
687
                    "a molecule str, or a Molecule class."
688
                )
689

690 4
            mol_hashes.append(molecule_hash)
691

692
        # Sum together the coefficients of duplicates
693 4
        ret: Dict[str, float] = {}
694 4
        for mol, coef in zip(mol_hashes, mol_values):
695 4
            if mol in list(ret):
696 4
                ret[mol] += coef
697
            else:
698 4
                ret[mol] = coef
699

700 4
        return ret
701

702 4
    def add_rxn(
703
        self,
704
        name: str,
705
        stoichiometry: Dict[str, List[Tuple[Molecule, float]]],
706
        reaction_results: Optional[Dict[str, str]] = None,
707
        attributes: Optional[Dict[str, Union[int, float, str]]] = None,
708
        other_fields: Optional[Dict[str, Any]] = None,
709
    ) -> ReactionEntry:
710
        """
711
        Adds a reaction to a database object.
712

713
        Parameters
714
        ----------
715
        name : str
716
            Name of the reaction.
717
        stoichiometry : list or dict
718
            Either a list or dictionary of lists
719
        reaction_results :  dict or None, Optional, Default: None
720
            A dictionary of the computed total interaction energy results
721
        attributes :  dict or None, Optional, Default: None
722
            A dictionary of attributes to assign to the reaction
723
        other_fields : dict or None, Optional, Default: None
724
            A dictionary of additional user defined fields to add to the reaction entry
725

726
        Returns
727
        -------
728
        ReactionEntry
729
            A complete specification of the reaction
730
        """
731 4
        if reaction_results is None:
732 4
            reaction_results = {}
733 4
        if attributes is None:
734 4
            attributes = {}
735 4
        if other_fields is None:
736 4
            other_fields = {}
737 4
        rxn_dict: Dict[str, Any] = {"name": name}
738

739
        # Set name
740 4
        if name in self.get_index():
741 0
            raise KeyError(
742
                "Dataset: Name '{}' already exists. "
743
                "Please either delete this entry or call the update function.".format(name)
744
            )
745

746
        # Set stoich
747 4
        if isinstance(stoichiometry, dict):
748 4
            rxn_dict["stoichiometry"] = {}
749

750 4
            if "default" not in list(stoichiometry):
751 4
                raise KeyError("Dataset:add_rxn: Stoichiometry dict must have a 'default' key.")
752

753 4
            for k, v in stoichiometry.items():
754 4
                rxn_dict["stoichiometry"][k] = self.parse_stoichiometry(v)
755

756 4
        elif isinstance(stoichiometry, (tuple, list)):
757 4
            rxn_dict["stoichiometry"] = {}
758 4
            rxn_dict["stoichiometry"]["default"] = self.parse_stoichiometry(stoichiometry)
759
        else:
760 0
            raise TypeError("Dataset:add_rxn: Type of stoichiometry input was not recognized:", type(stoichiometry))
761

762
        # Set attributes
763 4
        if not isinstance(attributes, dict):
764 0
            raise TypeError("Dataset:add_rxn: attributes must be a dictionary, not '{}'".format(type(attributes)))
765

766 4
        rxn_dict["attributes"] = attributes
767

768 4
        if not isinstance(other_fields, dict):
769 0
            raise TypeError("Dataset:add_rxn: other_fields must be a dictionary, not '{}'".format(type(attributes)))
770

771 4
        rxn_dict["extras"] = other_fields
772

773 4
        if "default" in list(reaction_results):
774 0
            rxn_dict["reaction_results"] = reaction_results
775 4
        elif isinstance(reaction_results, dict):
776 4
            rxn_dict["reaction_results"] = {}
777 4
            rxn_dict["reaction_results"]["default"] = reaction_results
778
        else:
779 0
            raise TypeError("Passed in reaction_results not understood.")
780

781 4
        rxn = ReactionEntry(**rxn_dict)
782 4
        self._new_records.append(rxn)
783

784 4
        return rxn
785

786 4
    def add_ie_rxn(self, name: str, mol: Molecule, **kwargs) -> ReactionEntry:
787
        """Add a interaction energy reaction entry to the database. Automatically
788
        builds CP and no-CP reactions for the fragmented molecule.
789

790
        Parameters
791
        ----------
792
        name : str
793
            The name of the reaction
794
        mol : Molecule
795
            A molecule with multiple fragments
796
        **kwargs
797
            Additional kwargs to pass into `build_id_fragments`.
798

799
        Returns
800
        -------
801
        ReactionEntry
802
            A representation of the new reaction.
803
        """
804 4
        reaction_results = kwargs.pop("reaction_results", {})
805 4
        attributes = kwargs.pop("attributes", {})
806 4
        other_fields = kwargs.pop("other_fields", {})
807

808 4
        stoichiometry = self.build_ie_fragments(mol, name=name, **kwargs)
809 4
        return self.add_rxn(
810
            name, stoichiometry, reaction_results=reaction_results, attributes=attributes, other_fields=other_fields
811
        )
812

813 4
    @staticmethod
814 4
    def build_ie_fragments(mol: Molecule, **kwargs) -> Dict[str, List[Tuple[Molecule, float]]]:
815
        """
816
        Build the stoichiometry for an Interaction Energy.
817

818
        Parameters
819
        ----------
820
        mol : Molecule class or str
821
            Molecule to fragment.
822
        do_default : bool
823
            Create the default (noCP) stoichiometry.
824
        do_cp : bool
825
            Create the counterpoise (CP) corrected stoichiometry.
826
        do_vmfc : bool
827
            Create the Valiron-Mayer Function Counterpoise (VMFC) corrected stoichiometry.
828
        max_nbody : int
829
            The maximum fragment level built, if zero defaults to the maximum number of fragments.
830

831
        Notes
832
        -----
833

834
        Returns
835
        -------
836
        ret : dict
837
            A JSON representation of the fragmented molecule.
838

839
        """
840

841 4
        do_default = kwargs.pop("do_default", True)
842 4
        do_cp = kwargs.pop("do_cp", True)
843 4
        do_vmfc = kwargs.pop("do_vmfc", False)
844 4
        max_nbody = kwargs.pop("max_nbody", 0)
845

846 4
        if not isinstance(mol, Molecule):
847

848 4
            mol = Molecule.from_data(mol, **kwargs)
849

850 4
        ret = {}
851

852 4
        max_frag = len(mol.fragments)
853 4
        if max_nbody == 0:
854 4
            max_nbody = max_frag
855 4
        if max_frag < 2:
856 0
            raise AttributeError("Dataset:build_ie_fragments: Molecule must have at least two fragments.")
857

858
        # Build some info
859 4
        fragment_range = list(range(max_frag))
860

861
        # Loop over the bodis
862 4
        for nbody in range(1, max_nbody):
863 4
            nocp_tmp = []
864 4
            cp_tmp = []
865 4
            for k in range(1, nbody + 1):
866 4
                take_nk = nCr(max_frag - k - 1, nbody - k)
867 4
                sign = (-1) ** (nbody - k)
868 4
                coef = take_nk * sign
869 4
                for frag in it.combinations(fragment_range, k):
870 4
                    if do_default:
871 4
                        nocp_tmp.append((mol.get_fragment(frag, orient=True, group_fragments=True), coef))
872 4
                    if do_cp:
873 4
                        ghost = list(set(fragment_range) - set(frag))
874 4
                        cp_tmp.append((mol.get_fragment(frag, ghost, orient=True, group_fragments=True), coef))
875

876 4
            if do_default:
877 4
                ret["default" + str(nbody)] = nocp_tmp
878

879 4
            if do_cp:
880 4
                ret["cp" + str(nbody)] = cp_tmp
881

882
        # VMFC is a special beast
883 4
        if do_vmfc:
884 0
            raise KeyError("VMFC isnt quite ready for primetime!")
885

886
            # ret.update({"vmfc" + str(nbody): [] for nbody in range(1, max_nbody)})
887
            # nbody_range = list(range(1, max_nbody))
888
            # for nbody in nbody_range:
889
            #     for cp_combos in it.combinations(fragment_range, nbody):
890
            #         basis_tuple = tuple(cp_combos)
891
            #         for interior_nbody in nbody_range:
892
            #             for x in it.combinations(cp_combos, interior_nbody):
893
            #                 ghost = list(set(basis_tuple) - set(x))
894
            #                 ret["vmfc" + str(interior_nbody)].append((mol.get_fragment(x, ghost), 1.0))
895

896
        # Add in the maximal position
897 4
        if do_default:
898 4
            ret["default"] = [(mol, 1.0)]
899

900 4
        if do_cp:
901 4
            ret["cp"] = [(mol, 1.0)]
902

903
        # if do_vmfc:
904
        #     ret["vmfc"] = [(mol, 1.0)]
905

906 4
        return ret
907

908

909 4
register_collection(ReactionDataset)

Read our documentation on viewing source code .

Loading