1
"""
2
SQLAlchemy Database class to handle access to Pstgres through ORM
3
"""
4

5 4
try:
6 4
    from sqlalchemy import create_engine, and_, or_, case, func
7 4
    from sqlalchemy.exc import IntegrityError
8 4
    from sqlalchemy.orm import sessionmaker, with_polymorphic
9 4
    from sqlalchemy.sql.expression import desc
10 4
    from sqlalchemy.sql.expression import case as expression_case
11 0
except ImportError:
12 0
    raise ImportError(
13
        "SQLAlchemy_socket requires sqlalchemy, please install this python " "module or try a different db_socket."
14
    )
15

16 4
import json
17 4
import logging
18 4
import secrets
19 4
from collections.abc import Iterable
20 4
from contextlib import contextmanager
21 4
from datetime import datetime as dt
22 4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
23

24 4
import bcrypt
25

26
# pydantic classes
27 4
from qcfractal.interface.models import (
28
    GridOptimizationRecord,
29
    KeywordSet,
30
    Molecule,
31
    ObjectId,
32
    OptimizationRecord,
33
    ResultRecord,
34
    TaskRecord,
35
    TaskStatusEnum,
36
    TorsionDriveRecord,
37
    KVStore,
38
    prepare_basis,
39
)
40 4
from qcfractal.storage_sockets.db_queries import QUERY_CLASSES
41 4
from qcfractal.storage_sockets.models import (
42
    AccessLogORM,
43
    BaseResultORM,
44
    CollectionORM,
45
    DatasetORM,
46
    GridOptimizationProcedureORM,
47
    KeywordsORM,
48
    KVStoreORM,
49
    MoleculeORM,
50
    OptimizationProcedureORM,
51
    QueueManagerLogORM,
52
    QueueManagerORM,
53
    ReactionDatasetORM,
54
    ResultORM,
55
    ServerStatsLogORM,
56
    ServiceQueueORM,
57
    TaskQueueORM,
58
    TorsionDriveProcedureORM,
59
    UserORM,
60
    VersionsORM,
61
    WavefunctionStoreORM,
62
)
63 4
from qcfractal.storage_sockets.storage_utils import add_metadata_template, get_metadata_template
64

65 4
from .models import Base
66

67 4
if TYPE_CHECKING:
68 0
    from ..services.service_util import BaseService
69

70
# for version checking
71 4
import qcelemental, qcfractal, qcengine
72

73 4
_null_keys = {"basis", "keywords"}
74 4
_id_keys = {"id", "molecule", "keywords", "procedure_id"}
75 4
_lower_func = lambda x: x.lower()
76 4
_prepare_keys = {"program": _lower_func, "basis": prepare_basis, "method": _lower_func, "procedure": _lower_func}
77

78

79 4
def dict_from_tuple(keys, values):
80 0
    return [dict(zip(keys, row)) for row in values]
81

82

83 4
def format_query(ORMClass, **query: Dict[str, Union[str, List[str]]]) -> Dict[str, Union[str, List[str]]]:
84
    """
85
    Formats a query into a SQLAlchemy format.
86
    """
87

88 4
    ret = []
89 4
    for k, v in query.items():
90 4
        if v is None:
91 4
            continue
92

93
        # Handle None keys
94 4
        k = k.lower()
95 4
        if (k in _null_keys) and (v == "null"):
96 4
            v = None
97

98 4
        if k in _prepare_keys:
99 4
            f = _prepare_keys[k]
100 4
            if isinstance(v, (list, tuple)):
101 4
                v = [f(x) for x in v]
102
            else:
103 4
                v = f(v)
104

105 4
        if isinstance(v, (list, tuple)):
106 4
            col = getattr(ORMClass, k)
107 4
            ret.append(getattr(col, "in_")(v))
108
        else:
109 4
            ret.append(getattr(ORMClass, k) == v)
110

111 4
    return ret
112

113

114 4
def get_count_fast(query):
115
    """
116
    returns total count of the query using:
117
        Fast: SELECT COUNT(*) FROM TestModel WHERE ...
118

119
    Not like q.count():
120
        Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
121
    """
122

123 4
    count_q = query.statement.with_only_columns([func.count()]).order_by(None)
124 4
    count = query.session.execute(count_q).scalar()
125

126 4
    return count
127

128

129 4
def get_procedure_class(record):
130

131 4
    if isinstance(record, OptimizationRecord):
132 4
        procedure_class = OptimizationProcedureORM
133 4
    elif isinstance(record, TorsionDriveRecord):
134 4
        procedure_class = TorsionDriveProcedureORM
135 1
    elif isinstance(record, GridOptimizationRecord):
136 1
        procedure_class = GridOptimizationProcedureORM
137
    else:
138 0
        raise TypeError("Procedure of type {} is not valid or supported yet.".format(type(record)))
139

140 4
    return procedure_class
141

142

143 4
def get_collection_class(collection_type):
144

145 4
    collection_map = {"dataset": DatasetORM, "reactiondataset": ReactionDatasetORM}
146

147 4
    collection_class = CollectionORM
148

149 4
    if collection_type in collection_map:
150 4
        collection_class = collection_map[collection_type]
151

152 4
    return collection_class
153

154

155 4
class SQLAlchemySocket:
156
    """
157
    SQLAlcehmy QCDB wrapper class.
158
    """
159

160 4
    def __init__(
161
        self,
162
        uri: str,
163
        project: str = "molssidb",
164
        bypass_security: bool = False,
165
        allow_read: bool = True,
166
        logger: "Logger" = None,
167
        sql_echo: bool = False,
168
        max_limit: int = 1000,
169
        skip_version_check: bool = False,
170
    ):
171
        """
172
        Constructs a new SQLAlchemy socket
173

174
        """
175

176
        # Logging data
177 4
        if logger:
178 0
            self.logger = logger
179
        else:
180 4
            self.logger = logging.getLogger("SQLAlcehmySocket")
181

182
        # Security
183 4
        self._bypass_security = bypass_security
184 4
        self._allow_read = allow_read
185

186 4
        self._lower_results_index = ["method", "basis", "program"]
187

188
        # disconnect from any active default connection
189
        # disconnect()
190 4
        if "psycopg2" not in uri:
191 4
            uri = uri.replace("postgresql", "postgresql+psycopg2")
192

193 4
        if project and not uri.endswith("/"):
194 0
            uri = uri + "/"
195

196 4
        uri = uri + project
197 4
        self.logger.info(f"SQLAlchemy attempt to connect to {uri}.")
198

199
        # Connect to DB and create session
200 4
        self.uri = uri
201 4
        self.engine = create_engine(
202
            uri,
203
            echo=sql_echo,  # echo for logging into python logging
204
            pool_size=5,  # 5 is the default, 0 means unlimited
205
        )
206 4
        self.logger.info(
207
            "Connected SQLAlchemy to DB dialect {} with driver {}".format(self.engine.dialect.name, self.engine.driver)
208
        )
209

210 4
        self.Session = sessionmaker(bind=self.engine)
211

212
        # check version compatibility
213 4
        db_ver = self.check_lib_versions()
214 4
        self.logger.info(f"DB versions: {db_ver}")
215 4
        if (not skip_version_check) and (db_ver and qcfractal.__version__ != db_ver["fractal_version"]):
216 0
            raise TypeError(
217
                f"You are running QCFractal version {qcfractal.__version__} "
218
                f'with an older DB version ({db_ver["fractal_version"]}). '
219
                f'Please run "qcfractal-server upgrade" first before starting the server.'
220
            )
221

222
        # actually create the tables
223 4
        try:
224 4
            Base.metadata.create_all(self.engine)
225 4
            self.check_lib_versions()  # update version if new DB
226 0
        except Exception as e:
227 0
            raise ValueError(f"SQLAlchemy Connection Error\n {str(e)}") from None
228

229
        # Advanced queries objects
230 4
        self._query_classes = {
231
            cls._class_name: cls(self.engine.url.database, max_limit=max_limit) for cls in QUERY_CLASSES
232
        }
233

234
        # if expanded_uri["password"] is not None:
235
        #     # connect to mongoengine
236
        #     self.client = db.connect(db=project, host=uri, authMechanism=authMechanism, authSource=authSource)
237
        # else:
238
        #     # connect to mongoengine
239
        #     self.client = db.connect(db=project, host=uri)
240

241
        # self._url, self._port = expanded_uri["nodelist"][0]
242

243
        # try:
244
        #     version_array = self.client.server_info()['versionArray']
245
        #
246
        #     if tuple(version_array) < (3, 2):
247
        #         raise RuntimeError
248
        # except AttributeError:
249
        #     raise RuntimeError(
250
        #         "Could not detect MongoDB version at URL {}. It may be a very old version or installed incorrectly. "
251
        #         "Choosing to stop instead of assuming version is at least 3.2.".format(uri))
252
        # except RuntimeError:
253
        #     # Trap low version
254
        #     raise RuntimeError("Connected MongoDB at URL {} needs to be at least version 3.2, found version {}.".
255
        #                        format(uri, self.client.server_info()['version']))
256

257 4
        self._project_name = project
258 4
        self._max_limit = max_limit
259

260 4
    def __str__(self) -> str:
261 0
        return f"<SQLAlchemySocket: address='{self.uri}`>"
262

263 4
    @contextmanager
264
    def session_scope(self):
265
        """Provide a transactional scope"""
266

267 4
        session = self.Session()
268 4
        try:
269 4
            yield session
270 4
            session.commit()
271 4
        except:
272 4
            session.rollback()
273 4
            raise
274
        finally:
275 4
            session.close()
276

277 4
    def _clear_db(self, db_name: str = None):
278
        """Dangerous, make sure you are deleting the right DB"""
279

280 4
        self.logger.warning("SQL: Clearing database '{}' and dropping all tables.".format(db_name))
281

282
        # drop all tables that it knows about
283 4
        Base.metadata.drop_all(self.engine)
284

285
        # create the tables again
286 4
        Base.metadata.create_all(self.engine)
287

288
        # self.client.drop_database(db_name)
289

290 4
    def _delete_DB_data(self, db_name):
291
        """TODO: needs more testing"""
292

293 4
        with self.session_scope() as session:
294
            # Metadata
295 4
            session.query(VersionsORM).delete(synchronize_session=False)
296
            # Task and services
297 4
            session.query(TaskQueueORM).delete(synchronize_session=False)
298 4
            session.query(QueueManagerLogORM).delete(synchronize_session=False)
299 4
            session.query(QueueManagerORM).delete(synchronize_session=False)
300 4
            session.query(ServiceQueueORM).delete(synchronize_session=False)
301

302
            # Collections
303 4
            session.query(CollectionORM).delete(synchronize_session=False)
304

305
            # Records
306 4
            session.query(TorsionDriveProcedureORM).delete(synchronize_session=False)
307 4
            session.query(GridOptimizationProcedureORM).delete(synchronize_session=False)
308 4
            session.query(OptimizationProcedureORM).delete(synchronize_session=False)
309 4
            session.query(ResultORM).delete(synchronize_session=False)
310 4
            session.query(WavefunctionStoreORM).delete(synchronize_session=False)
311 4
            session.query(BaseResultORM).delete(synchronize_session=False)
312

313
            # Auxiliary tables
314 4
            session.query(KVStoreORM).delete(synchronize_session=False)
315 4
            session.query(MoleculeORM).delete(synchronize_session=False)
316

317 4
    def get_project_name(self) -> str:
318 4
        return self._project_name
319

320 4
    def get_limit(self, limit: Optional[int]) -> int:
321
        """Get the allowed limit on results to return in queries based on the
322
        given `limit`. If this number is greater than the
323
        SQLAlchemySocket.max_limit then the max_limit will be returned instead.
324
        """
325

326 4
        return limit if limit is not None and limit < self._max_limit else self._max_limit
327

328 4
    def get_query_projection(self, className, query, *, limit=None, skip=0, include=None, exclude=None):
329

330 4
        if include and exclude:
331 0
            raise AttributeError(
332
                f"Either include or exclude can be "
333
                f"used, not both at the same query. "
334
                f"Given include: {include}, exclude: {exclude}"
335
            )
336

337 4
        prop, hybrids, relationships = className._get_col_types()
338

339
        # build projection from include or exclude
340 4
        _projection = []
341 4
        if include:
342 4
            _projection = set(include)
343 4
        elif exclude:
344 4
            _projection = set(className._all_col_names()) - set(exclude) - set(className.db_related_fields)
345 4
        _projection = list(_projection)
346

347 4
        proj = []
348 4
        join_attrs = {}
349 4
        callbacks = []
350

351
        # prepare hybrid attributes for callback and joins
352 4
        for key in _projection:
353 4
            if key in prop:  # normal column
354 4
                proj.append(getattr(className, key))
355

356
            # if hybrid property, save callback, and relation if any
357 4
            elif key in hybrids:
358 4
                callbacks.append(key)
359

360
                # if it has a relationship
361 4
                if key + "_obj" in relationships.keys():
362

363
                    # join_class_name = relationships[key + '_obj']
364 4
                    join_attrs[key] = relationships[key + "_obj"]
365
            else:
366 0
                raise AttributeError(f"Atrribute {key} is not found in class {className}.")
367

368 4
        for key in join_attrs:
369 4
            _projection.remove(key)
370

371 4
        with self.session_scope() as session:
372 4
            if _projection or join_attrs:
373

374 4
                if join_attrs and "id" not in _projection:  # if the id is need for joins
375 4
                    proj.append(getattr(className, "id"))
376 4
                    _projection.append("_id")  # not to be returned to user
377

378
                # query with projection, without joins
379 4
                data = session.query(*proj).filter(*query)
380

381 4
                n_found = get_count_fast(data)  # before iterating on the data
382 4
                data = data.limit(self.get_limit(limit)).offset(skip)
383 4
                rdata = [dict(zip(_projection, row)) for row in data]
384

385
                # query for joins if any (relationships and hybrids)
386 4
                if join_attrs:
387 4
                    res_ids = [d.get("id", d.get("_id")) for d in rdata]
388 4
                    res_ids.sort()
389 4
                    join_data = {res_id: {} for res_id in res_ids}
390

391
                    # relations data
392 4
                    for key, relation_details in join_attrs.items():
393 4
                        ret = (
394
                            session.query(
395
                                relation_details["remote_side_column"].label("id"), relation_details["join_class"]
396
                            )
397
                            .filter(relation_details["remote_side_column"].in_(res_ids))
398
                            .order_by(relation_details["remote_side_column"])
399
                            .all()
400
                        )
401 4
                        for res_id in res_ids:
402 4
                            join_data[res_id][key] = []
403 4
                            for res in ret:
404 4
                                if res_id == res[0]:
405 4
                                    join_data[res_id][key].append(res[1])
406

407 4
                        for data in rdata:
408 4
                            parent_id = data.get("id", data.get("_id"))
409 4
                            data[key] = join_data[parent_id][key]
410 4
                            data.pop("_id", None)
411

412
                # call hybrid methods
413 4
                for callback in callbacks:
414 4
                    for res in rdata:
415 4
                        res[callback] = getattr(className, "_" + callback)(res[callback])
416

417 4
                id_fields = className._get_fieldnames_with_DB_ids_()
418 4
                for d in rdata:
419
                    # Expand extra json into fields
420 4
                    if "extra" in d:
421 4
                        d.update(d["extra"])
422 4
                        del d["extra"]
423

424
                    # transform ids from int into str
425 4
                    for key in id_fields:
426 4
                        if key in d.keys() and d[key] is not None:
427 4
                            if isinstance(d[key], Iterable):
428 0
                                d[key] = [str(i) for i in d[key]]
429
                            else:
430 4
                                d[key] = str(d[key])
431
                # print('--------rdata after: ', rdata)
432
            else:
433 4
                data = session.query(className).filter(*query)
434

435
                # from sqlalchemy.dialects import postgresql
436
                # print(data.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
437 4
                n_found = get_count_fast(data)
438 4
                data = data.limit(self.get_limit(limit)).offset(skip).all()
439 4
                rdata = [d.to_dict() for d in data]
440

441 4
        return rdata, n_found
442

443
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
444

445 4
    def custom_query(self, class_name: str, query_key: str, **kwargs):
446
        """
447
        Run advanced or specialized queries on different classes
448

449
        Parameters
450
        ----------
451
        class_name : str
452
            REST APIs name of the class (not the actual python name),
453
             e.g., torsiondrive
454
        query_key : str
455
            The feature or attribute to look for, like initial_molecule
456
        kwargs
457
            Extra arguments needed by the query, like the id of the torison drive
458

459
        Returns
460
        -------
461
            Dict[str,Any]:
462
                Query result dictionary with keys:
463
                data: returned data by the query (variable format)
464
                meta:
465
                    success: True or False
466
                    error_description: Error msg to show to the user
467
        """
468

469 4
        ret = {"data": [], "meta": get_metadata_template()}
470

471 4
        try:
472 4
            if class_name not in self._query_classes:
473 0
                raise AttributeError(f"Class name {class_name} is not found.")
474

475 4
            session = self.Session()
476 4
            ret["data"] = self._query_classes[class_name].query(session, query_key, **kwargs)
477 4
            ret["meta"]["success"] = True
478 4
            try:
479 4
                ret["meta"]["n_found"] = len(ret["data"])
480 4
            except TypeError:
481 4
                ret["meta"]["n_found"] = 1
482 0
        except Exception as err:
483 0
            ret["meta"]["error_description"] = str(err)
484

485 4
        return ret
486

487
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488

489 4
    def save_access(self, log_data):
490

491 0
        with self.session_scope() as session:
492 0
            log = AccessLogORM(**log_data)
493 0
            session.add(log)
494 0
            session.commit()
495

496
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logs (KV store) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
497

498 4
    def add_kvstore(self, outputs: List[KVStore]):
499
        """
500
        Adds to the key/value store table.
501

502
        Parameters
503
        ----------
504
        outputs : List[Any]
505
            A list of KVStore objects add.
506

507
        Returns
508
        -------
509
        Dict[str, Any]
510
            Dictionary with keys data and meta, data is the ids of added blobs
511
        """
512

513 4
        meta = add_metadata_template()
514 4
        output_ids = []
515 4
        with self.session_scope() as session:
516 4
            for output in outputs:
517 4
                if output is None:
518 2
                    output_ids.append(None)
519 2
                    continue
520

521 4
                entry = KVStoreORM(**output.dict())
522 4
                session.add(entry)
523 4
                session.commit()
524 4
                output_ids.append(str(entry.id))
525 4
                meta["n_inserted"] += 1
526

527 4
        meta["success"] = True
528

529 4
        return {"data": output_ids, "meta": meta}
530

531 4
    def get_kvstore(self, id: List[ObjectId] = None, limit: int = None, skip: int = 0):
532
        """
533
        Pulls from the key/value store table.
534

535
        Parameters
536
        ----------
537
        id : List[str]
538
            A list of ids to query
539
        limit : Optional[int], optional
540
            Maximum number of results to return.
541
        skip : Optional[int], optional
542
            skip the `skip` results
543
        Returns
544
        -------
545
        Dict[str, Any]
546
            Dictionary with keys data and meta, data is a key-value dictionary of found key-value stored items.
547
        """
548

549 4
        meta = get_metadata_template()
550

551 4
        query = format_query(KVStoreORM, id=id)
552

553 4
        rdata, meta["n_found"] = self.get_query_projection(KVStoreORM, query, limit=limit, skip=skip)
554

555 4
        meta["success"] = True
556

557 4
        data = {}
558
        # TODO - after migrating everything, remove the 'value' column in the table
559 4
        for d in rdata:
560 4
            val = d.pop("value")
561 4
            if d["data"] is None:
562
                # Set the data field to be the string or dictionary
563 4
                d["data"] = val
564

565
                # Remove these and let the model handle the defaults
566 4
                d.pop("compression")
567 4
                d.pop("compression_level")
568

569
            # The KVStore constructor can handle conversion of strings and dictionaries
570 4
            data[d["id"]] = KVStore(**d)
571

572 4
        return {"data": data, "meta": meta}
573

574
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Molecule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
575

576 4
    def get_add_molecules_mixed(self, data: List[Union[ObjectId, Molecule]]) -> List[Molecule]:
577
        """
578
        Get or add the given molecules (if they don't exit).
579
        MoleculeORMs are given in a mixed format, either as a dict of mol data
580
        or as existing mol id
581

582
        TODO: to be split into get by_id and get_by_data
583
        """
584

585 4
        meta = get_metadata_template()
586

587 4
        ordered_mol_dict = {indx: mol for indx, mol in enumerate(data)}
588 4
        new_molecules = {}
589 4
        id_mols = {}
590 4
        for idx, mol in ordered_mol_dict.items():
591 4
            if isinstance(mol, (int, str)):
592 4
                id_mols[idx] = mol
593 4
            elif isinstance(mol, Molecule):
594 4
                new_molecules[idx] = mol
595
            else:
596 0
                meta["errors"].append((idx, "Data type not understood"))
597

598 4
        ret_mols = {}
599

600
        # Add all new molecules
601 4
        flat_mols = []
602 4
        flat_mol_keys = []
603 4
        for k, v in new_molecules.items():
604 4
            flat_mol_keys.append(k)
605 4
            flat_mols.append(v)
606 4
        flat_mols = self.add_molecules(flat_mols)["data"]
607

608 4
        id_mols.update({k: v for k, v in zip(flat_mol_keys, flat_mols)})
609

610
        # Get molecules by index and translate back to dict
611 4
        tmp = self.get_molecules(list(id_mols.values()))
612 4
        id_mols_list = tmp["data"]
613 4
        meta["errors"].extend(tmp["meta"]["errors"])
614

615 4
        inv_id_mols = {v: k for k, v in id_mols.items()}
616

617 4
        for mol in id_mols_list:
618 4
            ret_mols[inv_id_mols[mol.id]] = mol
619

620 4
        meta["success"] = True
621 4
        meta["n_found"] = len(ret_mols)
622 4
        meta["missing"] = list(ordered_mol_dict.keys() - ret_mols.keys())
623

624
        # Rewind to flat last
625 4
        ret = []
626 4
        for ind in range(len(ordered_mol_dict)):
627 4
            if ind in ret_mols:
628 4
                ret.append(ret_mols[ind])
629
            else:
630 4
                ret.append(None)
631

632 4
        return {"meta": meta, "data": ret}
633

634 4
    def add_molecules(self, molecules: List[Molecule]):
635
        """
636
        Adds molecules to the database.
637

638
        Parameters
639
        ----------
640
        molecules : List[Molecule]
641
            A List of molecule objects to add.
642

643
        Returns
644
        -------
645
        bool
646
            Whether the operation was successful.
647
        """
648

649 4
        meta = add_metadata_template()
650

651 4
        results = []
652 4
        with self.session_scope() as session:
653

654
            # Build out the ORMs
655 4
            orm_molecules = []
656 4
            for dmol in molecules:
657

658 4
                if dmol.validated is False:
659 0
                    dmol = Molecule(**dmol.dict(), validate=True)
660

661 4
                mol_dict = dmol.dict(exclude={"id", "validated"})
662

663
                # TODO: can set them as defaults in the sql_models, not here
664 4
                mol_dict["fix_com"] = True
665 4
                mol_dict["fix_orientation"] = True
666

667
                # Build fresh indices
668 4
                mol_dict["molecule_hash"] = dmol.get_hash()
669 4
                mol_dict["molecular_formula"] = dmol.get_molecular_formula()
670

671 4
                mol_dict["identifiers"] = {}
672 4
                mol_dict["identifiers"]["molecule_hash"] = mol_dict["molecule_hash"]
673 4
                mol_dict["identifiers"]["molecular_formula"] = mol_dict["molecular_formula"]
674

675
                # search by index keywords not by all keys, much faster
676 4
                orm_molecules.append(MoleculeORM(**mol_dict))
677

678
            # Check if we have duplicates
679 4
            hash_list = [x.molecule_hash for x in orm_molecules]
680 4
            query = format_query(MoleculeORM, molecule_hash=hash_list)
681 4
            indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
682 4
            previous_id_map = {k: v for k, v in indices}
683

684
            # For a bulk add there must be no pre-existing and there must be no duplicates in the add list
685 4
            bulk_ok = len(hash_list) == len(set(hash_list))
686 4
            bulk_ok &= len(previous_id_map) == 0
687
            # bulk_ok = False
688

689 4
            if bulk_ok:
690
                # Bulk save, doesn't update fields for speed
691 4
                session.bulk_save_objects(orm_molecules)
692 4
                session.commit()
693

694
                # Query ID's and reorder based off orm_molecule ordered list
695 4
                query = format_query(MoleculeORM, molecule_hash=hash_list)
696 4
                indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
697

698 4
                id_map = {k: v for k, v in indices}
699 4
                n_inserted = len(orm_molecules)
700

701
            else:
702
                # Start from old ID map
703 4
                id_map = previous_id_map
704

705 4
                new_molecules = []
706 4
                n_inserted = 0
707

708 4
                for orm_mol in orm_molecules:
709 4
                    duplicate_id = id_map.get(orm_mol.molecule_hash, False)
710 4
                    if duplicate_id is not False:
711 4
                        meta["duplicates"].append(str(duplicate_id))
712
                    else:
713 4
                        new_molecules.append(orm_mol)
714 4
                        id_map[orm_mol.molecule_hash] = "placeholder_id"
715 4
                        n_inserted += 1
716 4
                        session.add(orm_mol)
717

718
                    # We should make sure there was not a hash collision?
719
                    # new_mol.compare(old_mol)
720
                    # raise KeyError("!!! WARNING !!!: Hash collision detected")
721

722 4
                session.commit()
723

724 4
                for new_mol in new_molecules:
725 4
                    id_map[new_mol.molecule_hash] = new_mol.id
726

727 4
            results = [str(id_map[x.molecule_hash]) for x in orm_molecules]
728 4
            assert "placeholder_id" not in results
729 4
            meta["n_inserted"] = n_inserted
730

731 4
        meta["success"] = True
732

733 4
        ret = {"data": results, "meta": meta}
734 4
        return ret
735

736 4
    def get_molecules(self, id=None, molecule_hash=None, molecular_formula=None, limit: int = None, skip: int = 0):
737 4
        try:
738 4
            if isinstance(molecular_formula, str):
739 4
                molecular_formula = qcelemental.molutil.order_molecular_formula(molecular_formula)
740 4
            elif isinstance(molecular_formula, list):
741 2
                molecular_formula = [qcelemental.molutil.order_molecular_formula(form) for form in molecular_formula]
742 0
        except ValueError:
743
            # Probably, the user provided an invalid chemical formula
744 0
            pass
745

746 4
        meta = get_metadata_template()
747

748 4
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash, molecular_formula=molecular_formula)
749

750
        # Don't include the hash or the molecular_formula in the returned result
751 4
        rdata, meta["n_found"] = self.get_query_projection(
752
            MoleculeORM, query, limit=limit, skip=skip, exclude=["molecule_hash", "molecular_formula"]
753
        )
754

755 4
        meta["success"] = True
756

757
        # This is required for sparse molecules as we don't know which values are spase
758
        # We are lucky that None is the default and doesn't mean anything in Molecule
759
        # This strategy does not work for other objects
760 4
        data = []
761 4
        for mol_dict in rdata:
762 4
            mol_dict = {k: v for k, v in mol_dict.items() if v is not None}
763 4
            data.append(Molecule(**mol_dict, validate=False, validated=True))
764

765 4
        return {"meta": meta, "data": data}
766

767 4
    def del_molecules(self, id: List[str] = None, molecule_hash: List[str] = None):
768
        """
769
        Removes a molecule from the database from its hash.
770

771
        Parameters
772
        ----------
773
        id : str or List[str], optional
774
            ids of molecules, can use the hash parameter instead
775
        molecule_hash : str or List[str]
776
            The hash of a molecule.
777

778
        Returns
779
        -------
780
        bool
781
            Number of deleted molecules.
782
        """
783

784 4
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash)
785

786 4
        with self.session_scope() as session:
787 4
            ret = session.query(MoleculeORM).filter(*query).delete(synchronize_session=False)
788

789 4
        return ret
790

791
    # ~~~~~~~~~~~~~~~~~~~~~~~ Keywords ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
792

793 4
    def add_keywords(self, keyword_sets: List[KeywordSet]):
794
        """Add one KeywordSet uniquly identified by 'program' and the 'name'.
795

796
        Parameters
797
        ----------
798
        keywords_set : List[KeywordSet]
799
            A list of KeywordSets to be inserted.
800

801
        Returns
802
        -------
803
        Dict[str, Any]
804
            (see add_metadata_template())
805
            The 'data' part is a list of ids of the inserted options
806
            data['duplicates'] has the duplicate entries
807

808
        Notes
809
        ------
810
            Duplicates are not considered errors.
811

812
        """
813

814 4
        meta = add_metadata_template()
815

816 4
        keywords = []
817 4
        with self.session_scope() as session:
818 4
            for kw in keyword_sets:
819

820 4
                kw_dict = kw.dict(exclude={"id"})
821

822
                # search by index keywords not by all keys, much faster
823 4
                found = session.query(KeywordsORM).filter_by(hash_index=kw_dict["hash_index"]).first()
824 4
                if not found:
825 4
                    doc = KeywordsORM(**kw_dict)
826 4
                    session.add(doc)
827 4
                    session.commit()
828 4
                    keywords.append(str(doc.id))
829 4
                    meta["n_inserted"] += 1
830
                else:
831 4
                    meta["duplicates"].append(str(found.id))  # TODO
832 4
                    keywords.append(str(found.id))
833 4
                meta["success"] = True
834

835 4
        ret = {"data": keywords, "meta": meta}
836

837 4
        return ret
838

839 4
    def get_keywords(
840
        self,
841
        id: Union[str, list] = None,
842
        hash_index: Union[str, list] = None,
843
        limit: int = None,
844
        skip: int = 0,
845
        return_json: bool = False,
846
        with_ids: bool = True,
847
    ) -> List[KeywordSet]:
848
        """Search for one (unique) option based on the 'program'
849
        and the 'name'. No overwrite allowed.
850

851
        Parameters
852
        ----------
853
        id : List[str] or str
854
            Ids of the keywords
855
        hash_index : List[str] or str
856
            hash index of keywords
857
        limit : Optional[int], optional
858
            Maximum number of results to return.
859
            If this number is greater than the SQLAlchemySocket.max_limit then
860
            the max_limit will be returned instead.
861
            Default is to return the socket's max_limit (when limit=None or 0)
862
        skip : int, optional
863
        return_json : bool, optional
864
            Return the results as a json object
865
            Default is True
866
        with_ids : bool, optional
867
            Include the DB ids in the returned object (names 'id')
868
            Default is True
869

870

871
        Returns
872
        -------
873
            A dict with keys: 'data' and 'meta'
874
            (see get_metadata_template())
875
            The 'data' part is an object of the result or None if not found
876
        """
877

878 4
        meta = get_metadata_template()
879 4
        query = format_query(KeywordsORM, id=id, hash_index=hash_index)
880

881 4
        rdata, meta["n_found"] = self.get_query_projection(
882
            KeywordsORM, query, limit=limit, skip=skip, exclude=[None if with_ids else "id"]
883
        )
884

885 4
        meta["success"] = True
886

887
        # meta['error_description'] = str(err)
888

889 4
        if not return_json:
890 4
            data = [KeywordSet(**d) for d in rdata]
891
        else:
892 0
            data = rdata
893

894 4
        return {"data": data, "meta": meta}
895

896 4
    def get_add_keywords_mixed(self, data):
897
        """
898
        Get or add the given options (if they don't exit).
899
        KeywordsORM are given in a mixed format, either as a dict of mol data
900
        or as existing mol id
901

902
        TODO: to be split into get by_id and get_by_data
903
        """
904

905 4
        meta = get_metadata_template()
906

907 4
        ids = []
908 4
        for idx, kw in enumerate(data):
909 4
            if isinstance(kw, (int, str)):
910 4
                ids.append(kw)
911

912 4
            elif isinstance(kw, KeywordSet):
913 4
                new_id = self.add_keywords([kw])["data"][0]
914 4
                ids.append(new_id)
915
            else:
916 0
                meta["errors"].append((idx, "Data type not understood"))
917 0
                ids.append(None)
918

919 4
        missing = []
920 4
        ret = []
921 4
        for idx, id in enumerate(ids):
922 4
            if id is None:
923 0
                ret.append(None)
924 0
                missing.append(idx)
925 0
                continue
926

927 4
            tmp = self.get_keywords(id=id)["data"]
928 4
            if tmp:
929 4
                ret.append(tmp[0])
930
            else:
931 4
                ret.append(None)
932

933 4
        meta["success"] = True
934 4
        meta["n_found"] = len(ret) - len(missing)
935 4
        meta["missing"] = missing
936

937 4
        return {"meta": meta, "data": ret}
938

939 4
    def del_keywords(self, id: str) -> int:
940
        """
941
        Removes a option set from the database based on its id.
942

943
        Parameters
944
        ----------
945
        id : str
946
            id of the keyword
947

948
        Returns
949
        -------
950
        int
951
           number of deleted documents
952
        """
953

954 4
        count = 0
955 4
        with self.session_scope() as session:
956 4
            count = session.query(KeywordsORM).filter_by(id=id).delete(synchronize_session=False)
957

958 4
        return count
959

960
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`
961

962
    ### database functions
963

964 4
    def add_collection(self, data: Dict[str, Any], overwrite: bool = False):
965
        """Add (or update) a collection to the database.
966

967
        Parameters
968
        ----------
969
        data : Dict[str, Any]
970
            should inlcude at least(keys):
971
            collection : str (immutable)
972
            name : str (immutable)
973

974
        overwrite : bool
975
            Update existing collection
976

977
        Returns
978
        -------
979
        Dict[str, Any]
980
        A dict with keys: 'data' and 'meta'
981
            (see add_metadata_template())
982
            The 'data' part is the id of the inserted document or none
983

984
        Notes
985
        -----
986
        ** Change: The data doesn't have to include the ID, the document
987
        is identified by the (collection, name) pairs.
988
        ** Change: New fields will be added to the collection, but existing won't
989
            be removed.
990
        """
991

992 4
        meta = add_metadata_template()
993 4
        col_id = None
994
        # try:
995

996
        # if ("id" in data) and (data["id"] == "local"):
997
        #     data.pop("id", None)
998 4
        if "id" in data:  # remove the ID in any case
999 4
            data.pop("id", None)
1000 4
        lname = data.get("name").lower()
1001 4
        collection = data.pop("collection").lower()
1002

1003
        # Get collection class if special type is implemented
1004 4
        collection_class = get_collection_class(collection)
1005

1006 4
        update_fields = {}
1007 4
        for field in collection_class._all_col_names():
1008 4
            if field in data:
1009 4
                update_fields[field] = data.pop(field)
1010

1011 4
        update_fields["extra"] = data  # todo: check for sql injection
1012

1013 4
        with self.session_scope() as session:
1014

1015 4
            try:
1016 4
                if overwrite:
1017 4
                    col = session.query(collection_class).filter_by(collection=collection, lname=lname).first()
1018 4
                    for key, value in update_fields.items():
1019 4
                        setattr(col, key, value)
1020
                else:
1021 4
                    col = collection_class(collection=collection, lname=lname, **update_fields)
1022

1023 4
                session.add(col)
1024 4
                session.commit()
1025 4
                col.update_relations(**update_fields)
1026 4
                session.commit()
1027

1028 4
                col_id = str(col.id)
1029 4
                meta["success"] = True
1030 4
                meta["n_inserted"] = 1
1031

1032 4
            except Exception as err:
1033 4
                session.rollback()
1034 4
                meta["error_description"] = str(err)
1035

1036 4
        ret = {"data": col_id, "meta": meta}
1037 4
        return ret
1038

1039 4
    def get_collections(
1040
        self,
1041
        collection: Optional[str] = None,
1042
        name: Optional[str] = None,
1043
        col_id: Optional[int] = None,
1044
        limit: Optional[int] = None,
1045
        include: Optional[List[str]] = None,
1046
        exclude: Optional[List[str]] = None,
1047
        skip: int = 0,
1048
    ) -> Dict[str, Any]:
1049
        """Get collection by collection and/or name
1050

1051
        Parameters
1052
        ----------
1053
        collection: Optional[str], optional
1054
            Type of the collection, e.g. ReactionDataset
1055
        name: Optional[str], optional
1056
            Name of the collection, e.g. S22
1057
        col_id: Optional[int], optional
1058
            Database id of the collection
1059
        limit: Optional[int], optional
1060
            Maximum number of results to return
1061
        include: Optional[List[str]], optional
1062
            Columns to return
1063
        exclude: Optional[List[str]], optional
1064
            Return all but these columns
1065
        skip: int, optional
1066
            Skip the first `skip` results
1067

1068
        Returns
1069
        -------
1070
        Dict[str, Any]
1071
            A dict with keys: 'data' and 'meta'
1072
            The data is a list of the collections found
1073
        """
1074

1075 4
        meta = get_metadata_template()
1076 4
        if name:
1077 4
            name = name.lower()
1078 4
        if collection:
1079 4
            collection = collection.lower()
1080

1081 4
        collection_class = get_collection_class(collection)
1082 4
        query = format_query(collection_class, lname=name, collection=collection, id=col_id)
1083

1084
        # try:
1085 4
        rdata, meta["n_found"] = self.get_query_projection(
1086
            collection_class, query, include=include, exclude=exclude, limit=limit, skip=skip
1087
        )
1088

1089 4
        meta["success"] = True
1090
        # except Exception as err:
1091
        #     meta['error_description'] = str(err)
1092

1093 4
        return {"data": rdata, "meta": meta}
1094

1095 4
    def del_collection(
1096
        self, collection: Optional[str] = None, name: Optional[str] = None, col_id: Optional[int] = None
1097
    ) -> bool:
1098
        """
1099
        Remove a collection from the database from its keys.
1100

1101
        Parameters
1102
        ----------
1103
        collection: Optional[str], optional
1104
            CollectionORM type
1105
        name : Optional[str], optional
1106
            CollectionORM name
1107
        col_id: Optional[int], optional
1108
            Database id of the collection
1109
        Returns
1110
        -------
1111
        int
1112
            Number of documents deleted
1113
        """
1114

1115
        # Assuming here that we don't want to allow deletion of all collections, all datasets, etc.
1116 4
        if not (col_id is not None or (collection is not None and name is not None)):
1117 0
            raise ValueError(
1118
                "Either col_id ({col_id}) must be specified, or collection ({collection}) and name ({name}) must be specified."
1119
            )
1120

1121 4
        filter_spec = {}
1122 4
        if collection is not None:
1123 4
            filter_spec["collection"] = collection.lower()
1124 4
        if name is not None:
1125 4
            filter_spec["lname"] = name.lower()
1126 4
        if col_id is not None:
1127 4
            filter_spec["id"] = col_id
1128

1129 4
        with self.session_scope() as session:
1130 4
            count = session.query(CollectionORM).filter_by(**filter_spec).delete(synchronize_session=False)
1131 4
        return count
1132

1133
    ## ResultORMs functions
1134

1135 4
    def add_results(self, record_list: List[ResultRecord]):
1136
        """
1137
        Add results from a given dict. The dict should have all the required
1138
        keys of a result.
1139

1140
        Parameters
1141
        ----------
1142
        data : List[ResultRecord]
1143
            Each dict in the list must have:
1144
            program, driver, method, basis, options, molecule
1145
            Where molecule is the molecule id in the DB
1146
            In addition, it should have the other attributes that it needs
1147
            to store
1148

1149
        Returns
1150
        -------
1151
        Dict[str, Any]
1152
            Dict with keys: data, meta
1153
            Data is the ids of the inserted/updated/existing docs, in the same order as the
1154
            input record_list
1155
        """
1156

1157 4
        meta = add_metadata_template()
1158

1159 4
        results_list = []
1160 4
        duplicates_list = []
1161

1162
        # Stores indices referring to elements in record_list
1163 4
        new_record_idx, duplicates_idx = [], []
1164

1165
        # creating condition for a multi-value select
1166
        # This can be used to query for multiple results in a single query
1167 4
        conds = [
1168
            and_(
1169
                ResultORM.program == res.program,
1170
                ResultORM.driver == res.driver,
1171
                ResultORM.method == res.method,
1172
                ResultORM.basis == res.basis,
1173
                ResultORM.keywords == res.keywords,
1174
                ResultORM.molecule == res.molecule,
1175
            )
1176
            for res in record_list
1177
        ]
1178

1179 4
        with self.session_scope() as session:
1180
            # Query for all existing
1181
            # TODO: RACE CONDITION: Records could be inserted between this query and inserting later
1182

1183 4
            existing_results = {}
1184

1185 4
            for cond in conds:
1186 4
                doc = (
1187
                    session.query(
1188
                        ResultORM.program,
1189
                        ResultORM.driver,
1190
                        ResultORM.method,
1191
                        ResultORM.basis,
1192
                        ResultORM.keywords,
1193
                        ResultORM.molecule,
1194
                        ResultORM.id,
1195
                    )
1196
                    .filter(cond)
1197
                    .one_or_none()
1198
                )
1199

1200 4
                if doc is not None:
1201 4
                    existing_results[
1202
                        (doc.program, doc.driver, doc.method, doc.basis, doc.keywords, str(doc.molecule))
1203
                    ] = doc
1204

1205
            # Loop over all (input) records, keeping track each record's index in the list
1206 4
            for i, result in enumerate(record_list):
1207
                # constructing an index from the record compare against items existing_results
1208 4
                idx = (
1209
                    result.program,
1210
                    result.driver.value,
1211
                    result.method,
1212
                    result.basis,
1213
                    int(result.keywords) if result.keywords else None,
1214
                    result.molecule,
1215
                )
1216

1217 4
                if idx not in existing_results:
1218
                    # Does not exist in the database. Construct a new ResultORM
1219 4
                    doc = ResultORM(**result.dict(exclude={"id"}))
1220

1221
                    # Store in existing_results in case later records are duplicates
1222 4
                    existing_results[idx] = doc
1223

1224
                    # add the object to the list for later adding and committing to database.
1225 4
                    results_list.append(doc)
1226

1227
                    # Store the index of this record (in record_list) as a new_record
1228 4
                    new_record_idx.append(i)
1229 4
                    meta["n_inserted"] += 1
1230
                else:
1231
                    # This result already exists in the database
1232 4
                    doc = existing_results[idx]
1233

1234
                    # Store the index of this record (in record_list) as a new_record
1235 4
                    duplicates_idx.append(i)
1236

1237
                    # Store the entire object. Since this may be a duplicate of a record
1238
                    # added in a previous iteration of the loop, and the data hasn't been added/committed
1239
                    # to the database, the id may not be known here
1240 4
                    duplicates_list.append(doc)
1241

1242 4
            session.add_all(results_list)
1243 4
            session.commit()
1244

1245
            # At this point, all ids should be known. So store only the ids in the returned metadata
1246 4
            meta["duplicates"] = [str(doc.id) for doc in duplicates_list]
1247

1248
            # Construct the ID list to return (in the same order as the input data)
1249
            # Use a placeholder for all, and we will fill later
1250 4
            result_ids = [None] * len(record_list)
1251

1252
            # At this point:
1253
            #     results_list: ORM objects for all newly-added results
1254
            #     new_record_idx: indices (referring to record_list) of newly-added results
1255
            #     duplicates_idx: indices (referring to record_list) of results that already existed
1256
            #
1257
            # results_list and new_record_idx are in the same order
1258
            # (ie, the index stored at new_record_idx[0] refers to some element of record_list. That
1259
            # newly-added ResultORM is located at results_list[0])
1260
            #
1261
            # Similarly, duplicates_idx and meta["duplicates"] are in the same order
1262

1263 4
            for idx, new_result in zip(new_record_idx, results_list):
1264 4
                result_ids[idx] = str(new_result.id)
1265

1266
            # meta["duplicates"] only holds ids at this point
1267 4
            for idx, existing_result_id in zip(duplicates_idx, meta["duplicates"]):
1268 4
                result_ids[idx] = existing_result_id
1269

1270 4
        assert None not in result_ids
1271

1272 4
        meta["success"] = True
1273

1274 4
        ret = {"data": result_ids, "meta": meta}
1275 4
        return ret
1276

1277 4
    def update_results(self, record_list: List[ResultRecord]):
1278
        """
1279
        Update results from a given dict (replace existing)
1280

1281
        Parameters
1282
        ----------
1283
        id : list of str
1284
            Ids of the results to update, must exist in the DB
1285
        data : list of dict
1286
            Data that needs to be updated
1287
            Shouldn't update:
1288
            program, driver, method, basis, options, molecule
1289

1290
        Returns
1291
        -------
1292
            number of records updated
1293
        """
1294 2
        query_ids = [res.id for res in record_list]
1295
        # find duplicates among ids
1296 2
        duplicates = len(query_ids) != len(set(query_ids))
1297

1298 2
        with self.session_scope() as session:
1299

1300 2
            found = session.query(ResultORM).filter(ResultORM.id.in_(query_ids)).all()
1301
            # found items are stored in a dictionary
1302 2
            found_dict = {str(record.id): record for record in found}
1303

1304 2
            updated_count = 0
1305 2
            for result in record_list:
1306

1307 2
                if result.id is None:
1308 0
                    self.logger.error("Attempted update without ID, skipping")
1309 0
                    continue
1310

1311 2
                data = result.dict(exclude={"id"})
1312
                # retrieve the found item
1313 2
                found_db = found_dict[result.id]
1314

1315
                # updating the found item with input attribute values.
1316 2
                for attr, val in data.items():
1317 2
                    setattr(found_db, attr, val)
1318

1319
                # if any duplicate ids are found in the input, commit should be called each iteration
1320 2
                if duplicates:
1321 0
                    session.commit()
1322 2
                updated_count += 1
1323
            # if no duplicates found, only commit at the end of the loop.
1324 2
            if not duplicates:
1325 2
                session.commit()
1326

1327 2
        return updated_count
1328

1329 4
    def get_results_count(self):
1330
        """
1331
        TODO: just return the count, used for big queries
1332

1333
        Returns
1334
        -------
1335

1336
        """
1337

1338 4
    def get_results(
1339
        self,
1340
        id: Union[str, List] = None,
1341
        program: str = None,
1342
        method: str = None,
1343
        basis: str = None,
1344
        molecule: str = None,
1345
        driver: str = None,
1346
        keywords: str = None,
1347
        task_id: Union[str, List] = None,
1348
        manager_id: Union[str, List] = None,
1349
        status: str = "COMPLETE",
1350
        include: Optional[List[str]] = None,
1351
        exclude: Optional[List[str]] = None,
1352
        limit: int = None,
1353
        skip: int = 0,
1354
        return_json=True,
1355
        with_ids=True,
1356
    ):
1357
        """
1358

1359
        Parameters
1360
        ----------
1361
        id : str or List[str]
1362
        program : str
1363
        method : str
1364
        basis : str
1365
        molecule : str
1366
            MoleculeORM id in the DB
1367
        driver : str
1368
        keywords : str
1369
            The id of the option in the DB
1370
        task_id: str or List[str]
1371
            id or a list of ids of tasks
1372
        manager_id: str or List[str]
1373
            id or a list of ids of queue_mangers
1374
        status : bool, optional
1375
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1376
            Default is 'COMPLETE'
1377
        include : Optional[List[str]], optional
1378
            The fields to return, default to return all
1379
        exclude : Optional[List[str]], optional
1380
            The fields to not return, default to return all
1381
        limit : Optional[int], optional
1382
            maximum number of results to return
1383
            if 'limit' is greater than the global setting self._max_limit,
1384
            the self._max_limit will be returned instead
1385
            (This is to avoid overloading the server)
1386
        skip : int, optional
1387
            skip the first 'skip' results. Used to paginate
1388
            Default is 0
1389
        return_json : bool, optional
1390
            Return the results as a list of json inseated of objects
1391
            default is True
1392
        with_ids : bool, optional
1393
            Include the ids in the returned objects/dicts
1394
            default is True
1395

1396
        Returns
1397
        -------
1398
        Dict[str, Any]
1399
            Dict with keys: data, meta
1400
            Data is the objects found
1401
        """
1402

1403 4
        if task_id:
1404 0
            return self._get_results_by_task_id(task_id)
1405

1406 4
        meta = get_metadata_template()
1407

1408
        # Ignore status if Id is present
1409 4
        if id is not None:
1410 4
            status = None
1411

1412 4
        query = format_query(
1413
            ResultORM,
1414
            id=id,
1415
            program=program,
1416
            method=method,
1417
            basis=basis,
1418
            molecule=molecule,
1419
            driver=driver,
1420
            keywords=keywords,
1421
            manager_id=manager_id,
1422
            status=status,
1423
        )
1424

1425 4
        data, meta["n_found"] = self.get_query_projection(
1426
            ResultORM, query, include=include, exclude=exclude, limit=limit, skip=skip
1427
        )
1428 4
        meta["success"] = True
1429

1430 4
        return {"data": data, "meta": meta}
1431

1432 4
    def _get_results_by_task_id(self, task_id: Union[str, List] = None, return_json=True):
1433
        """
1434

1435
        Parameters
1436
        ----------
1437
        task_id : str or List[str]
1438

1439
        return_json : bool, optional
1440
            Return the results as a list of json inseated of objects
1441
            Default is True
1442

1443
        Returns
1444
        -------
1445
        Dict[str, Any]
1446
            Dict with keys: data, meta
1447
            Data is the objects found
1448
        """
1449

1450 0
        meta = get_metadata_template()
1451

1452 0
        data = []
1453 0
        task_id_list = [task_id] if isinstance(task_id, (int, str)) else task_id
1454
        # try:
1455 0
        with self.session_scope() as session:
1456 0
            data = (
1457
                session.query(BaseResultORM)
1458
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
1459
                .filter(TaskQueueORM.id.in_(task_id_list))
1460
            )
1461 0
            meta["n_found"] = get_count_fast(data)
1462 0
            data = [d.to_dict() for d in data.all()]
1463 0
            meta["success"] = True
1464
            # except Exception as err:
1465
            #     meta['error_description'] = str(err)
1466

1467 0
        return {"data": data, "meta": meta}
1468

1469 4
    def del_results(self, ids: List[str]):
1470
        """
1471
        Removes results from the database using their ids
1472
        (Should be cautious! other tables maybe referencing results)
1473

1474
        Parameters
1475
        ----------
1476
        ids : List[str]
1477
            The Ids of the results to be deleted
1478

1479
        Returns
1480
        -------
1481
        int
1482
            number of results deleted
1483
        """
1484

1485 4
        with self.session_scope() as session:
1486 4
            results = session.query(ResultORM).filter(ResultORM.id.in_(ids)).all()
1487
            # delete through session to delete correctly from base_result
1488 4
            for result in results:
1489 4
                session.delete(result)
1490 4
            session.commit()
1491 4
            count = len(results)
1492

1493 4
        return count
1494

1495 4
    def add_wavefunction_store(self, blobs_list: List[Dict[str, Any]]):
1496
        """
1497
        Adds to the wavefunction key/value store table.
1498

1499
        Parameters
1500
        ----------
1501
        blobs_list : List[Dict[str, Any]]
1502
            A list of wavefunction data blobs to add.
1503

1504
        Returns
1505
        -------
1506
        Dict[str, Any]
1507
            Dict with keys data and meta, where data represent the blob_ids of inserted wavefuction data blobs.
1508
        """
1509

1510 1
        meta = add_metadata_template()
1511 1
        blob_ids = []
1512 1
        with self.session_scope() as session:
1513 1
            for blob in blobs_list:
1514 1
                if blob is None:
1515 0
                    blob_ids.append(None)
1516 0
                    continue
1517

1518 1
                doc = WavefunctionStoreORM(**blob)
1519 1
                session.add(doc)
1520 1
                session.commit()
1521 1
                blob_ids.append(str(doc.id))
1522 1
                meta["n_inserted"] += 1
1523

1524 1
        meta["success"] = True
1525

1526 1
        return {"data": blob_ids, "meta": meta}
1527

1528 4
    def get_wavefunction_store(
1529
        self,
1530
        id: List[str] = None,
1531
        include: Optional[List[str]] = None,
1532
        exclude: Optional[List[str]] = None,
1533
        limit: int = None,
1534
        skip: int = 0,
1535
    ) -> Dict[str, Any]:
1536
        """
1537
        Pulls from the wavefunction key/value store table.
1538

1539
        Parameters
1540
        ----------
1541
        id : List[str], optional
1542
            A list of ids to query
1543
        include : Optional[List[str]], optional
1544
            The fields to return, default to return all
1545
        exclude : Optional[List[str]], optional
1546
            The fields to not return, default to return all
1547
        limit : int, optional
1548
            Maximum number of results to return.
1549
            Default is set to 0
1550
        skip : int, optional
1551
            Skips a number of results in the query, used for pagination
1552
            Default is set to 0
1553

1554
        Returns
1555
        -------
1556
        Dict[str, Any]
1557
            Dictionary with keys data and meta, where data is the found wavefunction items
1558
        """
1559

1560 1
        meta = get_metadata_template()
1561

1562 1
        query = format_query(WavefunctionStoreORM, id=id)
1563 1
        rdata, meta["n_found"] = self.get_query_projection(
1564
            WavefunctionStoreORM, query, limit=limit, skip=skip, include=include, exclude=exclude
1565
        )
1566

1567 1
        meta["success"] = True
1568

1569 1
        return {"data": rdata, "meta": meta}
1570

1571
    ### Mongo procedure/service functions
1572

1573 4
    def add_procedures(self, record_list: List["BaseRecord"]):
1574
        """
1575
        Add procedures from a given dict. The dict should have all the required
1576
        keys of a result.
1577

1578
        Parameters
1579
        ----------
1580
        record_list : List["BaseRecord"]
1581
            Each dict must have:
1582
            procedure, program, keywords, qc_meta, hash_index
1583
            In addition, it should have the other attributes that it needs
1584
            to store
1585

1586
        Returns
1587
        -------
1588
        Dict[str, Any]
1589
            Dictionary with keys data and meta, data is the ids of the inserted/updated/existing docs
1590
        """
1591

1592 4
        meta = add_metadata_template()
1593

1594 4
        if not record_list:
1595 0
            return {"data": [], "meta": meta}
1596

1597 4
        procedure_class = get_procedure_class(record_list[0])
1598

1599 4
        procedure_ids = []
1600 4
        with self.session_scope() as session:
1601 4
            for procedure in record_list:
1602 4
                doc = session.query(procedure_class).filter_by(hash_index=procedure.hash_index)
1603

1604 4
                if get_count_fast(doc) == 0:
1605 4
                    data = procedure.dict(exclude={"id"})
1606 4
                    proc_db = procedure_class(**data)
1607 4
                    session.add(proc_db)
1608 4
                    session.commit()
1609 4
                    proc_db.update_relations(**data)
1610 4
                    session.commit()
1611 4
                    procedure_ids.append(str(proc_db.id))
1612 4
                    meta["n_inserted"] += 1
1613
                else:
1614 1
                    id = str(doc.first().id)
1615 1
                    meta["duplicates"].append(id)  # TODO
1616 1
                    procedure_ids.append(id)
1617 4
        meta["success"] = True
1618

1619 4
        ret = {"data": procedure_ids, "meta": meta}
1620 4
        return ret
1621

1622 4
    def get_procedures(
1623
        self,
1624
        id: Union[str, List] = None,
1625
        procedure: str = None,
1626
        program: str = None,
1627
        hash_index: str = None,
1628
        task_id: Union[str, List] = None,
1629
        manager_id: Union[str, List] = None,
1630
        status: str = "COMPLETE",
1631
        include=None,
1632
        exclude=None,
1633
        limit: int = None,
1634
        skip: int = 0,
1635
        return_json=True,
1636
        with_ids=True,
1637
    ):
1638
        """
1639

1640
        Parameters
1641
        ----------
1642
        id : str or List[str]
1643
        procedure : str
1644
        program : str
1645
        hash_index : str
1646
        task_id : str or List[str]
1647
        status : bool, optional
1648
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1649
            Default is 'COMPLETE'
1650
        include : Optional[List[str]], optional
1651
            The fields to return, default to return all
1652
        exclude : Optional[List[str]], optional
1653
            The fields to not return, default to return all
1654
        limit : Optional[int], optional
1655
            maximum number of results to return
1656
            if 'limit' is greater than the global setting self._max_limit,
1657
            the self._max_limit will be returned instead
1658
            (This is to avoid overloading the server)
1659
        skip : int, optional
1660
            skip the first 'skip' resaults. Used to paginate
1661
            Default is 0
1662
        return_json : bool, optional
1663
            Return the results as a list of json inseated of objects
1664
            Default is True
1665
        with_ids : bool, optional
1666
            Include the ids in the returned objects/dicts
1667
            Default is True
1668

1669
        Returns
1670
        -------
1671
        Dict[str, Any]
1672
            Dict with keys: data and meta. Data is the objects found
1673
        """
1674

1675 4
        meta = get_metadata_template()
1676

1677 4
        if id is not None or task_id is not None:
1678 1
            status = None
1679

1680 4
        if procedure == "optimization":
1681 4
            className = OptimizationProcedureORM
1682 4
        elif procedure == "torsiondrive":
1683 4
            className = TorsionDriveProcedureORM
1684 4
        elif procedure == "gridoptimization":
1685 0
            className = GridOptimizationProcedureORM
1686
        else:
1687
            # raise TypeError('Unsupported procedure type {}. Id: {}, task_id: {}'
1688
            #                 .format(procedure, id, task_id))
1689 4
            className = BaseResultORM  # all classes, including those with 'selectin'
1690 4
            program = None  # make sure it's not used
1691 4
            if id is None:
1692 4
                self.logger.error(f"Procedure type not specified({procedure}), and ID is not given.")
1693 4
                raise KeyError("ID is required if procedure type is not specified.")
1694

1695 4
        query = format_query(
1696
            className,
1697
            id=id,
1698
            procedure=procedure,
1699
            program=program,
1700
            hash_index=hash_index,
1701
            task_id=task_id,
1702
            manager_id=manager_id,
1703
            status=status,
1704
        )
1705

1706 4
        data = []
1707 4
        try:
1708
            # TODO: decide a way to find the right type
1709

1710 4
            data, meta["n_found"] = self.get_query_projection(
1711
                className, query, limit=limit, skip=skip, include=include, exclude=exclude
1712
            )
1713 4
            meta["success"] = True
1714 0
        except Exception as err:
1715 0
            meta["error_description"] = str(err)
1716

1717 4
        return {"data": data, "meta": meta}
1718

1719 4
    def update_procedures(self, records_list: List["BaseRecord"]):
1720
        """
1721
        TODO: needs to be of specific type
1722
        """
1723

1724 4
        updated_count = 0
1725 4
        with self.session_scope() as session:
1726 4
            for procedure in records_list:
1727

1728 4
                className = get_procedure_class(procedure)
1729
                # join_table = get_procedure_join(procedure)
1730
                # Must have ID
1731 4
                if procedure.id is None:
1732 0
                    self.logger.error(
1733
                        "No procedure id found on update (hash_index={}), skipping.".format(procedure.hash_index)
1734
                    )
1735 0
                    continue
1736

1737 4
                proc_db = session.query(className).filter_by(id=procedure.id).first()
1738

1739 4
                data = procedure.dict(exclude={"id"})
1740 4
                proc_db.update_relations(**data)
1741

1742 4
                for attr, val in data.items():
1743 4
                    setattr(proc_db, attr, val)
1744

1745
                # session.add(proc_db)
1746

1747
                # Upsert relations (insert or update)
1748
                # needs primarykeyconstraint on the table keys
1749
                # for result_id in procedure.trajectory:
1750
                #     statement = postgres_insert(opt_result_association)\
1751
                #         .values(opt_id=procedure.id, result_id=result_id)\
1752
                #         .on_conflict_do_update(
1753
                #             index_elements=[opt_result_association.c.opt_id, opt_result_association.c.result_id],
1754
                #             set_=dict(result_id=result_id))
1755
                #     session.execute(statement)
1756

1757 4
                session.commit()
1758 4
                updated_count += 1
1759

1760
        # session.commit()  # save changes, takes care of inheritance
1761

1762 4
        return updated_count
1763

1764 4
    def del_procedures(self, ids: List[str]):
1765
        """
1766
        Removes results from the database using their ids
1767
        (Should be cautious! other tables maybe referencing results)
1768

1769
        Parameters
1770
        ----------
1771
        ids : List[str]
1772
            The Ids of the results to be deleted
1773

1774
        Returns
1775
        -------
1776
        int
1777
            number of results deleted
1778
        """
1779

1780 4
        with self.session_scope() as session:
1781 4
            procedures = (
1782
                session.query(
1783
                    with_polymorphic(
1784
                        BaseResultORM,
1785
                        [OptimizationProcedureORM, TorsionDriveProcedureORM, GridOptimizationProcedureORM],
1786
                    )
1787
                )
1788
                .filter(BaseResultORM.id.in_(ids))
1789
                .all()
1790
            )
1791
            # delete through session to delete correctly from base_result
1792 4
            for proc in procedures:
1793 4
                session.delete(proc)
1794
            # session.commit()
1795 4
            count = len(procedures)
1796

1797 4
        return count
1798

1799 4
    def add_services(self, service_list: List["BaseService"]):
1800
        """
1801
        Add services from a given list of dict.
1802

1803
        Parameters
1804
        ----------
1805
        services_list : List[Dict[str, Any]]
1806
            List of services to be added
1807
        Returns
1808
        -------
1809
        Dict[str, Any]
1810
            Dict with keys: data, meta. Data is the hash_index of the inserted/existing docs
1811
        """
1812

1813 4
        meta = add_metadata_template()
1814

1815 4
        procedure_ids = []
1816 4
        with self.session_scope() as session:
1817 4
            for service in service_list:
1818

1819
                # Add the underlying procedure
1820 4
                new_procedure = self.add_procedures([service.output])
1821

1822
                # ProcedureORM already exists
1823 4
                proc_id = new_procedure["data"][0]
1824

1825 4
                if new_procedure["meta"]["duplicates"]:
1826 1
                    procedure_ids.append(proc_id)
1827 1
                    meta["duplicates"].append(proc_id)
1828 1
                    continue
1829

1830
                # search by hash index
1831 4
                doc = session.query(ServiceQueueORM).filter_by(hash_index=service.hash_index)
1832 4
                service.procedure_id = proc_id
1833

1834 4
                if doc.count() == 0:
1835 4
                    doc = ServiceQueueORM(**service.dict(include=set(ServiceQueueORM.__dict__.keys())))
1836 4
                    doc.extra = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1837 4
                    doc.priority = doc.priority.value  # Must be an integer for sorting
1838 4
                    session.add(doc)
1839 4
                    session.commit()  # TODO
1840 4
                    procedure_ids.append(proc_id)
1841 4
                    meta["n_inserted"] += 1
1842
                else:
1843 0
                    procedure_ids.append(None)
1844 0
                    meta["errors"].append((doc.id, "Duplicate service, but not caught by procedure."))
1845

1846 4
        meta["success"] = True
1847

1848 4
        ret = {"data": procedure_ids, "meta": meta}
1849 4
        return ret
1850

1851 4
    def get_services(
1852
        self,
1853
        id: Union[List[str], str] = None,
1854
        procedure_id: Union[List[str], str] = None,
1855
        hash_index: Union[List[str], str] = None,
1856
        status: str = None,
1857
        limit: int = None,
1858
        skip: int = 0,
1859
        return_json=True,
1860
    ):
1861
        """
1862

1863
        Parameters
1864
        ----------
1865
        id / hash_index : List[str] or str, optional
1866
            service id
1867
        procedure_id : List[str] or str, optional
1868
            procedure_id for the specific procedure
1869
        status : str, optional
1870
            status of the record queried for
1871
        limit : Optional[int], optional
1872
            maximum number of results to return
1873
            if 'limit' is greater than the global setting self._max_limit,
1874
            the self._max_limit will be returned instead
1875
            (This is to avoid overloading the server)
1876
        skip : int, optional
1877
            skip the first 'skip' resaults. Used to paginate
1878
            Default is 0
1879
        return_json : bool, deafult is True
1880
            Return the results as a list of json instead of objects
1881

1882
        Returns
1883
        -------
1884
        Dict[str, Any]
1885
            Dict with keys: data, meta. Data is the objects found
1886
        """
1887

1888 4
        meta = get_metadata_template()
1889 4
        query = format_query(ServiceQueueORM, id=id, hash_index=hash_index, procedure_id=procedure_id, status=status)
1890

1891 4
        with self.session_scope() as session:
1892 4
            data = (
1893
                session.query(ServiceQueueORM)
1894
                .filter(*query)
1895
                .order_by(ServiceQueueORM.priority.desc(), ServiceQueueORM.created_on)
1896
                .limit(limit)
1897
                .offset(skip)
1898
                .all()
1899
            )
1900 4
            data = [x.to_dict() for x in data]
1901

1902 4
        meta["n_found"] = len(data)
1903 4
        meta["success"] = True
1904

1905
        # except Exception as err:
1906
        #     meta['error_description'] = str(err)
1907

1908 4
        return {"data": data, "meta": meta}
1909

1910 4
    def update_services(self, records_list: List["BaseService"]) -> int:
1911
        """
1912
        Replace existing service
1913

1914
        Raises exception if the id is invalid
1915

1916
        Parameters
1917
        ----------
1918
        records_list: List[Dict[str, Any]]
1919
            List of Service items to be updated using their id
1920

1921
        Returns
1922
        -------
1923
        int
1924
            number of updated services
1925
        """
1926

1927 4
        updated_count = 0
1928 4
        for service in records_list:
1929 4
            if service.id is None:
1930 0
                self.logger.error("No service id found on update (hash_index={}), skipping.".format(service.hash_index))
1931 0
                continue
1932

1933 4
            with self.session_scope() as session:
1934

1935 4
                doc_db = session.query(ServiceQueueORM).filter_by(id=service.id).first()
1936

1937 4
                data = service.dict(include=set(ServiceQueueORM.__dict__.keys()))
1938 4
                data["extra"] = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1939

1940 4
                data["id"] = int(data["id"])
1941 4
                for attr, val in data.items():
1942 4
                    setattr(doc_db, attr, val)
1943

1944 4
                session.add(doc_db)
1945 4
                session.commit()
1946

1947 4
            procedure = service.output
1948 4
            procedure.__dict__["id"] = service.procedure_id
1949 4
            self.update_procedures([procedure])
1950

1951 4
            updated_count += 1
1952

1953 4
        return updated_count
1954

1955 4
    def update_service_status(
1956
        self, status: str, id: Union[List[str], str] = None, procedure_id: Union[List[str], str] = None
1957
    ) -> int:
1958
        """
1959
        Update the status of the existing services in the database.
1960

1961
        Raises an exception if any of the ids are invalid.
1962
        Parameters
1963
        ----------
1964
        status : str
1965
            The input status string ready to replace the previous status
1966
        id : Optional[Union[List[str], str]], optional
1967
            ids of all the services requested to be updated, by default None
1968
        procedure_id : Optional[Union[List[str], str]], optional
1969
            procedure_ids for the specific procedures, by default None
1970

1971
        Returns
1972
        -------
1973
        int
1974
            1 indicating that the status update was successful
1975
        """
1976

1977 1
        if (id is None) and (procedure_id is None):
1978 0
            raise KeyError("id or procedure_id must not be None.")
1979

1980 1
        status = status.lower()
1981 1
        with self.session_scope() as session:
1982

1983 1
            query = format_query(ServiceQueueORM, id=id, procedure_id=procedure_id)
1984

1985
            # Update the service
1986 1
            service = session.query(ServiceQueueORM).filter(*query).first()
1987 1
            service.status = status
1988

1989
            # Update the procedure
1990 1
            if status == "waiting":
1991 0
                status = "incomplete"
1992 1
            session.query(BaseResultORM).filter(BaseResultORM.id == service.procedure_id).update({"status": status})
1993

1994 1
            session.commit()
1995

1996 1
        return 1
1997

1998 4
    def services_completed(self, records_list: List["BaseService"]) -> int:
1999
        """
2000
        Delete the services which are completed from the database.
2001

2002
        Parameters
2003
        ----------
2004
        records_list : List["BaseService"]
2005
            List of Service objects which are completed.
2006

2007
        Returns
2008
        -------
2009
        int
2010
            Number of deleted active services from database.
2011
        """
2012 2
        done = 0
2013 2
        for service in records_list:
2014 1
            if service.id is None:
2015 0
                self.logger.error(
2016
                    "No service id found on completion (hash_index={}), skipping.".format(service.hash_index)
2017
                )
2018 0
                continue
2019

2020
            # in one transaction
2021 1
            with self.session_scope() as session:
2022

2023 1
                procedure = service.output
2024 1
                procedure.__dict__["id"] = service.procedure_id
2025 1
                self.update_procedures([procedure])
2026

2027 1
                session.query(ServiceQueueORM).filter_by(id=service.id).delete()  # synchronize_session=False)
2028

2029 1
            done += 1
2030

2031 2
        return done
2032

2033
    ### Mongo queue handling functions
2034

2035 4
    def queue_submit(self, data: List[TaskRecord]):
2036
        """Submit a list of tasks to the queue.
2037
        Tasks are unique by their base_result, which should be inserted into
2038
        the DB first before submitting it's corresponding task to the queue
2039
        (with result.status='INCOMPLETE' as the default)
2040
        The default task.status is 'WAITING'
2041

2042
        Duplicate tasks sould be a rare case.
2043
        Hooks are merged if the task already exists
2044

2045
        Parameters
2046
        ----------
2047
        data : List[TaskRecord]
2048
            A task is a dict, with the following fields:
2049
            - hash_index: idx, not used anymore
2050
            - spec: dynamic field (dict-like), can have any structure
2051
            - tag: str
2052
            - base_results: tuple (required), first value is the class type
2053
             of the result, {'results' or 'procedure'). The second value is
2054
             the ID of the result in the DB. Example:
2055
             "base_result": ('results', result_id)
2056

2057
        Returns
2058
        -------
2059
        Dict[str, Any]
2060
            Dictionary with keys data and meta.
2061
            'data' is a list of the IDs of the tasks IN ORDER, including
2062
            duplicates. An errored task has 'None' in its ID
2063
            meta['duplicates'] has the duplicate tasks
2064
        """
2065

2066 4
        meta = add_metadata_template()
2067

2068 4
        results = ["placeholder"] * len(data)
2069

2070 4
        with self.session_scope() as session:
2071
            # preserving all the base results for later check
2072 4
            all_base_results = [record.base_result for record in data]
2073 4
            query_res = (
2074
                session.query(TaskQueueORM.id, TaskQueueORM.base_result_id)
2075
                .filter(TaskQueueORM.base_result_id.in_(all_base_results))
2076
                .all()
2077
            )
2078

2079
            # constructing a dict of found tasks and their ids
2080 4
            found_dict = {str(base_result_id): str(task_id) for task_id, base_result_id in query_res}
2081 4
            new_tasks, new_idx = [], []
2082 4
            duplicate_idx = []
2083 4
            for task_num, record in enumerate(data):
2084

2085 4
                if found_dict.get(record.base_result):
2086
                    # if found, get id from found_dict
2087
                    # Note: found_dict may return a task object because the duplicate id is of an object in the input.
2088 4
                    results[task_num] = found_dict.get(record.base_result)
2089
                    # add index of duplicates
2090 4
                    duplicate_idx.append(task_num)
2091 4
                    meta["duplicates"].append(task_num)
2092

2093
                else:
2094 4
                    task_dict = record.dict(exclude={"id"})
2095 4
                    task = TaskQueueORM(**task_dict)
2096 4
                    new_idx.append(task_num)
2097 4
                    task.priority = task.priority.value
2098
                    # append all the new tasks that should be added
2099 4
                    new_tasks.append(task)
2100
                    # add the (yet to be) inserted object id to dictionary
2101 4
                    found_dict[record.base_result] = task
2102

2103 4
            session.add_all(new_tasks)
2104 4
            session.commit()
2105

2106 4
            meta["n_inserted"] += len(new_tasks)
2107
            # setting the id for new inserted objects, cannot be done before commiting as new objects do not have ids
2108 4
            for i, task_idx in enumerate(new_idx):
2109 4
                results[task_idx] = str(new_tasks[i].id)
2110

2111
            # finding the duplicate items in input, for which ids are found only after insertion
2112 4
            for i in duplicate_idx:
2113 4
                if not isinstance(results[i], str):
2114 4
                    results[i] = str(results[i].id)
2115

2116 4
        meta["success"] = True
2117

2118 4
        ret = {"data": results, "meta": meta}
2119 4
        return ret
2120

2121 4
    def queue_get_next(
2122
        self, manager, available_programs, available_procedures, limit=100, tag=None, as_json=True
2123
    ) -> List[TaskRecord]:
2124
        """Done in a transaction"""
2125

2126
        # Figure out query, tagless has no requirements
2127

2128 4
        proc_filt = TaskQueueORM.procedure.in_([p.lower() for p in available_procedures])
2129 4
        none_filt = TaskQueueORM.procedure == None  # lgtm [py/test-equals-none]
2130

2131 4
        order_by = []
2132 4
        if tag is not None:
2133 2
            if isinstance(tag, str):
2134 2
                tag = [tag]
2135
            # task_order = expression_case([(TaskQueueORM.tag == t, num) for num, t in enumerate(tag)])
2136
            # order_by.append(task_order)
2137

2138 4
        order_by.extend([TaskQueueORM.priority.desc(), TaskQueueORM.created_on])
2139 4
        queries = []
2140 4
        if tag is not None:
2141 2
            for t in tag:
2142 2
                query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs, tag=t)
2143 2
                query.append(or_(proc_filt, none_filt))
2144 2
                queries.append(query)
2145
        else:
2146 4
            query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs)
2147 4
            query.append((or_(proc_filt, none_filt)))
2148 4
            queries.append(query)
2149 4
        new_limit = limit
2150 4
        ids = []
2151 4
        found = []
2152 4
        with self.session_scope() as session:
2153 4
            for q in queries:
2154 4
                if new_limit == 0:
2155 2
                    break
2156 4
                query = session.query(TaskQueueORM).filter(*q).order_by(*order_by).limit(new_limit)
2157
                # from sqlalchemy.dialects import postgresql
2158
                # print(query.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
2159 4
                new_items = query.all()
2160 4
                found.extend(new_items)
2161 4
                new_limit = limit - len(new_items)
2162 4
                ids.extend([x.id for x in new_items])
2163 4
            update_fields = {"status": TaskStatusEnum.running, "modified_on": dt.utcnow(), "manager": manager}
2164
            # # Bulk update operation in SQL
2165 4
            update_count = (
2166
                session.query(TaskQueueORM)
2167
                .filter(TaskQueueORM.id.in_(ids))
2168
                .update(update_fields, synchronize_session=False)
2169
            )
2170

2171 4
            if as_json:
2172
                # avoid another trip to the DB to get the updated values, set them here
2173 4
                found = [TaskRecord(**task.to_dict(exclude=update_fields.keys()), **update_fields) for task in found]
2174 4
            session.commit()
2175

2176 4
        if update_count != len(found):
2177 0
            self.logger.warning("QUEUE: Number of found projects does not match the number of updated projects.")
2178

2179 4
        return found
2180

2181 4
    def get_queue(
2182
        self,
2183
        id=None,
2184
        hash_index=None,
2185
        program=None,
2186
        status: str = None,
2187
        base_result: str = None,
2188
        tag=None,
2189
        manager=None,
2190
        include=None,
2191
        exclude=None,
2192
        limit: int = None,
2193
        skip: int = 0,
2194
        return_json=False,
2195
        with_ids=True,
2196
    ):
2197
        """
2198
        TODO: check what query keys are needs
2199
        Parameters
2200
        ----------
2201
        id : Optional[List[str]], optional
2202
            Ids of the tasks
2203
        Hash_index: Optional[List[str]], optional,
2204
            hash_index of service, not used
2205
        program, list of str or str, optional
2206
        status : Optional[bool], optional (find all)
2207
            The status of the task: 'COMPLETE', 'RUNNING', 'WAITING', or 'ERROR'
2208
        base_result: Optional[str], optional
2209
            base_result id
2210
        include : Optional[List[str]], optional
2211
            The fields to return, default to return all
2212
        exclude : Optional[List[str]], optional
2213
            The fields to not return, default to return all
2214
        limit : Optional[int], optional
2215
            maximum number of results to return
2216
            if 'limit' is greater than the global setting self._max_limit,
2217
            the self._max_limit will be returned instead
2218
            (This is to avoid overloading the server)
2219
        skip : int, optional
2220
            skip the first 'skip' results. Used to paginate, default is 0
2221
        return_json : bool, optional
2222
            Return the results as a list of json inseated of objects, deafult is True
2223
        with_ids : bool, optional
2224
            Include the ids in the returned objects/dicts, default is True
2225

2226
        Returns
2227
        -------
2228
        Dict[str, Any]
2229
            Dict with keys: data, meta. Data is the objects found
2230
        """
2231

2232 4
        meta = get_metadata_template()
2233 4
        query = format_query(
2234
            TaskQueueORM,
2235
            program=program,
2236
            id=id,
2237
            hash_index=hash_index,
2238
            status=status,
2239
            base_result_id=base_result,
2240
            tag=tag,
2241
            manager=manager,
2242
        )
2243

2244 4
        data = []
2245 4
        try:
2246 4
            data, meta["n_found"] = self.get_query_projection(
2247
                TaskQueueORM, query, limit=limit, skip=skip, include=include, exclude=exclude
2248
            )
2249 4
            meta["success"] = True
2250 0
        except Exception as err:
2251 0
            meta["error_description"] = str(err)
2252

2253 4
        data = [TaskRecord(**task) for task in data]
2254

2255 4
        return {"data": data, "meta": meta}
2256

2257 4
    def queue_get_by_id(self, id: List[str], limit: int = None, skip: int = 0, as_json: bool = True):
2258
        """Get tasks by their IDs
2259

2260
        Parameters
2261
        ----------
2262
        id : List[str]
2263
            List of the task Ids in the DB
2264
        limit : Optional[int], optional
2265
            max number of returned tasks. If limit > max_limit, max_limit
2266
            will be returned instead (safe query)
2267
        skip : int, optional
2268
            skip the first 'skip' results. Used to paginate, default is 0
2269
        as_json : bool, optioanl
2270
            Return tasks as JSON, default is True
2271

2272
        Returns
2273
        -------
2274
        List[TaskRecord]
2275
            List of the found tasks
2276
        """
2277

2278 4
        with self.session_scope() as session:
2279 4
            found = (
2280
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(id)).limit(self.get_limit(limit)).offset(skip)
2281
            )
2282

2283 4
            if as_json:
2284 4
                found = [TaskRecord(**task.to_dict()) for task in found]
2285

2286 4
        return found
2287

2288 4
    def queue_mark_complete(self, task_ids: List[str]) -> int:
2289
        """Update the given tasks as complete
2290
        Note that each task is already pointing to its result location
2291
        Mark the corresponding result/procedure as complete
2292

2293
        Parameters
2294
        ----------
2295
        task_ids : List[str]
2296
            IDs of the tasks to mark as COMPLETE
2297

2298
        Returns
2299
        -------
2300
        int
2301
            number of TaskRecord objects marked as COMPLETE, and deleted from the database consequtively.
2302
        """
2303

2304 4
        if not task_ids:
2305 2
            return 0
2306

2307 4
        update_fields = dict(status=TaskStatusEnum.complete, modified_on=dt.utcnow())
2308 4
        with self.session_scope() as session:
2309
            # assuming all task_ids are valid, then managers will be in order by id
2310 4
            managers = (
2311
                session.query(TaskQueueORM.manager)
2312
                .filter(TaskQueueORM.id.in_(task_ids))
2313
                .order_by(TaskQueueORM.id)
2314
                .all()
2315
            )
2316 4
            managers = [manager[0] if manager else manager for manager in managers]
2317 4
            task_manger_map = {task_id: manager for task_id, manager in zip(sorted(task_ids), managers)}
2318 4
            update_fields[BaseResultORM.manager_name] = case(task_manger_map, value=TaskQueueORM.id)
2319

2320 4
            session.query(BaseResultORM).filter(BaseResultORM.id == TaskQueueORM.base_result_id).filter(
2321
                TaskQueueORM.id.in_(task_ids)
2322
            ).update(update_fields, synchronize_session=False)
2323

2324
            # delete completed tasks
2325 4
            tasks_c = (
2326
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2327
            )
2328

2329 4
        return tasks_c
2330

2331 4
    def queue_mark_error(self, data: List[Tuple[int, str]]):
2332
        """
2333
        update the given tasks as errored
2334
        Mark the corresponding result/procedure as Errored
2335

2336
        Parameters
2337
        ----------
2338
        data : List[Tuple[int, str]]
2339
            List of task ids and their error messages desired to be assigned to them.
2340

2341
        Returns
2342
        -------
2343
        int
2344
            Number of tasks updated as errored.
2345
        """
2346

2347 4
        if not data:
2348 2
            return 0
2349

2350 4
        task_ids = []
2351 4
        with self.session_scope() as session:
2352
            # Make sure returned results are in the same order as the task ids
2353
            # SQL queries change the order when using "in"
2354 4
            data_dict = {item[0]: item[1] for item in data}
2355 4
            sorted_data = {key: data_dict[key] for key in sorted(data_dict.keys())}
2356 4
            task_objects = (
2357
                session.query(TaskQueueORM)
2358
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2359
                .order_by(TaskQueueORM.id)
2360
                .all()
2361
            )
2362 4
            base_results = (
2363
                session.query(BaseResultORM)
2364
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
2365
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2366
                .order_by(TaskQueueORM.id)
2367
                .all()
2368
            )
2369

2370 4
            for (task_id, msg), task_obj, base_result in zip(sorted_data.items(), task_objects, base_results):
2371

2372 4
                task_ids.append(task_id)
2373
                # update task
2374 4
                task_obj.status = TaskStatusEnum.error
2375 4
                task_obj.modified_on = dt.utcnow()
2376

2377
                # update result
2378 4
                base_result.status = TaskStatusEnum.error
2379 4
                base_result.manager_name = task_obj.manager
2380 4
                base_result.modified_on = dt.utcnow()
2381

2382 4
                err = KVStore(data=msg)
2383 4
                err_id = self.add_kvstore([err])["data"][0]
2384 4
                base_result.error = err_id
2385

2386 4
            session.commit()
2387

2388 4
        return len(task_ids)
2389

2390 4
    def queue_reset_status(
2391
        self,
2392
        id: Union[str, List[str]] = None,
2393
        base_result: Union[str, List[str]] = None,
2394
        manager: Optional[str] = None,
2395
        reset_running: bool = False,
2396
        reset_error: bool = False,
2397
    ) -> int:
2398
        """
2399
        Reset the status of the tasks that a manager owns from Running to Waiting
2400
        If reset_error is True, then also reset errored tasks AND its results/proc
2401

2402
        Parameters
2403
        ----------
2404
        id : Optional[Union[str, List[str]]], optional
2405
            The id of the task to modify
2406
        base_result : Optional[Union[str, List[str]]], optional
2407
            The id of the base result to modify
2408
        manager : Optional[str], optional
2409
            The manager name to reset the status of
2410
        reset_running : bool, optional
2411
            If True, reset running tasks to be waiting
2412
        reset_error : bool, optional
2413
            If True, also reset errored tasks to be waiting,
2414
            also update results/proc to be INCOMPLETE
2415

2416
        Returns
2417
        -------
2418
        int
2419
            Updated count
2420
        """
2421

2422 4
        if not (reset_running or reset_error):
2423
            # nothing to do
2424 0
            return 0
2425

2426 4
        if sum(x is not None for x in [id, base_result, manager]) == 0:
2427 4
            raise ValueError("All query fields are None, reset_status must specify queries.")
2428

2429 4
        status = []
2430 4
        if reset_running:
2431 4
            status.append(TaskStatusEnum.running)
2432 4
        if reset_error:
2433 2
            status.append(TaskStatusEnum.error)
2434

2435 4
        query = format_query(TaskQueueORM, id=id, base_result_id=base_result, manager=manager, status=status)
2436

2437
        # Must have status + something, checking above as well(being paranoid)
2438 4
        if len(query) < 2:
2439 0
            raise ValueError("All query fields are None, reset_status must specify queries.")
2440

2441 4
        with self.session_scope() as session:
2442
            # Update results and procedures if reset_error
2443 4
            task_ids = session.query(TaskQueueORM.id).filter(*query)
2444 4
            session.query(BaseResultORM).filter(TaskQueueORM.base_result_id == BaseResultORM.id).filter(
2445
                TaskQueueORM.id.in_(task_ids)
2446
            ).update(dict(status="INCOMPLETE", modified_on=dt.utcnow()), synchronize_session=False)
2447

2448 4
            updated = (
2449
                session.query(TaskQueueORM)
2450
                .filter(TaskQueueORM.id.in_(task_ids))
2451
                .update(dict(status=TaskStatusEnum.waiting, modified_on=dt.utcnow()), synchronize_session=False)
2452
            )
2453

2454 4
        return updated
2455

2456 4
    def del_tasks(self, id: Union[str, list]):
2457
        """
2458
        Delete a task from the queue. Use with cautious
2459

2460
        Parameters
2461
        ----------
2462
        id : str or List
2463
            Ids of the tasks to delete
2464
        Returns
2465
        -------
2466
        int
2467
            Number of tasks deleted
2468
        """
2469

2470 4
        task_ids = [id] if isinstance(id, (int, str)) else id
2471 4
        with self.session_scope() as session:
2472 4
            count = session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2473

2474 4
        return count
2475

2476 4
    def _copy_task_to_queue(self, record_list: List[TaskRecord]):
2477
        """
2478
        copy the given tasks as-is to the DB. Used for data migration
2479

2480
        Parameters
2481
        ----------
2482
        record_list : List[TaskRecords]
2483
            List of task records to be copied
2484

2485
        Returns
2486
        -------
2487
        Dict[str, Any]
2488
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2489
        """
2490

2491 0
        meta = add_metadata_template()
2492

2493 0
        task_ids = []
2494 0
        with self.session_scope() as session:
2495 0
            for task in record_list:
2496 0
                doc = session.query(TaskQueueORM).filter_by(base_result_id=task.base_result_id)
2497

2498 0
                if get_count_fast(doc) == 0:
2499 0
                    doc = TaskQueueORM(**task.dict(exclude={"id"}))
2500 0
                    doc.priority = doc.priority.value
2501 0
                    if isinstance(doc.error, dict):
2502 0
                        doc.error = json.dumps(doc.error)
2503

2504 0
                    session.add(doc)
2505 0
                    session.commit()  # TODO: faster if done in bulk
2506 0
                    task_ids.append(str(doc.id))
2507 0
                    meta["n_inserted"] += 1
2508
                else:
2509 0
                    id = str(doc.first().id)
2510 0
                    meta["duplicates"].append(id)  # TODO
2511
                    # If new or duplicate, add the id to the return list
2512 0
                    task_ids.append(id)
2513 0
        meta["success"] = True
2514

2515 0
        ret = {"data": task_ids, "meta": meta}
2516 0
        return ret
2517

2518
    ### QueueManagerORMs
2519

2520 4
    def manager_update(self, name, **kwargs):
2521

2522 4
        do_log = kwargs.pop("log", False)
2523

2524 4
        inc_count = {
2525
            # Increment relevant data
2526
            "submitted": QueueManagerORM.submitted + kwargs.pop("submitted", 0),
2527
            "completed": QueueManagerORM.completed + kwargs.pop("completed", 0),
2528
            "returned": QueueManagerORM.returned + kwargs.pop("returned", 0),
2529
            "failures": QueueManagerORM.failures + kwargs.pop("failures", 0),
2530
        }
2531

2532 4
        upd = {key: kwargs[key] for key in QueueManagerORM.__dict__.keys() if key in kwargs}
2533

2534 4
        with self.session_scope() as session:
2535
            # QueueManagerORM.objects()  # init
2536 4
            manager = session.query(QueueManagerORM).filter_by(name=name)
2537 4
            if manager.count() > 0:  # existing
2538 4
                upd.update(inc_count, modified_on=dt.utcnow())
2539 4
                num_updated = manager.update(upd)
2540
            else:  # create new, ensures defaults and validations
2541 4
                manager = QueueManagerORM(name=name, **upd)
2542 4
                session.add(manager)
2543 4
                session.commit()
2544 4
                num_updated = 1
2545

2546 4
            if do_log:
2547
                # Pull again in case it was updated
2548 4
                manager = session.query(QueueManagerORM).filter_by(name=name).first()
2549

2550 4
                manager_log = QueueManagerLogORM(
2551
                    manager_id=manager.id,
2552
                    completed=manager.completed,
2553
                    submitted=manager.submitted,
2554
                    failures=manager.failures,
2555
                    total_worker_walltime=manager.total_worker_walltime,
2556
                    total_task_walltime=manager.total_task_walltime,
2557
                    active_tasks=manager.active_tasks,
2558
                    active_cores=manager.active_cores,
2559
                    active_memory=manager.active_memory,
2560
                )
2561

2562 4
                session.add(manager_log)
2563 4
                session.commit()
2564

2565 4
        return num_updated == 1
2566

2567 4
    def get_managers(
2568
        self, name: str = None, status: str = None, modified_before=None, modified_after=None, limit=None, skip=0
2569
    ):
2570

2571 4
        meta = get_metadata_template()
2572 4
        query = format_query(QueueManagerORM, name=name, status=status)
2573

2574 4
        if modified_before:
2575 4
            query.append(QueueManagerORM.modified_on <= modified_before)
2576

2577 4
        if modified_after:
2578 0
            query.append(QueueManagerORM.modified_on >= modified_after)
2579

2580 4
        data, meta["n_found"] = self.get_query_projection(QueueManagerORM, query, limit=limit, skip=skip)
2581 4
        meta["success"] = True
2582

2583 4
        return {"data": data, "meta": meta}
2584

2585 4
    def get_manager_logs(self, manager_ids: Union[List[str], str], timestamp_after=None, limit=None, skip=0):
2586 2
        meta = get_metadata_template()
2587 2
        query = format_query(QueueManagerLogORM, manager_id=manager_ids)
2588

2589 2
        if timestamp_after:
2590 2
            query.append(QueueManagerLogORM.timestamp >= timestamp_after)
2591

2592 2
        data, meta["n_found"] = self.get_query_projection(
2593
            QueueManagerLogORM, query, limit=limit, skip=skip, exclude=["id"]
2594
        )
2595 2
        meta["success"] = True
2596

2597 2
        return {"data": data, "meta": meta}
2598

2599 4
    def _copy_managers(self, record_list: Dict):
2600
        """
2601
        copy the given managers as-is to the DB. Used for data migration
2602

2603
        Parameters
2604
        ----------
2605
        record_list : List[Dict[str, Any]]
2606
            list of dict of managers data
2607
        Returns
2608
        -------
2609
        Dict[str, Any]
2610
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2611
        """
2612

2613 0
        meta = add_metadata_template()
2614

2615 0
        manager_names = []
2616 0
        with self.session_scope() as session:
2617 0
            for manager in record_list:
2618 0
                doc = session.query(QueueManagerORM).filter_by(name=manager["name"])
2619

2620 0
                if get_count_fast(doc) == 0:
2621 0
                    doc = QueueManagerORM(**manager)
2622 0
                    if isinstance(doc.created_on, float):
2623 0
                        doc.created_on = dt.fromtimestamp(doc.created_on / 1e3)
2624 0
                    if isinstance(doc.modified_on, float):
2625 0
                        doc.modified_on = dt.fromtimestamp(doc.modified_on / 1e3)
2626 0
                    session.add(doc)
2627 0
                    session.commit()  # TODO: faster if done in bulk
2628 0
                    manager_names.append(doc.name)
2629 0
                    meta["n_inserted"] += 1
2630
                else:
2631 0
                    name = doc.first().name
2632 0
                    meta["duplicates"].append(name)  # TODO
2633
                    # If new or duplicate, add the id to the return list
2634 0
                    manager_names.append(id)
2635 0
        meta["success"] = True
2636

2637 0
        ret = {"data": manager_names, "meta": meta}
2638 0
        return ret
2639

2640
    ### UserORMs
2641

2642 4
    _valid_permissions = frozenset({"read", "write", "compute", "queue", "admin"})
2643

2644 4
    @staticmethod
2645 4
    def _generate_password() -> str:
2646
        """
2647
        Generates a random password e.g. for add_user and modify_user.
2648

2649
        Returns
2650
        -------
2651
        str
2652
            An unhashed random password.
2653
        """
2654 4
        return secrets.token_urlsafe(32)
2655

2656 4
    def add_user(
2657
        self, username: str, password: Optional[str] = None, permissions: List[str] = ["read"], overwrite: bool = False
2658
    ) -> Tuple[bool, str]:
2659
        """
2660
        Adds a new user and associated permissions.
2661

2662
        Passwords are stored using bcrypt.
2663

2664
        Parameters
2665
        ----------
2666
        username : str
2667
            New user's username
2668
        password : Optional[str], optional
2669
            The user's password. If None, a new password will be generated.
2670
        permissions : Optional[List[str]], optional
2671
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2672
        overwrite: bool, optional
2673
            Overwrite the user if it already exists.
2674
        Returns
2675
        -------
2676
        Tuple[bool, str]
2677
            A tuple of (success flag, password)
2678
        """
2679

2680
        # Make sure permissions are valid
2681 4
        if not self._valid_permissions >= set(permissions):
2682 0
            raise KeyError("Permissions settings not understood: {}".format(set(permissions) - self._valid_permissions))
2683

2684 4
        if password is None:
2685 4
            password = self._generate_password()
2686

2687 4
        hashed = bcrypt.hashpw(password.encode("UTF-8"), bcrypt.gensalt(6))
2688

2689 4
        blob = {"username": username, "password": hashed, "permissions": permissions}
2690

2691 4
        success = False
2692 4
        with self.session_scope() as session:
2693 4
            if overwrite:
2694 4
                count = session.query(UserORM).filter_by(username=username).update(blob)
2695
                # doc.upsert_one(**blob)
2696 4
                success = count == 1
2697

2698
            else:
2699 4
                try:
2700 4
                    user = UserORM(**blob)
2701 4
                    session.add(user)
2702 4
                    session.commit()
2703 4
                    success = True
2704 4
                except IntegrityError as err:
2705 4
                    self.logger.warning(str(err))
2706 4
                    success = False
2707 4
                    session.rollback()
2708

2709 4
        return success, password
2710

2711 4
    def verify_user(self, username: str, password: str, permission: str) -> Tuple[bool, str]:
2712
        """
2713
        Verifies if a user has the requested permissions or not.
2714

2715
        Passwords are stored and verified using bcrypt.
2716

2717
        Parameters
2718
        ----------
2719
        username : str
2720
            The username to verify
2721
        password : str
2722
            The password associated with the username
2723
        permission : str
2724
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2725

2726
        Returns
2727
        -------
2728
        Tuple[bool, str]
2729
            A tuple of (success flag, failure string)
2730

2731
        Examples
2732
        --------
2733

2734
        >>> db.add_user("george", "shortpw")
2735

2736
        >>> db.verify_user("george", "shortpw", "read")[0]
2737
        True
2738

2739
        >>> db.verify_user("george", "shortpw", "admin")[0]
2740
        False
2741

2742
        """
2743

2744 4
        if self._bypass_security or (self._allow_read and (permission == "read")):
2745 4
            return (True, "Success")
2746

2747 4
        with self.session_scope() as session:
2748 4
            data = session.query(UserORM).filter_by(username=username).first()
2749

2750 4
            if data is None:
2751 4
                return (False, "User not found.")
2752

2753
            # Completely general failure
2754 4
            try:
2755 4
                pwcheck = bcrypt.checkpw(password.encode("UTF-8"), data.password)
2756 0
            except Exception as e:
2757 0
                self.logger.warning(f"Password check failure, error: {str(e)}")
2758 0
                self.logger.warning(
2759
                    f"Error likely caused by encryption salt mismatch, potentially fixed by creating a new password for user {username}."
2760
                )
2761 0
                return (False, "Password decryption failure, please contact your database administrator.")
2762

2763 4
            if pwcheck is False:
2764 4
                return (False, "Incorrect password.")
2765

2766
            # Admin has access to everything
2767 4
            if (permission.lower() not in data.permissions) and ("admin" not in data.permissions):
2768 4
                return (False, "User has insufficient permissions.")
2769

2770 4
        return (True, "Success")
2771

2772 4
    def modify_user(
2773
        self,
2774
        username: str,
2775
        password: Optional[str] = None,
2776
        reset_password: bool = False,
2777
        permissions: Optional[List[str]] = None,
2778
    ) -> Tuple[bool, str]:
2779
        """
2780
        Alters a user's password, permissions, or both
2781

2782
        Passwords are stored using bcrypt.
2783

2784
        Parameters
2785
        ----------
2786
        username : str
2787
            The username
2788
        password : Optional[str], optional
2789
            The user's new password. If None, the password will not be updated. Excludes reset_password.
2790
        reset_password: bool, optional
2791
            Reset the user's password to a new autogenerated one. The default is False.
2792
        permissions : Optional[List[str]], optional
2793
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2794

2795
        Returns
2796
        -------
2797
        Tuple[bool, str]
2798
            A tuple of (success flag, message)
2799
        """
2800

2801 4
        if reset_password and password is not None:
2802 4
            return False, "only one of reset_password and password may be specified"
2803

2804 4
        with self.session_scope() as session:
2805 4
            data = session.query(UserORM).filter_by(username=username).first()
2806

2807 4
            if data is None:
2808 4
                return False, f"User {username} not found."
2809

2810 4
            blob = {"username": username}
2811

2812 4
            if permissions is not None:
2813
                # Make sure permissions are valid
2814 4
                if not self._valid_permissions >= set(permissions):
2815 0
                    return False, "Permissions not understood: {}".format(set(permissions) - self._valid_permissions)
2816 4
                blob["permissions"] = permissions
2817 4
            if reset_password:
2818 4
                password = self._generate_password()
2819 4
            if password is not None:
2820 4
                blob["password"] = bcrypt.hashpw(password.encode("UTF-8"), bcrypt.gensalt(6))
2821

2822 4
            count = session.query(UserORM).filter_by(username=username).update(blob)
2823 4
            success = count == 1
2824

2825