1
"""
2
SQLAlchemy Database class to handle access to Pstgres through ORM
3
"""
4

5 4
try:
6 4
    from sqlalchemy import create_engine, and_, or_, case, func
7 4
    from sqlalchemy.exc import IntegrityError
8 4
    from sqlalchemy.orm import sessionmaker, with_polymorphic
9 4
    from sqlalchemy.sql.expression import desc
10 4
    from sqlalchemy.sql.expression import case as expression_case
11 0
except ImportError:
12 0
    raise ImportError(
13
        "SQLAlchemy_socket requires sqlalchemy, please install this python " "module or try a different db_socket."
14
    )
15

16 4
import json
17 4
import logging
18 4
import secrets
19 4
from collections.abc import Iterable
20 4
from contextlib import contextmanager
21 4
from datetime import datetime as dt
22 4
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
23

24 4
import bcrypt
25

26
# pydantic classes
27 4
from qcfractal.interface.models import (
28
    GridOptimizationRecord,
29
    KeywordSet,
30
    Molecule,
31
    ObjectId,
32
    OptimizationRecord,
33
    ResultRecord,
34
    TaskRecord,
35
    TaskStatusEnum,
36
    TorsionDriveRecord,
37
    KVStore,
38
    prepare_basis,
39
)
40 4
from qcfractal.storage_sockets.db_queries import QUERY_CLASSES
41 4
from qcfractal.storage_sockets.models import (
42
    AccessLogORM,
43
    BaseResultORM,
44
    CollectionORM,
45
    DatasetORM,
46
    GridOptimizationProcedureORM,
47
    KeywordsORM,
48
    KVStoreORM,
49
    MoleculeORM,
50
    OptimizationProcedureORM,
51
    QueueManagerLogORM,
52
    QueueManagerORM,
53
    ReactionDatasetORM,
54
    ResultORM,
55
    ServerStatsLogORM,
56
    ServiceQueueORM,
57
    TaskQueueORM,
58
    TorsionDriveProcedureORM,
59
    UserORM,
60
    VersionsORM,
61
    WavefunctionStoreORM,
62
)
63 4
from qcfractal.storage_sockets.storage_utils import add_metadata_template, get_metadata_template
64

65 4
from .models import Base
66

67 4
if TYPE_CHECKING:
68 0
    from ..services.service_util import BaseService
69

70
# for version checking
71 4
import qcelemental, qcfractal, qcengine
72

73 4
_null_keys = {"basis", "keywords"}
74 4
_id_keys = {"id", "molecule", "keywords", "procedure_id"}
75 4
_lower_func = lambda x: x.lower()
76 4
_prepare_keys = {"program": _lower_func, "basis": prepare_basis, "method": _lower_func, "procedure": _lower_func}
77

78

79 4
def dict_from_tuple(keys, values):
80 0
    return [dict(zip(keys, row)) for row in values]
81

82

83 4
def format_query(ORMClass, **query: Dict[str, Union[str, List[str]]]) -> Dict[str, Union[str, List[str]]]:
84
    """
85
    Formats a query into a SQLAlchemy format.
86
    """
87

88 4
    ret = []
89 4
    for k, v in query.items():
90 4
        if v is None:
91 4
            continue
92

93
        # Handle None keys
94 4
        k = k.lower()
95 4
        if (k in _null_keys) and (v == "null"):
96 4
            v = None
97

98 4
        if k in _prepare_keys:
99 4
            f = _prepare_keys[k]
100 4
            if isinstance(v, (list, tuple)):
101 4
                v = [f(x) for x in v]
102
            else:
103 4
                v = f(v)
104

105 4
        if isinstance(v, (list, tuple)):
106 4
            col = getattr(ORMClass, k)
107 4
            ret.append(getattr(col, "in_")(v))
108
        else:
109 4
            ret.append(getattr(ORMClass, k) == v)
110

111 4
    return ret
112

113

114 4
def get_count_fast(query):
115
    """
116
    returns total count of the query using:
117
        Fast: SELECT COUNT(*) FROM TestModel WHERE ...
118

119
    Not like q.count():
120
        Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
121
    """
122

123 4
    count_q = query.statement.with_only_columns([func.count()]).order_by(None)
124 4
    count = query.session.execute(count_q).scalar()
125

126 4
    return count
127

128

129 4
def get_procedure_class(record):
130

131 4
    if isinstance(record, OptimizationRecord):
132 4
        procedure_class = OptimizationProcedureORM
133 4
    elif isinstance(record, TorsionDriveRecord):
134 4
        procedure_class = TorsionDriveProcedureORM
135 1
    elif isinstance(record, GridOptimizationRecord):
136 1
        procedure_class = GridOptimizationProcedureORM
137
    else:
138 0
        raise TypeError("Procedure of type {} is not valid or supported yet.".format(type(record)))
139

140 4
    return procedure_class
141

142

143 4
def get_collection_class(collection_type):
144

145 4
    collection_map = {"dataset": DatasetORM, "reactiondataset": ReactionDatasetORM}
146

147 4
    collection_class = CollectionORM
148

149 4
    if collection_type in collection_map:
150 4
        collection_class = collection_map[collection_type]
151

152 4
    return collection_class
153

154

155 4
class SQLAlchemySocket:
156
    """
157
        SQLAlcehmy QCDB wrapper class.
158
    """
159

160 4
    def __init__(
161
        self,
162
        uri: str,
163
        project: str = "molssidb",
164
        bypass_security: bool = False,
165
        allow_read: bool = True,
166
        logger: "Logger" = None,
167
        sql_echo: bool = False,
168
        max_limit: int = 1000,
169
        skip_version_check: bool = False,
170
    ):
171
        """
172
        Constructs a new SQLAlchemy socket
173

174
        """
175

176
        # Logging data
177 4
        if logger:
178 0
            self.logger = logger
179
        else:
180 4
            self.logger = logging.getLogger("SQLAlcehmySocket")
181

182
        # Security
183 4
        self._bypass_security = bypass_security
184 4
        self._allow_read = allow_read
185

186 4
        self._lower_results_index = ["method", "basis", "program"]
187

188
        # disconnect from any active default connection
189
        # disconnect()
190 4
        if "psycopg2" not in uri:
191 4
            uri = uri.replace("postgresql", "postgresql+psycopg2")
192

193 4
        if project and not uri.endswith("/"):
194 0
            uri = uri + "/"
195

196 4
        uri = uri + project
197 4
        self.logger.info(f"SQLAlchemy attempt to connect to {uri}.")
198

199
        # Connect to DB and create session
200 4
        self.uri = uri
201 4
        self.engine = create_engine(
202
            uri,
203
            echo=sql_echo,  # echo for logging into python logging
204
            pool_size=5,  # 5 is the default, 0 means unlimited
205
        )
206 4
        self.logger.info(
207
            "Connected SQLAlchemy to DB dialect {} with driver {}".format(self.engine.dialect.name, self.engine.driver)
208
        )
209

210 4
        self.Session = sessionmaker(bind=self.engine)
211

212
        # check version compatibility
213 4
        db_ver = self.check_lib_versions()
214 4
        self.logger.info(f"DB versions: {db_ver}")
215 4
        if (not skip_version_check) and (db_ver and qcfractal.__version__ != db_ver["fractal_version"]):
216 0
            raise TypeError(
217
                f"You are running QCFractal version {qcfractal.__version__} "
218
                f'with an older DB version ({db_ver["fractal_version"]}). '
219
                f'Please run "qcfractal-server upgrade" first before starting the server.'
220
            )
221

222
        # actually create the tables
223 4
        try:
224 4
            Base.metadata.create_all(self.engine)
225 4
            self.check_lib_versions()  # update version if new DB
226 0
        except Exception as e:
227 0
            raise ValueError(f"SQLAlchemy Connection Error\n {str(e)}") from None
228

229
        # Advanced queries objects
230 4
        self._query_classes = {
231
            cls._class_name: cls(self.engine.url.database, max_limit=max_limit) for cls in QUERY_CLASSES
232
        }
233

234
        # if expanded_uri["password"] is not None:
235
        #     # connect to mongoengine
236
        #     self.client = db.connect(db=project, host=uri, authMechanism=authMechanism, authSource=authSource)
237
        # else:
238
        #     # connect to mongoengine
239
        #     self.client = db.connect(db=project, host=uri)
240

241
        # self._url, self._port = expanded_uri["nodelist"][0]
242

243
        # try:
244
        #     version_array = self.client.server_info()['versionArray']
245
        #
246
        #     if tuple(version_array) < (3, 2):
247
        #         raise RuntimeError
248
        # except AttributeError:
249
        #     raise RuntimeError(
250
        #         "Could not detect MongoDB version at URL {}. It may be a very old version or installed incorrectly. "
251
        #         "Choosing to stop instead of assuming version is at least 3.2.".format(uri))
252
        # except RuntimeError:
253
        #     # Trap low version
254
        #     raise RuntimeError("Connected MongoDB at URL {} needs to be at least version 3.2, found version {}.".
255
        #                        format(uri, self.client.server_info()['version']))
256

257 4
        self._project_name = project
258 4
        self._max_limit = max_limit
259

260 4
    def __str__(self) -> str:
261 0
        return f"<SQLAlchemySocket: address='{self.uri}`>"
262

263 4
    @contextmanager
264
    def session_scope(self):
265
        """Provide a transactional scope"""
266

267 4
        session = self.Session()
268 4
        try:
269 4
            yield session
270 4
            session.commit()
271 4
        except:
272 4
            session.rollback()
273 4
            raise
274
        finally:
275 4
            session.close()
276

277 4
    def _clear_db(self, db_name: str = None):
278
        """Dangerous, make sure you are deleting the right DB"""
279

280 4
        self.logger.warning("SQL: Clearing database '{}' and dropping all tables.".format(db_name))
281

282
        # drop all tables that it knows about
283 4
        Base.metadata.drop_all(self.engine)
284

285
        # create the tables again
286 4
        Base.metadata.create_all(self.engine)
287

288
        # self.client.drop_database(db_name)
289

290 4
    def _delete_DB_data(self, db_name):
291
        """TODO: needs more testing"""
292

293 4
        with self.session_scope() as session:
294
            # Metadata
295 4
            session.query(VersionsORM).delete(synchronize_session=False)
296
            # Task and services
297 4
            session.query(TaskQueueORM).delete(synchronize_session=False)
298 4
            session.query(QueueManagerLogORM).delete(synchronize_session=False)
299 4
            session.query(QueueManagerORM).delete(synchronize_session=False)
300 4
            session.query(ServiceQueueORM).delete(synchronize_session=False)
301

302
            # Collections
303 4
            session.query(CollectionORM).delete(synchronize_session=False)
304

305
            # Records
306 4
            session.query(TorsionDriveProcedureORM).delete(synchronize_session=False)
307 4
            session.query(GridOptimizationProcedureORM).delete(synchronize_session=False)
308 4
            session.query(OptimizationProcedureORM).delete(synchronize_session=False)
309 4
            session.query(ResultORM).delete(synchronize_session=False)
310 4
            session.query(WavefunctionStoreORM).delete(synchronize_session=False)
311 4
            session.query(BaseResultORM).delete(synchronize_session=False)
312

313
            # Auxiliary tables
314 4
            session.query(KVStoreORM).delete(synchronize_session=False)
315 4
            session.query(MoleculeORM).delete(synchronize_session=False)
316

317 4
    def get_project_name(self) -> str:
318 4
        return self._project_name
319

320 4
    def get_limit(self, limit: Optional[int]) -> int:
321
        """Get the allowed limit on results to return in queries based on the
322
         given `limit`. If this number is greater than the
323
         SQLAlchemySocket.max_limit then the max_limit will be returned instead.
324
        """
325

326 4
        return limit if limit is not None and limit < self._max_limit else self._max_limit
327

328 4
    def get_query_projection(self, className, query, *, limit=None, skip=0, include=None, exclude=None):
329

330 4
        if include and exclude:
331 0
            raise AttributeError(
332
                f"Either include or exclude can be "
333
                f"used, not both at the same query. "
334
                f"Given include: {include}, exclude: {exclude}"
335
            )
336

337 4
        prop, hybrids, relationships = className._get_col_types()
338

339
        # build projection from include or exclude
340 4
        _projection = []
341 4
        if include:
342 4
            _projection = set(include)
343 4
        elif exclude:
344 4
            _projection = set(className._all_col_names()) - set(exclude) - set(className.db_related_fields)
345 4
        _projection = list(_projection)
346

347 4
        proj = []
348 4
        join_attrs = {}
349 4
        callbacks = []
350

351
        # prepare hybrid attributes for callback and joins
352 4
        for key in _projection:
353 4
            if key in prop:  # normal column
354 4
                proj.append(getattr(className, key))
355

356
            # if hybrid property, save callback, and relation if any
357 4
            elif key in hybrids:
358 4
                callbacks.append(key)
359

360
                # if it has a relationship
361 4
                if key + "_obj" in relationships.keys():
362

363
                    # join_class_name = relationships[key + '_obj']
364 4
                    join_attrs[key] = relationships[key + "_obj"]
365
            else:
366 0
                raise AttributeError(f"Atrribute {key} is not found in class {className}.")
367

368 4
        for key in join_attrs:
369 4
            _projection.remove(key)
370

371 4
        with self.session_scope() as session:
372 4
            if _projection or join_attrs:
373

374 4
                if join_attrs and "id" not in _projection:  # if the id is need for joins
375 4
                    proj.append(getattr(className, "id"))
376 4
                    _projection.append("_id")  # not to be returned to user
377

378
                # query with projection, without joins
379 4
                data = session.query(*proj).filter(*query)
380

381 4
                n_found = get_count_fast(data)  # before iterating on the data
382 4
                data = data.limit(self.get_limit(limit)).offset(skip)
383 4
                rdata = [dict(zip(_projection, row)) for row in data]
384

385
                # query for joins if any (relationships and hybrids)
386 4
                if join_attrs:
387 4
                    res_ids = [d.get("id", d.get("_id")) for d in rdata]
388 4
                    res_ids.sort()
389 4
                    join_data = {res_id: {} for res_id in res_ids}
390

391
                    # relations data
392 4
                    for key, relation_details in join_attrs.items():
393 4
                        ret = (
394
                            session.query(
395
                                relation_details["remote_side_column"].label("id"), relation_details["join_class"]
396
                            )
397
                            .filter(relation_details["remote_side_column"].in_(res_ids))
398
                            .order_by(relation_details["remote_side_column"])
399
                            .all()
400
                        )
401 4
                        for res_id in res_ids:
402 4
                            join_data[res_id][key] = []
403 4
                            for res in ret:
404 4
                                if res_id == res[0]:
405 4
                                    join_data[res_id][key].append(res[1])
406

407 4
                        for data in rdata:
408 4
                            parent_id = data.get("id", data.get("_id"))
409 4
                            data[key] = join_data[parent_id][key]
410 4
                            data.pop("_id", None)
411

412
                # call hybrid methods
413 4
                for callback in callbacks:
414 4
                    for res in rdata:
415 4
                        res[callback] = getattr(className, "_" + callback)(res[callback])
416

417 4
                id_fields = className._get_fieldnames_with_DB_ids_()
418 4
                for d in rdata:
419
                    # Expand extra json into fields
420 4
                    if "extra" in d:
421 4
                        d.update(d["extra"])
422 4
                        del d["extra"]
423

424
                    # transform ids from int into str
425 4
                    for key in id_fields:
426 4
                        if key in d.keys() and d[key] is not None:
427 4
                            if isinstance(d[key], Iterable):
428 0
                                d[key] = [str(i) for i in d[key]]
429
                            else:
430 4
                                d[key] = str(d[key])
431
                # print('--------rdata after: ', rdata)
432
            else:
433 4
                data = session.query(className).filter(*query)
434

435
                # from sqlalchemy.dialects import postgresql
436
                # print(data.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
437 4
                n_found = get_count_fast(data)
438 4
                data = data.limit(self.get_limit(limit)).offset(skip).all()
439 4
                rdata = [d.to_dict() for d in data]
440

441 4
        return rdata, n_found
442

443
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
444

445 4
    def custom_query(self, class_name: str, query_key: str, **kwargs):
446
        """
447
        Run advanced or specialized queries on different classes
448

449
        Parameters
450
        ----------
451
        class_name : str
452
            REST APIs name of the class (not the actual python name),
453
             e.g., torsiondrive
454
        query_key : str
455
            The feature or attribute to look for, like initial_molecule
456
        kwargs
457
            Extra arguments needed by the query, like the id of the torison drive
458

459
        Returns
460
        -------
461
            Dict[str,Any]:
462
                Query result dictionary with keys:
463
                data: returned data by the query (variable format)
464
                meta:
465
                    success: True or False
466
                    error_description: Error msg to show to the user
467
        """
468

469 4
        ret = {"data": [], "meta": get_metadata_template()}
470

471 4
        try:
472 4
            if class_name not in self._query_classes:
473 0
                raise AttributeError(f"Class name {class_name} is not found.")
474

475 4
            session = self.Session()
476 4
            ret["data"] = self._query_classes[class_name].query(session, query_key, **kwargs)
477 4
            ret["meta"]["success"] = True
478 4
            try:
479 4
                ret["meta"]["n_found"] = len(ret["data"])
480 4
            except TypeError:
481 4
                ret["meta"]["n_found"] = 1
482 0
        except Exception as err:
483 0
            ret["meta"]["error_description"] = str(err)
484

485 4
        return ret
486

487
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
488

489 4
    def save_access(self, log_data):
490

491 0
        with self.session_scope() as session:
492 0
            log = AccessLogORM(**log_data)
493 0
            session.add(log)
494 0
            session.commit()
495

496
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logs (KV store) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
497

498 4
    def add_kvstore(self, outputs: List[KVStore]):
499
        """
500
        Adds to the key/value store table.
501

502
        Parameters
503
        ----------
504
        outputs : List[Any]
505
            A list of KVStore objects add.
506

507
        Returns
508
        -------
509
        Dict[str, Any]
510
            Dictionary with keys data and meta, data is the ids of added blobs
511
        """
512

513 4
        meta = add_metadata_template()
514 4
        output_ids = []
515 4
        with self.session_scope() as session:
516 4
            for output in outputs:
517 4
                if output is None:
518 2
                    output_ids.append(None)
519 2
                    continue
520

521 4
                entry = KVStoreORM(**output.dict())
522 4
                session.add(entry)
523 4
                session.commit()
524 4
                output_ids.append(str(entry.id))
525 4
                meta["n_inserted"] += 1
526

527 4
        meta["success"] = True
528

529 4
        return {"data": output_ids, "meta": meta}
530

531 4
    def get_kvstore(self, id: List[ObjectId] = None, limit: int = None, skip: int = 0):
532
        """
533
        Pulls from the key/value store table.
534

535
        Parameters
536
        ----------
537
        id : List[str]
538
            A list of ids to query
539
        limit : Optional[int], optional
540
            Maximum number of results to return.
541
        skip : Optional[int], optional
542
            skip the `skip` results
543
        Returns
544
        -------
545
        Dict[str, Any]
546
            Dictionary with keys data and meta, data is a key-value dictionary of found key-value stored items.
547
        """
548

549 4
        meta = get_metadata_template()
550

551 4
        query = format_query(KVStoreORM, id=id)
552

553 4
        rdata, meta["n_found"] = self.get_query_projection(KVStoreORM, query, limit=limit, skip=skip)
554

555 4
        meta["success"] = True
556

557 4
        data = {}
558
        # TODO - after migrating everything, remove the 'value' column in the table
559 4
        for d in rdata:
560 4
            val = d.pop("value")
561 4
            if d["data"] is None:
562
                # Set the data field to be the string or dictionary
563 4
                d["data"] = val
564

565
                # Remove these and let the model handle the defaults
566 4
                d.pop("compression")
567 4
                d.pop("compression_level")
568

569
            # The KVStore constructor can handle conversion of strings and dictionaries
570 4
            data[d["id"]] = KVStore(**d)
571

572 4
        return {"data": data, "meta": meta}
573

574
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Molecule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
575

576 4
    def get_add_molecules_mixed(self, data: List[Union[ObjectId, Molecule]]) -> List[Molecule]:
577
        """
578
        Get or add the given molecules (if they don't exit).
579
        MoleculeORMs are given in a mixed format, either as a dict of mol data
580
        or as existing mol id
581

582
        TODO: to be split into get by_id and get_by_data
583
        """
584

585 4
        meta = get_metadata_template()
586

587 4
        ordered_mol_dict = {indx: mol for indx, mol in enumerate(data)}
588 4
        new_molecules = {}
589 4
        id_mols = {}
590 4
        for idx, mol in ordered_mol_dict.items():
591 4
            if isinstance(mol, (int, str)):
592 4
                id_mols[idx] = mol
593 4
            elif isinstance(mol, Molecule):
594 4
                new_molecules[idx] = mol
595
            else:
596 0
                meta["errors"].append((idx, "Data type not understood"))
597

598 4
        ret_mols = {}
599

600
        # Add all new molecules
601 4
        flat_mols = []
602 4
        flat_mol_keys = []
603 4
        for k, v in new_molecules.items():
604 4
            flat_mol_keys.append(k)
605 4
            flat_mols.append(v)
606 4
        flat_mols = self.add_molecules(flat_mols)["data"]
607

608 4
        id_mols.update({k: v for k, v in zip(flat_mol_keys, flat_mols)})
609

610
        # Get molecules by index and translate back to dict
611 4
        tmp = self.get_molecules(list(id_mols.values()))
612 4
        id_mols_list = tmp["data"]
613 4
        meta["errors"].extend(tmp["meta"]["errors"])
614

615 4
        inv_id_mols = {v: k for k, v in id_mols.items()}
616

617 4
        for mol in id_mols_list:
618 4
            ret_mols[inv_id_mols[mol.id]] = mol
619

620 4
        meta["success"] = True
621 4
        meta["n_found"] = len(ret_mols)
622 4
        meta["missing"] = list(ordered_mol_dict.keys() - ret_mols.keys())
623

624
        # Rewind to flat last
625 4
        ret = []
626 4
        for ind in range(len(ordered_mol_dict)):
627 4
            if ind in ret_mols:
628 4
                ret.append(ret_mols[ind])
629
            else:
630 4
                ret.append(None)
631

632 4
        return {"meta": meta, "data": ret}
633

634 4
    def add_molecules(self, molecules: List[Molecule]):
635
        """
636
        Adds molecules to the database.
637

638
        Parameters
639
        ----------
640
        molecules : List[Molecule]
641
            A List of molecule objects to add.
642

643
        Returns
644
        -------
645
        bool
646
            Whether the operation was successful.
647
        """
648

649 4
        meta = add_metadata_template()
650

651 4
        results = []
652 4
        with self.session_scope() as session:
653

654
            # Build out the ORMs
655 4
            orm_molecules = []
656 4
            for dmol in molecules:
657

658 4
                if dmol.validated is False:
659 0
                    dmol = Molecule(**dmol.dict(), validate=True)
660

661 4
                mol_dict = dmol.dict(exclude={"id", "validated"})
662

663
                # TODO: can set them as defaults in the sql_models, not here
664 4
                mol_dict["fix_com"] = True
665 4
                mol_dict["fix_orientation"] = True
666

667
                # Build fresh indices
668 4
                mol_dict["molecule_hash"] = dmol.get_hash()
669 4
                mol_dict["molecular_formula"] = dmol.get_molecular_formula()
670

671 4
                mol_dict["identifiers"] = {}
672 4
                mol_dict["identifiers"]["molecule_hash"] = mol_dict["molecule_hash"]
673 4
                mol_dict["identifiers"]["molecular_formula"] = mol_dict["molecular_formula"]
674

675
                # search by index keywords not by all keys, much faster
676 4
                orm_molecules.append(MoleculeORM(**mol_dict))
677

678
            # Check if we have duplicates
679 4
            hash_list = [x.molecule_hash for x in orm_molecules]
680 4
            query = format_query(MoleculeORM, molecule_hash=hash_list)
681 4
            indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
682 4
            previous_id_map = {k: v for k, v in indices}
683

684
            # For a bulk add there must be no pre-existing and there must be no duplicates in the add list
685 4
            bulk_ok = len(hash_list) == len(set(hash_list))
686 4
            bulk_ok &= len(previous_id_map) == 0
687
            # bulk_ok = False
688

689 4
            if bulk_ok:
690
                # Bulk save, doesn't update fields for speed
691 4
                session.bulk_save_objects(orm_molecules)
692 4
                session.commit()
693

694
                # Query ID's and reorder based off orm_molecule ordered list
695 4
                query = format_query(MoleculeORM, molecule_hash=hash_list)
696 4
                indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
697

698 4
                id_map = {k: v for k, v in indices}
699 4
                n_inserted = len(orm_molecules)
700

701
            else:
702
                # Start from old ID map
703 4
                id_map = previous_id_map
704

705 4
                new_molecules = []
706 4
                n_inserted = 0
707

708 4
                for orm_mol in orm_molecules:
709 4
                    duplicate_id = id_map.get(orm_mol.molecule_hash, False)
710 4
                    if duplicate_id is not False:
711 4
                        meta["duplicates"].append(str(duplicate_id))
712
                    else:
713 4
                        new_molecules.append(orm_mol)
714 4
                        id_map[orm_mol.molecule_hash] = "placeholder_id"
715 4
                        n_inserted += 1
716 4
                        session.add(orm_mol)
717

718
                    # We should make sure there was not a hash collision?
719
                    # new_mol.compare(old_mol)
720
                    # raise KeyError("!!! WARNING !!!: Hash collision detected")
721

722 4
                session.commit()
723

724 4
                for new_mol in new_molecules:
725 4
                    id_map[new_mol.molecule_hash] = new_mol.id
726

727 4
            results = [str(id_map[x.molecule_hash]) for x in orm_molecules]
728 4
            assert "placeholder_id" not in results
729 4
            meta["n_inserted"] = n_inserted
730

731 4
        meta["success"] = True
732

733 4
        ret = {"data": results, "meta": meta}
734 4
        return ret
735

736 4
    def get_molecules(self, id=None, molecule_hash=None, molecular_formula=None, limit: int = None, skip: int = 0):
737 4
        try:
738 4
            if isinstance(molecular_formula, str):
739 4
                molecular_formula = qcelemental.molutil.order_molecular_formula(molecular_formula)
740 4
            elif isinstance(molecular_formula, list):
741 2
                molecular_formula = [qcelemental.molutil.order_molecular_formula(form) for form in molecular_formula]
742 0
        except ValueError:
743
            # Probably, the user provided an invalid chemical formula
744 0
            pass
745

746 4
        meta = get_metadata_template()
747

748 4
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash, molecular_formula=molecular_formula)
749

750
        # Don't include the hash or the molecular_formula in the returned result
751 4
        rdata, meta["n_found"] = self.get_query_projection(
752
            MoleculeORM, query, limit=limit, skip=skip, exclude=["molecule_hash", "molecular_formula"]
753
        )
754

755 4
        meta["success"] = True
756

757
        # This is required for sparse molecules as we don't know which values are spase
758
        # We are lucky that None is the default and doesn't mean anything in Molecule
759
        # This strategy does not work for other objects
760 4
        data = []
761 4
        for mol_dict in rdata:
762 4
            mol_dict = {k: v for k, v in mol_dict.items() if v is not None}
763 4
            data.append(Molecule(**mol_dict, validate=False, validated=True))
764

765 4
        return {"meta": meta, "data": data}
766

767 4
    def del_molecules(self, id: List[str] = None, molecule_hash: List[str] = None):
768
        """
769
        Removes a molecule from the database from its hash.
770

771
        Parameters
772
        ----------
773
        id : str or List[str], optional
774
            ids of molecules, can use the hash parameter instead
775
        molecule_hash : str or List[str]
776
            The hash of a molecule.
777

778
        Returns
779
        -------
780
        bool
781
            Number of deleted molecules.
782
        """
783

784 4
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash)
785

786 4
        with self.session_scope() as session:
787 4
            ret = session.query(MoleculeORM).filter(*query).delete(synchronize_session=False)
788

789 4
        return ret
790

791
    # ~~~~~~~~~~~~~~~~~~~~~~~ Keywords ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
792

793 4
    def add_keywords(self, keyword_sets: List[KeywordSet]):
794
        """Add one KeywordSet uniquly identified by 'program' and the 'name'.
795

796
        Parameters
797
        ----------
798
        keywords_set : List[KeywordSet]
799
            A list of KeywordSets to be inserted.
800

801
        Returns
802
        -------
803
        Dict[str, Any]
804
            (see add_metadata_template())
805
            The 'data' part is a list of ids of the inserted options
806
            data['duplicates'] has the duplicate entries
807

808
        Notes
809
        ------
810
            Duplicates are not considered errors.
811

812
        """
813

814 4
        meta = add_metadata_template()
815

816 4
        keywords = []
817 4
        with self.session_scope() as session:
818 4
            for kw in keyword_sets:
819

820 4
                kw_dict = kw.dict(exclude={"id"})
821

822
                # search by index keywords not by all keys, much faster
823 4
                found = session.query(KeywordsORM).filter_by(hash_index=kw_dict["hash_index"]).first()
824 4
                if not found:
825 4
                    doc = KeywordsORM(**kw_dict)
826 4
                    session.add(doc)
827 4
                    session.commit()
828 4
                    keywords.append(str(doc.id))
829 4
                    meta["n_inserted"] += 1
830
                else:
831 4
                    meta["duplicates"].append(str(found.id))  # TODO
832 4
                    keywords.append(str(found.id))
833 4
                meta["success"] = True
834

835 4
        ret = {"data": keywords, "meta": meta}
836

837 4
        return ret
838

839 4
    def get_keywords(
840
        self,
841
        id: Union[str, list] = None,
842
        hash_index: Union[str, list] = None,
843
        limit: int = None,
844
        skip: int = 0,
845
        return_json: bool = False,
846
        with_ids: bool = True,
847
    ) -> List[KeywordSet]:
848
        """Search for one (unique) option based on the 'program'
849
        and the 'name'. No overwrite allowed.
850

851
        Parameters
852
        ----------
853
        id : List[str] or str
854
            Ids of the keywords
855
        hash_index : List[str] or str
856
            hash index of keywords
857
        limit : Optional[int], optional
858
            Maximum number of results to return.
859
            If this number is greater than the SQLAlchemySocket.max_limit then
860
            the max_limit will be returned instead.
861
            Default is to return the socket's max_limit (when limit=None or 0)
862
        skip : int, optional
863
        return_json : bool, optional
864
            Return the results as a json object
865
            Default is True
866
        with_ids : bool, optional
867
            Include the DB ids in the returned object (names 'id')
868
            Default is True
869

870

871
        Returns
872
        -------
873
            A dict with keys: 'data' and 'meta'
874
            (see get_metadata_template())
875
            The 'data' part is an object of the result or None if not found
876
        """
877

878 4
        meta = get_metadata_template()
879 4
        query = format_query(KeywordsORM, id=id, hash_index=hash_index)
880

881 4
        rdata, meta["n_found"] = self.get_query_projection(
882
            KeywordsORM, query, limit=limit, skip=skip, exclude=[None if with_ids else "id"]
883
        )
884

885 4
        meta["success"] = True
886

887
        # meta['error_description'] = str(err)
888

889 4
        if not return_json:
890 4
            data = [KeywordSet(**d) for d in rdata]
891
        else:
892 0
            data = rdata
893

894 4
        return {"data": data, "meta": meta}
895

896 4
    def get_add_keywords_mixed(self, data):
897
        """
898
        Get or add the given options (if they don't exit).
899
        KeywordsORM are given in a mixed format, either as a dict of mol data
900
        or as existing mol id
901

902
        TODO: to be split into get by_id and get_by_data
903
        """
904

905 4
        meta = get_metadata_template()
906

907 4
        ids = []
908 4
        for idx, kw in enumerate(data):
909 4
            if isinstance(kw, (int, str)):
910 4
                ids.append(kw)
911

912 4
            elif isinstance(kw, KeywordSet):
913 4
                new_id = self.add_keywords([kw])["data"][0]
914 4
                ids.append(new_id)
915
            else:
916 0
                meta["errors"].append((idx, "Data type not understood"))
917 0
                ids.append(None)
918

919 4
        missing = []
920 4
        ret = []
921 4
        for idx, id in enumerate(ids):
922 4
            if id is None:
923 0
                ret.append(None)
924 0
                missing.append(idx)
925 0
                continue
926

927 4
            tmp = self.get_keywords(id=id)["data"]
928 4
            if tmp:
929 4
                ret.append(tmp[0])
930
            else:
931 4
                ret.append(None)
932

933 4
        meta["success"] = True
934 4
        meta["n_found"] = len(ret) - len(missing)
935 4
        meta["missing"] = missing
936

937 4
        return {"meta": meta, "data": ret}
938

939 4
    def del_keywords(self, id: str) -> int:
940
        """
941
        Removes a option set from the database based on its id.
942

943
        Parameters
944
        ----------
945
        id : str
946
            id of the keyword
947

948
        Returns
949
        -------
950
        int
951
           number of deleted documents
952
        """
953

954 4
        count = 0
955 4
        with self.session_scope() as session:
956 4
            count = session.query(KeywordsORM).filter_by(id=id).delete(synchronize_session=False)
957

958 4
        return count
959

960
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`
961

962
    ### database functions
963

964 4
    def add_collection(self, data: Dict[str, Any], overwrite: bool = False):
965
        """Add (or update) a collection to the database.
966

967
        Parameters
968
        ----------
969
        data : Dict[str, Any]
970
            should inlcude at least(keys):
971
            collection : str (immutable)
972
            name : str (immutable)
973

974
        overwrite : bool
975
            Update existing collection
976

977
        Returns
978
        -------
979
        Dict[str, Any]
980
        A dict with keys: 'data' and 'meta'
981
            (see add_metadata_template())
982
            The 'data' part is the id of the inserted document or none
983

984
        Notes
985
        -----
986
        ** Change: The data doesn't have to include the ID, the document
987
        is identified by the (collection, name) pairs.
988
        ** Change: New fields will be added to the collection, but existing won't
989
            be removed.
990
        """
991

992 4
        meta = add_metadata_template()
993 4
        col_id = None
994
        # try:
995

996
        # if ("id" in data) and (data["id"] == "local"):
997
        #     data.pop("id", None)
998 4
        if "id" in data:  # remove the ID in any case
999 4
            data.pop("id", None)
1000 4
        lname = data.get("name").lower()
1001 4
        collection = data.pop("collection").lower()
1002

1003
        # Get collection class if special type is implemented
1004 4
        collection_class = get_collection_class(collection)
1005

1006 4
        update_fields = {}
1007 4
        for field in collection_class._all_col_names():
1008 4
            if field in data:
1009 4
                update_fields[field] = data.pop(field)
1010

1011 4
        update_fields["extra"] = data  # todo: check for sql injection
1012

1013 4
        with self.session_scope() as session:
1014

1015 4
            try:
1016 4
                if overwrite:
1017 4
                    col = session.query(collection_class).filter_by(collection=collection, lname=lname).first()
1018 4
                    for key, value in update_fields.items():
1019 4
                        setattr(col, key, value)
1020
                else:
1021 4
                    col = collection_class(collection=collection, lname=lname, **update_fields)
1022

1023 4
                session.add(col)
1024 4
                session.commit()
1025 4
                col.update_relations(**update_fields)
1026 4
                session.commit()
1027

1028 4
                col_id = str(col.id)
1029 4
                meta["success"] = True
1030 4
                meta["n_inserted"] = 1
1031

1032 4
            except Exception as err:
1033 4
                session.rollback()
1034 4
                meta["error_description"] = str(err)
1035

1036 4
        ret = {"data": col_id, "meta": meta}
1037 4
        return ret
1038

1039 4
    def get_collections(
1040
        self,
1041
        collection: Optional[str] = None,
1042
        name: Optional[str] = None,
1043
        col_id: Optional[int] = None,
1044
        limit: Optional[int] = None,
1045
        include: Optional[List[str]] = None,
1046
        exclude: Optional[List[str]] = None,
1047
        skip: int = 0,
1048
    ) -> Dict[str, Any]:
1049
        """Get collection by collection and/or name
1050

1051
        Parameters
1052
        ----------
1053
        collection: Optional[str], optional
1054
            Type of the collection, e.g. ReactionDataset
1055
        name: Optional[str], optional
1056
            Name of the collection, e.g. S22
1057
        col_id: Optional[int], optional
1058
            Database id of the collection
1059
        limit: Optional[int], optional
1060
            Maximum number of results to return
1061
        include: Optional[List[str]], optional
1062
            Columns to return
1063
        exclude: Optional[List[str]], optional
1064
            Return all but these columns
1065
        skip: int, optional
1066
            Skip the first `skip` results
1067

1068
        Returns
1069
        -------
1070
        Dict[str, Any]
1071
            A dict with keys: 'data' and 'meta'
1072
            The data is a list of the collections found
1073
        """
1074

1075 4
        meta = get_metadata_template()
1076 4
        if name:
1077 4
            name = name.lower()
1078 4
        if collection:
1079 4
            collection = collection.lower()
1080

1081 4
        collection_class = get_collection_class(collection)
1082 4
        query = format_query(collection_class, lname=name, collection=collection, id=col_id)
1083

1084
        # try:
1085 4
        rdata, meta["n_found"] = self.get_query_projection(
1086
            collection_class, query, include=include, exclude=exclude, limit=limit, skip=skip
1087
        )
1088

1089 4
        meta["success"] = True
1090
        # except Exception as err:
1091
        #     meta['error_description'] = str(err)
1092

1093 4
        return {"data": rdata, "meta": meta}
1094

1095 4
    def del_collection(
1096
        self, collection: Optional[str] = None, name: Optional[str] = None, col_id: Optional[int] = None
1097
    ) -> bool:
1098
        """
1099
        Remove a collection from the database from its keys.
1100

1101
        Parameters
1102
        ----------
1103
        collection: Optional[str], optional
1104
            CollectionORM type
1105
        name : Optional[str], optional
1106
            CollectionORM name
1107
        col_id: Optional[int], optional
1108
            Database id of the collection
1109
        Returns
1110
        -------
1111
        int
1112
            Number of documents deleted
1113
        """
1114

1115
        # Assuming here that we don't want to allow deletion of all collections, all datasets, etc.
1116 4
        if not (col_id is not None or (collection is not None and name is not None)):
1117 0
            raise ValueError(
1118
                "Either col_id ({col_id}) must be specified, or collection ({collection}) and name ({name}) must be specified."
1119
            )
1120

1121 4
        filter_spec = {}
1122 4
        if collection is not None:
1123 4
            filter_spec["collection"] = collection.lower()
1124 4
        if name is not None:
1125 4
            filter_spec["lname"] = name.lower()
1126 4
        if col_id is not None:
1127 4
            filter_spec["id"] = col_id
1128

1129 4
        with self.session_scope() as session:
1130 4
            count = session.query(CollectionORM).filter_by(**filter_spec).delete(synchronize_session=False)
1131 4
        return count
1132

1133
    ## ResultORMs functions
1134

1135 4
    def add_results(self, record_list: List[ResultRecord]):
1136
        """
1137
        Add results from a given dict. The dict should have all the required
1138
        keys of a result.
1139

1140
        Parameters
1141
        ----------
1142
        data : List[ResultRecord]
1143
            Each dict in the list must have:
1144
            program, driver, method, basis, options, molecule
1145
            Where molecule is the molecule id in the DB
1146
            In addition, it should have the other attributes that it needs
1147
            to store
1148

1149
        Returns
1150
        -------
1151
        Dict[str, Any]
1152
            Dict with keys: data, meta
1153
            Data is the ids of the inserted/updated/existing docs, in the same order as the
1154
            input record_list
1155
        """
1156

1157 4
        meta = add_metadata_template()
1158

1159 4
        results_list = []
1160 4
        duplicates_list = []
1161

1162
        # Stores indices referring to elements in record_list
1163 4
        new_record_idx, duplicates_idx = [], []
1164

1165
        # creating condition for a multi-value select
1166
        # This can be used to query for multiple results in a single query
1167 4
        conds = [
1168
            and_(
1169
                ResultORM.program == res.program,
1170
                ResultORM.driver == res.driver,
1171
                ResultORM.method == res.method,
1172
                ResultORM.basis == res.basis,
1173
                ResultORM.keywords == res.keywords,
1174
                ResultORM.molecule == res.molecule,
1175
            )
1176
            for res in record_list
1177
        ]
1178

1179 4
        with self.session_scope() as session:
1180
            # Query for all existing
1181
            # TODO: RACE CONDITION: Records could be inserted between this query and inserting later
1182

1183 4
            existing_results = {}
1184

1185 4
            for cond in conds:
1186 4
                doc = session.query(
1187
                    ResultORM.program,
1188
                    ResultORM.driver,
1189
                    ResultORM.method,
1190
                    ResultORM.basis,
1191
                    ResultORM.keywords,
1192
                    ResultORM.molecule,
1193
                    ResultORM.id
1194
                ).filter(cond).one_or_none()
1195

1196 4
                if doc is not None:
1197 4
                    existing_results[(doc.program, doc.driver, doc.method, doc.basis, doc.keywords, str(doc.molecule))] = doc
1198

1199
            # Loop over all (input) records, keeping track each record's index in the list
1200 4
            for i, result in enumerate(record_list):
1201
                # constructing an index from the record compare against items existing_results
1202 4
                idx = (
1203
                    result.program,
1204
                    result.driver.value,
1205
                    result.method,
1206
                    result.basis,
1207
                    int(result.keywords) if result.keywords else None,
1208
                    result.molecule,
1209
                )
1210

1211 4
                if idx not in existing_results:
1212
                    # Does not exist in the database. Construct a new ResultORM
1213 4
                    doc = ResultORM(**result.dict(exclude={"id"}))
1214

1215
                    # Store in existing_results in case later records are duplicates
1216 4
                    existing_results[idx] = doc
1217

1218
                    # add the object to the list for later adding and committing to database.
1219 4
                    results_list.append(doc)
1220

1221
                    # Store the index of this record (in record_list) as a new_record
1222 4
                    new_record_idx.append(i)
1223 4
                    meta["n_inserted"] += 1
1224
                else:
1225
                    # This result already exists in the database
1226 4
                    doc = existing_results[idx]
1227

1228
                    # Store the index of this record (in record_list) as a new_record
1229 4
                    duplicates_idx.append(i)
1230

1231
                    # Store the entire object. Since this may be a duplicate of a record
1232
                    # added in a previous iteration of the loop, and the data hasn't been added/committed
1233
                    # to the database, the id may not be known here
1234 4
                    duplicates_list.append(doc)
1235

1236 4
            session.add_all(results_list)
1237 4
            session.commit()
1238

1239
            # At this point, all ids should be known. So store only the ids in the returned metadata
1240 4
            meta["duplicates"] = [str(doc.id) for doc in duplicates_list]
1241

1242
            # Construct the ID list to return (in the same order as the input data)
1243
            # Use a placeholder for all, and we will fill later
1244 4
            result_ids = [None] * len(record_list)
1245

1246
            # At this point:
1247
            #     results_list: ORM objects for all newly-added results
1248
            #     new_record_idx: indices (referring to record_list) of newly-added results
1249
            #     duplicates_idx: indices (referring to record_list) of results that already existed
1250
            #
1251
            # results_list and new_record_idx are in the same order
1252
            # (ie, the index stored at new_record_idx[0] refers to some element of record_list. That
1253
            # newly-added ResultORM is located at results_list[0])
1254
            #
1255
            # Similarly, duplicates_idx and meta["duplicates"] are in the same order
1256

1257 4
            for idx, new_result in zip(new_record_idx, results_list):
1258 4
                result_ids[idx] = str(new_result.id)
1259

1260
            # meta["duplicates"] only holds ids at this point
1261 4
            for idx, existing_result_id in zip(duplicates_idx, meta["duplicates"]):
1262 4
                result_ids[idx] = existing_result_id
1263

1264 4
        assert None not in result_ids
1265

1266 4
        meta["success"] = True
1267

1268 4
        ret = {"data": result_ids, "meta": meta}
1269 4
        return ret
1270

1271 4
    def update_results(self, record_list: List[ResultRecord]):
1272
        """
1273
        Update results from a given dict (replace existing)
1274

1275
        Parameters
1276
        ----------
1277
        id : list of str
1278
            Ids of the results to update, must exist in the DB
1279
        data : list of dict
1280
            Data that needs to be updated
1281
            Shouldn't update:
1282
            program, driver, method, basis, options, molecule
1283

1284
        Returns
1285
        -------
1286
            number of records updated
1287
        """
1288 2
        query_ids = [res.id for res in record_list]
1289
        # find duplicates among ids
1290 2
        duplicates = len(query_ids) != len(set(query_ids))
1291

1292 2
        with self.session_scope() as session:
1293

1294 2
            found = session.query(ResultORM).filter(ResultORM.id.in_(query_ids)).all()
1295
            # found items are stored in a dictionary
1296 2
            found_dict = {str(record.id): record for record in found}
1297

1298 2
            updated_count = 0
1299 2
            for result in record_list:
1300

1301 2
                if result.id is None:
1302 0
                    self.logger.error("Attempted update without ID, skipping")
1303 0
                    continue
1304

1305 2
                data = result.dict(exclude={"id"})
1306
                # retrieve the found item
1307 2
                found_db = found_dict[result.id]
1308

1309
                # updating the found item with input attribute values.
1310 2
                for attr, val in data.items():
1311 2
                    setattr(found_db, attr, val)
1312

1313
                # if any duplicate ids are found in the input, commit should be called each iteration
1314 2
                if duplicates:
1315 0
                    session.commit()
1316 2
                updated_count += 1
1317
            # if no duplicates found, only commit at the end of the loop.
1318 2
            if not duplicates:
1319 2
                session.commit()
1320

1321 2
        return updated_count
1322

1323 4
    def get_results_count(self):
1324
        """
1325
        TODO: just return the count, used for big queries
1326

1327
        Returns
1328
        -------
1329

1330
        """
1331

1332 4
    def get_results(
1333
        self,
1334
        id: Union[str, List] = None,
1335
        program: str = None,
1336
        method: str = None,
1337
        basis: str = None,
1338
        molecule: str = None,
1339
        driver: str = None,
1340
        keywords: str = None,
1341
        task_id: Union[str, List] = None,
1342
        manager_id: Union[str, List] = None,
1343
        status: str = "COMPLETE",
1344
        include: Optional[List[str]] = None,
1345
        exclude: Optional[List[str]] = None,
1346
        limit: int = None,
1347
        skip: int = 0,
1348
        return_json=True,
1349
        with_ids=True,
1350
    ):
1351
        """
1352

1353
        Parameters
1354
        ----------
1355
        id : str or List[str]
1356
        program : str
1357
        method : str
1358
        basis : str
1359
        molecule : str
1360
            MoleculeORM id in the DB
1361
        driver : str
1362
        keywords : str
1363
            The id of the option in the DB
1364
        task_id: str or List[str]
1365
            id or a list of ids of tasks
1366
        manager_id: str or List[str]
1367
            id or a list of ids of queue_mangers
1368
        status : bool, optional 
1369
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1370
            Default is 'COMPLETE'
1371
        include : Optional[List[str]], optional
1372
            The fields to return, default to return all
1373
        exclude : Optional[List[str]], optional
1374
            The fields to not return, default to return all
1375
        limit : Optional[int], optional
1376
            maximum number of results to return
1377
            if 'limit' is greater than the global setting self._max_limit,
1378
            the self._max_limit will be returned instead
1379
            (This is to avoid overloading the server)
1380
        skip : int, optional
1381
            skip the first 'skip' results. Used to paginate
1382
            Default is 0
1383
        return_json : bool, optional 
1384
            Return the results as a list of json inseated of objects
1385
            default is True
1386
        with_ids : bool, optional
1387
            Include the ids in the returned objects/dicts
1388
            default is True
1389

1390
        Returns
1391
        -------
1392
        Dict[str, Any]
1393
            Dict with keys: data, meta
1394
            Data is the objects found
1395
        """
1396

1397 4
        if task_id:
1398 0
            return self._get_results_by_task_id(task_id)
1399

1400 4
        meta = get_metadata_template()
1401

1402
        # Ignore status if Id is present
1403 4
        if id is not None:
1404 4
            status = None
1405

1406 4
        query = format_query(
1407
            ResultORM,
1408
            id=id,
1409
            program=program,
1410
            method=method,
1411
            basis=basis,
1412
            molecule=molecule,
1413
            driver=driver,
1414
            keywords=keywords,
1415
            manager_id=manager_id,
1416
            status=status,
1417
        )
1418

1419 4
        data, meta["n_found"] = self.get_query_projection(
1420
            ResultORM, query, include=include, exclude=exclude, limit=limit, skip=skip
1421
        )
1422 4
        meta["success"] = True
1423

1424 4
        return {"data": data, "meta": meta}
1425

1426 4
    def _get_results_by_task_id(self, task_id: Union[str, List] = None, return_json=True):
1427
        """
1428

1429
        Parameters
1430
        ----------
1431
        task_id : str or List[str]
1432

1433
        return_json : bool, optional
1434
            Return the results as a list of json inseated of objects
1435
            Default is True
1436

1437
        Returns
1438
        -------
1439
        Dict[str, Any]
1440
            Dict with keys: data, meta
1441
            Data is the objects found
1442
        """
1443

1444 0
        meta = get_metadata_template()
1445

1446 0
        data = []
1447 0
        task_id_list = [task_id] if isinstance(task_id, (int, str)) else task_id
1448
        # try:
1449 0
        with self.session_scope() as session:
1450 0
            data = (
1451
                session.query(BaseResultORM)
1452
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
1453
                .filter(TaskQueueORM.id.in_(task_id_list))
1454
            )
1455 0
            meta["n_found"] = get_count_fast(data)
1456 0
            data = [d.to_dict() for d in data.all()]
1457 0
            meta["success"] = True
1458
            # except Exception as err:
1459
            #     meta['error_description'] = str(err)
1460

1461 0
        return {"data": data, "meta": meta}
1462

1463 4
    def del_results(self, ids: List[str]):
1464
        """
1465
        Removes results from the database using their ids
1466
        (Should be cautious! other tables maybe referencing results)
1467

1468
        Parameters
1469
        ----------
1470
        ids : List[str]
1471
            The Ids of the results to be deleted
1472

1473
        Returns
1474
        -------
1475
        int
1476
            number of results deleted
1477
        """
1478

1479 4
        with self.session_scope() as session:
1480 4
            results = session.query(ResultORM).filter(ResultORM.id.in_(ids)).all()
1481
            # delete through session to delete correctly from base_result
1482 4
            for result in results:
1483 4
                session.delete(result)
1484 4
            session.commit()
1485 4
            count = len(results)
1486

1487 4
        return count
1488

1489 4
    def add_wavefunction_store(self, blobs_list: List[Dict[str, Any]]):
1490
        """
1491
        Adds to the wavefunction key/value store table.
1492

1493
        Parameters
1494
        ----------
1495
        blobs_list : List[Dict[str, Any]]
1496
            A list of wavefunction data blobs to add.
1497

1498
        Returns
1499
        -------
1500
        Dict[str, Any]
1501
            Dict with keys data and meta, where data represent the blob_ids of inserted wavefuction data blobs.
1502
        """
1503

1504 1
        meta = add_metadata_template()
1505 1
        blob_ids = []
1506 1
        with self.session_scope() as session:
1507 1
            for blob in blobs_list:
1508 1
                if blob is None:
1509 0
                    blob_ids.append(None)
1510 0
                    continue
1511

1512 1
                doc = WavefunctionStoreORM(**blob)
1513 1
                session.add(doc)
1514 1
                session.commit()
1515 1
                blob_ids.append(str(doc.id))
1516 1
                meta["n_inserted"] += 1
1517

1518 1
        meta["success"] = True
1519

1520 1
        return {"data": blob_ids, "meta": meta}
1521

1522 4
    def get_wavefunction_store(
1523
        self,
1524
        id: List[str] = None,
1525
        include: Optional[List[str]] = None,
1526
        exclude: Optional[List[str]] = None,
1527
        limit: int = None,
1528
        skip: int = 0,
1529
    ) -> Dict[str, Any]:
1530
        """
1531
        Pulls from the wavefunction key/value store table.
1532

1533
        Parameters
1534
        ----------
1535
        id : List[str], optional
1536
            A list of ids to query
1537
        include : Optional[List[str]], optional
1538
            The fields to return, default to return all
1539
        exclude : Optional[List[str]], optional
1540
            The fields to not return, default to return all
1541
        limit : int, optional
1542
            Maximum number of results to return.
1543
            Default is set to 0
1544
        skip : int, optional
1545
            Skips a number of results in the query, used for pagination
1546
            Default is set to 0
1547

1548
        Returns
1549
        -------
1550
        Dict[str, Any]
1551
            Dictionary with keys data and meta, where data is the found wavefunction items
1552
        """
1553

1554 1
        meta = get_metadata_template()
1555

1556 1
        query = format_query(WavefunctionStoreORM, id=id)
1557 1
        rdata, meta["n_found"] = self.get_query_projection(
1558
            WavefunctionStoreORM, query, limit=limit, skip=skip, include=include, exclude=exclude
1559
        )
1560

1561 1
        meta["success"] = True
1562

1563 1
        return {"data": rdata, "meta": meta}
1564

1565
    ### Mongo procedure/service functions
1566

1567 4
    def add_procedures(self, record_list: List["BaseRecord"]):
1568
        """
1569
        Add procedures from a given dict. The dict should have all the required
1570
        keys of a result.
1571

1572
        Parameters
1573
        ----------
1574
        record_list : List["BaseRecord"]
1575
            Each dict must have:
1576
            procedure, program, keywords, qc_meta, hash_index
1577
            In addition, it should have the other attributes that it needs
1578
            to store
1579

1580
        Returns
1581
        -------
1582
        Dict[str, Any]
1583
            Dictionary with keys data and meta, data is the ids of the inserted/updated/existing docs
1584
        """
1585

1586 4
        meta = add_metadata_template()
1587

1588 4
        if not record_list:
1589 0
            return {"data": [], "meta": meta}
1590

1591 4
        procedure_class = get_procedure_class(record_list[0])
1592

1593 4
        procedure_ids = []
1594 4
        with self.session_scope() as session:
1595 4
            for procedure in record_list:
1596 4
                doc = session.query(procedure_class).filter_by(hash_index=procedure.hash_index)
1597

1598 4
                if get_count_fast(doc) == 0:
1599 4
                    data = procedure.dict(exclude={"id"})
1600 4
                    proc_db = procedure_class(**data)
1601 4
                    session.add(proc_db)
1602 4
                    session.commit()
1603 4
                    proc_db.update_relations(**data)
1604 4
                    session.commit()
1605 4
                    procedure_ids.append(str(proc_db.id))
1606 4
                    meta["n_inserted"] += 1
1607
                else:
1608 1
                    id = str(doc.first().id)
1609 1
                    meta["duplicates"].append(id)  # TODO
1610 1
                    procedure_ids.append(id)
1611 4
        meta["success"] = True
1612

1613 4
        ret = {"data": procedure_ids, "meta": meta}
1614 4
        return ret
1615

1616 4
    def get_procedures(
1617
        self,
1618
        id: Union[str, List] = None,
1619
        procedure: str = None,
1620
        program: str = None,
1621
        hash_index: str = None,
1622
        task_id: Union[str, List] = None,
1623
        manager_id: Union[str, List] = None,
1624
        status: str = "COMPLETE",
1625
        include=None,
1626
        exclude=None,
1627
        limit: int = None,
1628
        skip: int = 0,
1629
        return_json=True,
1630
        with_ids=True,
1631
    ):
1632
        """
1633

1634
        Parameters
1635
        ----------
1636
        id : str or List[str]
1637
        procedure : str
1638
        program : str
1639
        hash_index : str
1640
        task_id : str or List[str]
1641
        status : bool, optional
1642
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1643
            Default is 'COMPLETE'
1644
        include : Optional[List[str]], optional
1645
            The fields to return, default to return all
1646
        exclude : Optional[List[str]], optional
1647
            The fields to not return, default to return all
1648
        limit : Optional[int], optional
1649
            maximum number of results to return
1650
            if 'limit' is greater than the global setting self._max_limit,
1651
            the self._max_limit will be returned instead
1652
            (This is to avoid overloading the server)
1653
        skip : int, optional
1654
            skip the first 'skip' resaults. Used to paginate
1655
            Default is 0
1656
        return_json : bool, optional
1657
            Return the results as a list of json inseated of objects
1658
            Default is True
1659
        with_ids : bool, optional
1660
            Include the ids in the returned objects/dicts
1661
            Default is True
1662

1663
        Returns
1664
        -------
1665
        Dict[str, Any]
1666
            Dict with keys: data and meta. Data is the objects found
1667
        """
1668

1669 4
        meta = get_metadata_template()
1670

1671 4
        if id is not None or task_id is not None:
1672 1
            status = None
1673

1674 4
        if procedure == "optimization":
1675 4
            className = OptimizationProcedureORM
1676 4
        elif procedure == "torsiondrive":
1677 4
            className = TorsionDriveProcedureORM
1678 4
        elif procedure == "gridoptimization":
1679 0
            className = GridOptimizationProcedureORM
1680
        else:
1681
            # raise TypeError('Unsupported procedure type {}. Id: {}, task_id: {}'
1682
            #                 .format(procedure, id, task_id))
1683 4
            className = BaseResultORM  # all classes, including those with 'selectin'
1684 4
            program = None  # make sure it's not used
1685 4
            if id is None:
1686 4
                self.logger.error(f"Procedure type not specified({procedure}), and ID is not given.")
1687 4
                raise KeyError("ID is required if procedure type is not specified.")
1688

1689 4
        query = format_query(
1690
            className,
1691
            id=id,
1692
            procedure=procedure,
1693
            program=program,
1694
            hash_index=hash_index,
1695
            task_id=task_id,
1696
            manager_id=manager_id,
1697
            status=status,
1698
        )
1699

1700 4
        data = []
1701 4
        try:
1702
            # TODO: decide a way to find the right type
1703

1704 4
            data, meta["n_found"] = self.get_query_projection(
1705
                className, query, limit=limit, skip=skip, include=include, exclude=exclude
1706
            )
1707 4
            meta["success"] = True
1708 0
        except Exception as err:
1709 0
            meta["error_description"] = str(err)
1710

1711 4
        return {"data": data, "meta": meta}
1712

1713 4
    def update_procedures(self, records_list: List["BaseRecord"]):
1714
        """
1715
        TODO: needs to be of specific type
1716
        """
1717

1718 4
        updated_count = 0
1719 4
        with self.session_scope() as session:
1720 4
            for procedure in records_list:
1721

1722 4
                className = get_procedure_class(procedure)
1723
                # join_table = get_procedure_join(procedure)
1724
                # Must have ID
1725 4
                if procedure.id is None:
1726 0
                    self.logger.error(
1727
                        "No procedure id found on update (hash_index={}), skipping.".format(procedure.hash_index)
1728
                    )
1729 0
                    continue
1730

1731 4
                proc_db = session.query(className).filter_by(id=procedure.id).first()
1732

1733 4
                data = procedure.dict(exclude={"id"})
1734 4
                proc_db.update_relations(**data)
1735

1736 4
                for attr, val in data.items():
1737 4
                    setattr(proc_db, attr, val)
1738

1739
                # session.add(proc_db)
1740

1741
                # Upsert relations (insert or update)
1742
                # needs primarykeyconstraint on the table keys
1743
                # for result_id in procedure.trajectory:
1744
                #     statement = postgres_insert(opt_result_association)\
1745
                #         .values(opt_id=procedure.id, result_id=result_id)\
1746
                #         .on_conflict_do_update(
1747
                #             index_elements=[opt_result_association.c.opt_id, opt_result_association.c.result_id],
1748
                #             set_=dict(result_id=result_id))
1749
                #     session.execute(statement)
1750

1751 4
                session.commit()
1752 4
                updated_count += 1
1753

1754
        # session.commit()  # save changes, takes care of inheritance
1755

1756 4
        return updated_count
1757

1758 4
    def del_procedures(self, ids: List[str]):
1759
        """
1760
        Removes results from the database using their ids
1761
        (Should be cautious! other tables maybe referencing results)
1762

1763
        Parameters
1764
        ----------
1765
        ids : List[str]
1766
            The Ids of the results to be deleted
1767

1768
        Returns
1769
        -------
1770
        int
1771
            number of results deleted
1772
        """
1773

1774 4
        with self.session_scope() as session:
1775 4
            procedures = (
1776
                session.query(
1777
                    with_polymorphic(
1778
                        BaseResultORM,
1779
                        [OptimizationProcedureORM, TorsionDriveProcedureORM, GridOptimizationProcedureORM],
1780
                    )
1781
                )
1782
                .filter(BaseResultORM.id.in_(ids))
1783
                .all()
1784
            )
1785
            # delete through session to delete correctly from base_result
1786 4
            for proc in procedures:
1787 4
                session.delete(proc)
1788
            # session.commit()
1789 4
            count = len(procedures)
1790

1791 4
        return count
1792

1793 4
    def add_services(self, service_list: List["BaseService"]):
1794
        """
1795
        Add services from a given list of dict.
1796

1797
        Parameters
1798
        ----------
1799
        services_list : List[Dict[str, Any]]
1800
            List of services to be added
1801
        Returns
1802
        -------
1803
        Dict[str, Any]
1804
            Dict with keys: data, meta. Data is the hash_index of the inserted/existing docs
1805
        """
1806

1807 4
        meta = add_metadata_template()
1808

1809 4
        procedure_ids = []
1810 4
        with self.session_scope() as session:
1811 4
            for service in service_list:
1812

1813
                # Add the underlying procedure
1814 4
                new_procedure = self.add_procedures([service.output])
1815

1816
                # ProcedureORM already exists
1817 4
                proc_id = new_procedure["data"][0]
1818

1819 4
                if new_procedure["meta"]["duplicates"]:
1820 1
                    procedure_ids.append(proc_id)
1821 1
                    meta["duplicates"].append(proc_id)
1822 1
                    continue
1823

1824
                # search by hash index
1825 4
                doc = session.query(ServiceQueueORM).filter_by(hash_index=service.hash_index)
1826 4
                service.procedure_id = proc_id
1827

1828 4
                if doc.count() == 0:
1829 4
                    doc = ServiceQueueORM(**service.dict(include=set(ServiceQueueORM.__dict__.keys())))
1830 4
                    doc.extra = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1831 4
                    doc.priority = doc.priority.value  # Must be an integer for sorting
1832 4
                    session.add(doc)
1833 4
                    session.commit()  # TODO
1834 4
                    procedure_ids.append(proc_id)
1835 4
                    meta["n_inserted"] += 1
1836
                else:
1837 0
                    procedure_ids.append(None)
1838 0
                    meta["errors"].append((doc.id, "Duplicate service, but not caught by procedure."))
1839

1840 4
        meta["success"] = True
1841

1842 4
        ret = {"data": procedure_ids, "meta": meta}
1843 4
        return ret
1844

1845 4
    def get_services(
1846
        self,
1847
        id: Union[List[str], str] = None,
1848
        procedure_id: Union[List[str], str] = None,
1849
        hash_index: Union[List[str], str] = None,
1850
        status: str = None,
1851
        limit: int = None,
1852
        skip: int = 0,
1853
        return_json=True,
1854
    ):
1855
        """
1856

1857
        Parameters
1858
        ----------
1859
        id / hash_index : List[str] or str, optional
1860
            service id
1861
        procedure_id : List[str] or str, optional
1862
            procedure_id for the specific procedure
1863
        status : str, optional
1864
            status of the record queried for
1865
        limit : Optional[int], optional
1866
            maximum number of results to return
1867
            if 'limit' is greater than the global setting self._max_limit,
1868
            the self._max_limit will be returned instead
1869
            (This is to avoid overloading the server)
1870
        skip : int, optional
1871
            skip the first 'skip' resaults. Used to paginate
1872
            Default is 0
1873
        return_json : bool, deafult is True
1874
            Return the results as a list of json instead of objects
1875

1876
        Returns
1877
        -------
1878
        Dict[str, Any]
1879
            Dict with keys: data, meta. Data is the objects found
1880
        """
1881

1882 4
        meta = get_metadata_template()
1883 4
        query = format_query(ServiceQueueORM, id=id, hash_index=hash_index, procedure_id=procedure_id, status=status)
1884

1885 4
        with self.session_scope() as session:
1886 4
            data = (
1887
                session.query(ServiceQueueORM)
1888
                .filter(*query)
1889
                .order_by(ServiceQueueORM.priority.desc(), ServiceQueueORM.created_on)
1890
                .limit(limit)
1891
                .offset(skip)
1892
                .all()
1893
            )
1894 4
            data = [x.to_dict() for x in data]
1895

1896 4
        meta["n_found"] = len(data)
1897 4
        meta["success"] = True
1898

1899
        # except Exception as err:
1900
        #     meta['error_description'] = str(err)
1901

1902 4
        return {"data": data, "meta": meta}
1903

1904 4
    def update_services(self, records_list: List["BaseService"]) -> int:
1905
        """
1906
        Replace existing service
1907

1908
        Raises exception if the id is invalid
1909

1910
        Parameters
1911
        ----------
1912
        records_list: List[Dict[str, Any]]
1913
            List of Service items to be updated using their id
1914

1915
        Returns
1916
        -------
1917
        int
1918
            number of updated services
1919
        """
1920

1921 4
        updated_count = 0
1922 4
        for service in records_list:
1923 4
            if service.id is None:
1924 0
                self.logger.error("No service id found on update (hash_index={}), skipping.".format(service.hash_index))
1925 0
                continue
1926

1927 4
            with self.session_scope() as session:
1928

1929 4
                doc_db = session.query(ServiceQueueORM).filter_by(id=service.id).first()
1930

1931 4
                data = service.dict(include=set(ServiceQueueORM.__dict__.keys()))
1932 4
                data["extra"] = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1933

1934 4
                data["id"] = int(data["id"])
1935 4
                for attr, val in data.items():
1936 4
                    setattr(doc_db, attr, val)
1937

1938 4
                session.add(doc_db)
1939 4
                session.commit()
1940

1941 4
            procedure = service.output
1942 4
            procedure.__dict__["id"] = service.procedure_id
1943 4
            self.update_procedures([procedure])
1944

1945 4
            updated_count += 1
1946

1947 4
        return updated_count
1948

1949 4
    def update_service_status(
1950
        self, status: str, id: Union[List[str], str] = None, procedure_id: Union[List[str], str] = None
1951
    ) -> int:
1952
        """
1953
        Update the status of the existing services in the database.
1954

1955
        Raises an exception if any of the ids are invalid.
1956
        Parameters
1957
        ----------
1958
        status : str
1959
            The input status string ready to replace the previous status
1960
        id : Optional[Union[List[str], str]], optional
1961
            ids of all the services requested to be updated, by default None
1962
        procedure_id : Optional[Union[List[str], str]], optional
1963
            procedure_ids for the specific procedures, by default None
1964

1965
        Returns
1966
        -------
1967
        int
1968
            1 indicating that the status update was successful
1969
        """
1970

1971 1
        if (id is None) and (procedure_id is None):
1972 0
            raise KeyError("id or procedure_id must not be None.")
1973

1974 1
        status = status.lower()
1975 1
        with self.session_scope() as session:
1976

1977 1
            query = format_query(ServiceQueueORM, id=id, procedure_id=procedure_id)
1978

1979
            # Update the service
1980 1
            service = session.query(ServiceQueueORM).filter(*query).first()
1981 1
            service.status = status
1982

1983
            # Update the procedure
1984 1
            if status == "waiting":
1985 0
                status = "incomplete"
1986 1
            session.query(BaseResultORM).filter(BaseResultORM.id == service.procedure_id).update({"status": status})
1987

1988 1
            session.commit()
1989

1990 1
        return 1
1991

1992 4
    def services_completed(self, records_list: List["BaseService"]) -> int:
1993
        """
1994
        Delete the services which are completed from the database. 
1995
        
1996
        Parameters
1997
        ----------
1998
        records_list : List["BaseService"]
1999
            List of Service objects which are completed.
2000
        
2001
        Returns
2002
        -------
2003
        int
2004
            Number of deleted active services from database.
2005
        """
2006 2
        done = 0
2007 2
        for service in records_list:
2008 1
            if service.id is None:
2009 0
                self.logger.error(
2010
                    "No service id found on completion (hash_index={}), skipping.".format(service.hash_index)
2011
                )
2012 0
                continue
2013

2014
            # in one transaction
2015 1
            with self.session_scope() as session:
2016

2017 1
                procedure = service.output
2018 1
                procedure.__dict__["id"] = service.procedure_id
2019 1
                self.update_procedures([procedure])
2020

2021 1
                session.query(ServiceQueueORM).filter_by(id=service.id).delete()  # synchronize_session=False)
2022

2023 1
            done += 1
2024

2025 2
        return done
2026

2027
    ### Mongo queue handling functions
2028

2029 4
    def queue_submit(self, data: List[TaskRecord]):
2030
        """Submit a list of tasks to the queue.
2031
        Tasks are unique by their base_result, which should be inserted into
2032
        the DB first before submitting it's corresponding task to the queue
2033
        (with result.status='INCOMPLETE' as the default)
2034
        The default task.status is 'WAITING'
2035

2036
        Duplicate tasks sould be a rare case.
2037
        Hooks are merged if the task already exists
2038

2039
        Parameters
2040
        ----------
2041
        data : List[TaskRecord]
2042
            A task is a dict, with the following fields:
2043
            - hash_index: idx, not used anymore
2044
            - spec: dynamic field (dict-like), can have any structure
2045
            - tag: str
2046
            - base_results: tuple (required), first value is the class type
2047
             of the result, {'results' or 'procedure'). The second value is
2048
             the ID of the result in the DB. Example:
2049
             "base_result": ('results', result_id)
2050

2051
        Returns
2052
        -------
2053
        Dict[str, Any]
2054
            Dictionary with keys data and meta.
2055
            'data' is a list of the IDs of the tasks IN ORDER, including
2056
            duplicates. An errored task has 'None' in its ID
2057
            meta['duplicates'] has the duplicate tasks
2058
        """
2059

2060 4
        meta = add_metadata_template()
2061

2062 4
        results = ["placeholder"] * len(data)
2063

2064 4
        with self.session_scope() as session:
2065
            # preserving all the base results for later check
2066 4
            all_base_results = [record.base_result for record in data]
2067 4
            query_res = (
2068
                session.query(TaskQueueORM.id, TaskQueueORM.base_result_id)
2069
                .filter(TaskQueueORM.base_result_id.in_(all_base_results))
2070
                .all()
2071
            )
2072

2073
            # constructing a dict of found tasks and their ids
2074 4
            found_dict = {str(base_result_id): str(task_id) for task_id, base_result_id in query_res}
2075 4
            new_tasks, new_idx = [], []
2076 4
            duplicate_idx = []
2077 4
            for task_num, record in enumerate(data):
2078

2079 4
                if found_dict.get(record.base_result):
2080
                    # if found, get id from found_dict
2081
                    # Note: found_dict may return a task object because the duplicate id is of an object in the input.
2082 4
                    results[task_num] = found_dict.get(record.base_result)
2083
                    # add index of duplicates
2084 4
                    duplicate_idx.append(task_num)
2085 4
                    meta["duplicates"].append(task_num)
2086

2087
                else:
2088 4
                    task_dict = record.dict(exclude={"id"})
2089 4
                    task = TaskQueueORM(**task_dict)
2090 4
                    new_idx.append(task_num)
2091 4
                    task.priority = task.priority.value
2092
                    # append all the new tasks that should be added
2093 4
                    new_tasks.append(task)
2094
                    # add the (yet to be) inserted object id to dictionary
2095 4
                    found_dict[record.base_result] = task
2096

2097 4
            session.add_all(new_tasks)
2098 4
            session.commit()
2099

2100 4
            meta["n_inserted"] += len(new_tasks)
2101
            # setting the id for new inserted objects, cannot be done before commiting as new objects do not have ids
2102 4
            for i, task_idx in enumerate(new_idx):
2103 4
                results[task_idx] = str(new_tasks[i].id)
2104

2105
            # finding the duplicate items in input, for which ids are found only after insertion
2106 4
            for i in duplicate_idx:
2107 4
                if not isinstance(results[i], str):
2108 4
                    results[i] = str(results[i].id)
2109

2110 4
        meta["success"] = True
2111

2112 4
        ret = {"data": results, "meta": meta}
2113 4
        return ret
2114

2115 4
    def queue_get_next(
2116
        self, manager, available_programs, available_procedures, limit=100, tag=None, as_json=True
2117
    ) -> List[TaskRecord]:
2118
        """Done in a transaction"""
2119

2120
        # Figure out query, tagless has no requirements
2121

2122 4
        proc_filt = TaskQueueORM.procedure.in_([p.lower() for p in available_procedures])
2123 4
        none_filt = TaskQueueORM.procedure == None  # lgtm [py/test-equals-none]
2124

2125 4
        order_by = []
2126 4
        if tag is not None:
2127 2
            if isinstance(tag, str):
2128 2
                tag = [tag]
2129
            # task_order = expression_case([(TaskQueueORM.tag == t, num) for num, t in enumerate(tag)])
2130
            # order_by.append(task_order)
2131

2132 4
        order_by.extend([TaskQueueORM.priority.desc(), TaskQueueORM.created_on])
2133 4
        queries = []
2134 4
        if tag is not None:
2135 2
            for t in tag:
2136 2
                query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs, tag=t)
2137 2
                query.append(or_(proc_filt, none_filt))
2138 2
                queries.append(query)
2139
        else:
2140 4
            query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs)
2141 4
            query.append((or_(proc_filt, none_filt)))
2142 4
            queries.append(query)
2143 4
        new_limit = limit
2144 4
        ids = []
2145 4
        found = []
2146 4
        with self.session_scope() as session:
2147 4
            for q in queries:
2148 4
                if new_limit == 0:
2149 2
                    break
2150 4
                query = session.query(TaskQueueORM).filter(*q).order_by(*order_by).limit(new_limit)
2151
                # from sqlalchemy.dialects import postgresql
2152
                # print(query.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
2153 4
                new_items = query.all()
2154 4
                found.extend(new_items)
2155 4
                new_limit = limit - len(new_items)
2156 4
                ids.extend([x.id for x in new_items])
2157 4
            update_fields = {"status": TaskStatusEnum.running, "modified_on": dt.utcnow(), "manager": manager}
2158
            # # Bulk update operation in SQL
2159 4
            update_count = (
2160
                session.query(TaskQueueORM)
2161
                .filter(TaskQueueORM.id.in_(ids))
2162
                .update(update_fields, synchronize_session=False)
2163
            )
2164

2165 4
            if as_json:
2166
                # avoid another trip to the DB to get the updated values, set them here
2167 4
                found = [TaskRecord(**task.to_dict(exclude=update_fields.keys()), **update_fields) for task in found]
2168 4
            session.commit()
2169

2170 4
        if update_count != len(found):
2171 0
            self.logger.warning("QUEUE: Number of found projects does not match the number of updated projects.")
2172

2173 4
        return found
2174

2175 4
    def get_queue(
2176
        self,
2177
        id=None,
2178
        hash_index=None,
2179
        program=None,
2180
        status: str = None,
2181
        base_result: str = None,
2182
        tag=None,
2183
        manager=None,
2184
        include=None,
2185
        exclude=None,
2186
        limit: int = None,
2187
        skip: int = 0,
2188
        return_json=False,
2189
        with_ids=True,
2190
    ):
2191
        """
2192
        TODO: check what query keys are needs
2193
        Parameters
2194
        ----------
2195
        id : Optional[List[str]], optional
2196
            Ids of the tasks
2197
        Hash_index: Optional[List[str]], optional,
2198
            hash_index of service, not used
2199
        program, list of str or str, optional
2200
        status : Optional[bool], optional (find all)
2201
            The status of the task: 'COMPLETE', 'RUNNING', 'WAITING', or 'ERROR'
2202
        base_result: Optional[str], optional
2203
            base_result id
2204
        include : Optional[List[str]], optional
2205
            The fields to return, default to return all
2206
        exclude : Optional[List[str]], optional
2207
            The fields to not return, default to return all
2208
        limit : Optional[int], optional
2209
            maximum number of results to return
2210
            if 'limit' is greater than the global setting self._max_limit,
2211
            the self._max_limit will be returned instead
2212
            (This is to avoid overloading the server)
2213
        skip : int, optional
2214
            skip the first 'skip' results. Used to paginate, default is 0
2215
        return_json : bool, optional
2216
            Return the results as a list of json inseated of objects, deafult is True
2217
        with_ids : bool, optional 
2218
            Include the ids in the returned objects/dicts, default is True
2219

2220
        Returns
2221
        -------
2222
        Dict[str, Any]
2223
            Dict with keys: data, meta. Data is the objects found
2224
        """
2225

2226 4
        meta = get_metadata_template()
2227 4
        query = format_query(
2228
            TaskQueueORM,
2229
            program=program,
2230
            id=id,
2231
            hash_index=hash_index,
2232
            status=status,
2233
            base_result_id=base_result,
2234
            tag=tag,
2235
            manager=manager,
2236
        )
2237

2238 4
        data = []
2239 4
        try:
2240 4
            data, meta["n_found"] = self.get_query_projection(
2241
                TaskQueueORM, query, limit=limit, skip=skip, include=include, exclude=exclude
2242
            )
2243 4
            meta["success"] = True
2244 0
        except Exception as err:
2245 0
            meta["error_description"] = str(err)
2246

2247 4
        data = [TaskRecord(**task) for task in data]
2248

2249 4
        return {"data": data, "meta": meta}
2250

2251 4
    def queue_get_by_id(self, id: List[str], limit: int = None, skip: int = 0, as_json: bool = True):
2252
        """Get tasks by their IDs
2253

2254
        Parameters
2255
        ----------
2256
        id : List[str]
2257
            List of the task Ids in the DB
2258
        limit : Optional[int], optional
2259
            max number of returned tasks. If limit > max_limit, max_limit
2260
            will be returned instead (safe query)
2261
        skip : int, optional
2262
            skip the first 'skip' results. Used to paginate, default is 0
2263
        as_json : bool, optioanl
2264
            Return tasks as JSON, default is True
2265

2266
        Returns
2267
        -------
2268
        List[TaskRecord]
2269
            List of the found tasks
2270
        """
2271

2272 4
        with self.session_scope() as session:
2273 4
            found = (
2274
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(id)).limit(self.get_limit(limit)).offset(skip)
2275
            )
2276

2277 4
            if as_json:
2278 4
                found = [TaskRecord(**task.to_dict()) for task in found]
2279

2280 4
        return found
2281

2282 4
    def queue_mark_complete(self, task_ids: List[str]) -> int:
2283
        """Update the given tasks as complete
2284
        Note that each task is already pointing to its result location
2285
        Mark the corresponding result/procedure as complete
2286

2287
        Parameters
2288
        ----------
2289
        task_ids : List[str]
2290
            IDs of the tasks to mark as COMPLETE
2291

2292
        Returns
2293
        -------
2294
        int
2295
            number of TaskRecord objects marked as COMPLETE, and deleted from the database consequtively.
2296
        """
2297

2298 4
        if not task_ids:
2299 2
            return 0
2300

2301 4
        update_fields = dict(status=TaskStatusEnum.complete, modified_on=dt.utcnow())
2302 4
        with self.session_scope() as session:
2303
            # assuming all task_ids are valid, then managers will be in order by id
2304 4
            managers = (
2305
                session.query(TaskQueueORM.manager)
2306
                .filter(TaskQueueORM.id.in_(task_ids))
2307
                .order_by(TaskQueueORM.id)
2308
                .all()
2309
            )
2310 4
            managers = [manager[0] if manager else manager for manager in managers]
2311 4
            task_manger_map = {task_id: manager for task_id, manager in zip(sorted(task_ids), managers)}
2312 4
            update_fields[BaseResultORM.manager_name] = case(task_manger_map, value=TaskQueueORM.id)
2313

2314 4
            session.query(BaseResultORM).filter(BaseResultORM.id == TaskQueueORM.base_result_id).filter(
2315
                TaskQueueORM.id.in_(task_ids)
2316
            ).update(update_fields, synchronize_session=False)
2317

2318
            # delete completed tasks
2319 4
            tasks_c = (
2320
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2321
            )
2322

2323 4
        return tasks_c
2324

2325 4
    def queue_mark_error(self, data: List[Tuple[int, str]]):
2326
        """
2327
        update the given tasks as errored
2328
        Mark the corresponding result/procedure as Errored
2329

2330
        Parameters
2331
        ----------
2332
        data : List[Tuple[int, str]]
2333
            List of task ids and their error messages desired to be assigned to them.
2334

2335
        Returns
2336
        -------
2337
        int
2338
            Number of tasks updated as errored.
2339
        """
2340

2341 4
        if not data:
2342 2
            return 0
2343

2344 4
        task_ids = []
2345 4
        with self.session_scope() as session:
2346
            # Make sure returned results are in the same order as the task ids
2347
            # SQL queries change the order when using "in"
2348 4
            data_dict = {item[0]: item[1] for item in data}
2349 4
            sorted_data = {key: data_dict[key] for key in sorted(data_dict.keys())}
2350 4
            task_objects = (
2351
                session.query(TaskQueueORM)
2352
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2353
                .order_by(TaskQueueORM.id)
2354
                .all()
2355
            )
2356 4
            base_results = (
2357
                session.query(BaseResultORM)
2358
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
2359
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2360
                .order_by(TaskQueueORM.id)
2361
                .all()
2362
            )
2363

2364 4
            for (task_id, msg), task_obj, base_result in zip(sorted_data.items(), task_objects, base_results):
2365

2366 4
                task_ids.append(task_id)
2367
                # update task
2368 4
                task_obj.status = TaskStatusEnum.error
2369 4
                task_obj.modified_on = dt.utcnow()
2370

2371
                # update result
2372 4
                base_result.status = TaskStatusEnum.error
2373 4
                base_result.manager_name = task_obj.manager
2374 4
                base_result.modified_on = dt.utcnow()
2375

2376 4
                err = KVStore(data=msg)
2377 4
                err_id = self.add_kvstore([err])["data"][0]
2378 4
                base_result.error = err_id
2379

2380 4
            session.commit()
2381

2382 4
        return len(task_ids)
2383

2384 4
    def queue_reset_status(
2385
        self,
2386
        id: Union[str, List[str]] = None,
2387
        base_result: Union[str, List[str]] = None,
2388
        manager: Optional[str] = None,
2389
        reset_running: bool = False,
2390
        reset_error: bool = False,
2391
    ) -> int:
2392
        """
2393
        Reset the status of the tasks that a manager owns from Running to Waiting
2394
        If reset_error is True, then also reset errored tasks AND its results/proc
2395

2396
        Parameters
2397
        ----------
2398
        id : Optional[Union[str, List[str]]], optional
2399
            The id of the task to modify
2400
        base_result : Optional[Union[str, List[str]]], optional
2401
            The id of the base result to modify
2402
        manager : Optional[str], optional
2403
            The manager name to reset the status of
2404
        reset_running : bool, optional
2405
            If True, reset running tasks to be waiting
2406
        reset_error : bool, optional
2407
            If True, also reset errored tasks to be waiting,
2408
            also update results/proc to be INCOMPLETE
2409

2410
        Returns
2411
        -------
2412
        int
2413
            Updated count
2414
        """
2415

2416 4
        if not (reset_running or reset_error):
2417
            # nothing to do
2418 0
            return 0
2419

2420 4
        if sum(x is not None for x in [id, base_result, manager]) == 0:
2421 4
            raise ValueError("All query fields are None, reset_status must specify queries.")
2422

2423 4
        status = []
2424 4
        if reset_running:
2425 4
            status.append(TaskStatusEnum.running)
2426 4
        if reset_error:
2427 2
            status.append(TaskStatusEnum.error)
2428

2429 4
        query = format_query(TaskQueueORM, id=id, base_result_id=base_result, manager=manager, status=status)
2430

2431
        # Must have status + something, checking above as well(being paranoid)
2432 4
        if len(query) < 2:
2433 0
            raise ValueError("All query fields are None, reset_status must specify queries.")
2434

2435 4
        with self.session_scope() as session:
2436
            # Update results and procedures if reset_error
2437 4
            task_ids = session.query(TaskQueueORM.id).filter(*query)
2438 4
            session.query(BaseResultORM).filter(TaskQueueORM.base_result_id == BaseResultORM.id).filter(
2439
                TaskQueueORM.id.in_(task_ids)
2440
            ).update(dict(status="INCOMPLETE", modified_on=dt.utcnow()), synchronize_session=False)
2441

2442 4
            updated = (
2443
                session.query(TaskQueueORM)
2444
                .filter(TaskQueueORM.id.in_(task_ids))
2445
                .update(dict(status=TaskStatusEnum.waiting, modified_on=dt.utcnow()), synchronize_session=False)
2446
            )
2447

2448 4
        return updated
2449

2450 4
    def del_tasks(self, id: Union[str, list]):
2451
        """
2452
        Delete a task from the queue. Use with cautious
2453

2454
        Parameters
2455
        ----------
2456
        id : str or List
2457
            Ids of the tasks to delete
2458
        Returns
2459
        -------
2460
        int
2461
            Number of tasks deleted
2462
        """
2463

2464 4
        task_ids = [id] if isinstance(id, (int, str)) else id
2465 4
        with self.session_scope() as session:
2466 4
            count = session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2467

2468 4
        return count
2469

2470 4
    def _copy_task_to_queue(self, record_list: List[TaskRecord]):
2471
        """
2472
        copy the given tasks as-is to the DB. Used for data migration
2473

2474
        Parameters
2475
        ----------
2476
        record_list : List[TaskRecords]
2477
            List of task records to be copied
2478

2479
        Returns
2480
        -------
2481
        Dict[str, Any]
2482
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2483
        """
2484

2485 0
        meta = add_metadata_template()
2486

2487 0
        task_ids = []
2488 0
        with self.session_scope() as session:
2489 0
            for task in record_list:
2490 0
                doc = session.query(TaskQueueORM).filter_by(base_result_id=task.base_result_id)
2491

2492 0
                if get_count_fast(doc) == 0:
2493 0
                    doc = TaskQueueORM(**task.dict(exclude={"id"}))
2494 0
                    doc.priority = doc.priority.value
2495 0
                    if isinstance(doc.error, dict):
2496 0
                        doc.error = json.dumps(doc.error)
2497

2498 0
                    session.add(doc)
2499 0
                    session.commit()  # TODO: faster if done in bulk
2500 0
                    task_ids.append(str(doc.id))
2501 0
                    meta["n_inserted"] += 1
2502
                else:
2503 0
                    id = str(doc.first().id)
2504 0
                    meta["duplicates"].append(id)  # TODO
2505
                    # If new or duplicate, add the id to the return list
2506 0
                    task_ids.append(id)
2507 0
        meta["success"] = True
2508

2509 0
        ret = {"data": task_ids, "meta": meta}
2510 0
        return ret
2511

2512
    ### QueueManagerORMs
2513

2514 4
    def manager_update(self, name, **kwargs):
2515

2516 4
        do_log = kwargs.pop("log", False)
2517

2518 4
        inc_count = {
2519
            # Increment relevant data
2520
            "submitted": QueueManagerORM.submitted + kwargs.pop("submitted", 0),
2521
            "completed": QueueManagerORM.completed + kwargs.pop("completed", 0),
2522
            "returned": QueueManagerORM.returned + kwargs.pop("returned", 0),
2523
            "failures": QueueManagerORM.failures + kwargs.pop("failures", 0),
2524
        }
2525

2526 4
        upd = {key: kwargs[key] for key in QueueManagerORM.__dict__.keys() if key in kwargs}
2527

2528 4
        with self.session_scope() as session:
2529
            # QueueManagerORM.objects()  # init
2530 4
            manager = session.query(QueueManagerORM).filter_by(name=name)
2531 4
            if manager.count() > 0:  # existing
2532 4
                upd.update(inc_count, modified_on=dt.utcnow())
2533 4
                num_updated = manager.update(upd)
2534
            else:  # create new, ensures defaults and validations
2535 4
                manager = QueueManagerORM(name=name, **upd)
2536 4
                session.add(manager)
2537 4
                session.commit()
2538 4
                num_updated = 1
2539

2540 4
            if do_log:
2541
                # Pull again in case it was updated
2542 4
                manager = session.query(QueueManagerORM).filter_by(name=name).first()
2543

2544 4
                manager_log = QueueManagerLogORM(
2545
                    manager_id=manager.id,
2546
                    completed=manager.completed,
2547
                    submitted=manager.submitted,
2548
                    failures=manager.failures,
2549
                    total_worker_walltime=manager.total_worker_walltime,
2550
                    total_task_walltime=manager.total_task_walltime,
2551
                    active_tasks=manager.active_tasks,
2552
                    active_cores=manager.active_cores,
2553
                    active_memory=manager.active_memory,
2554
                )
2555

2556 4
                session.add(manager_log)
2557 4
                session.commit()
2558

2559 4
        return num_updated == 1
2560

2561 4
    def get_managers(
2562
        self, name: str = None, status: str = None, modified_before=None, modified_after=None, limit=None, skip=0
2563
    ):
2564

2565 4
        meta = get_metadata_template()
2566 4
        query = format_query(QueueManagerORM, name=name, status=status)
2567

2568 4
        if modified_before:
2569 4
            query.append(QueueManagerORM.modified_on <= modified_before)
2570

2571 4
        if modified_after:
2572 0
            query.append(QueueManagerORM.modified_on >= modified_after)
2573

2574 4
        data, meta["n_found"] = self.get_query_projection(QueueManagerORM, query, limit=limit, skip=skip)
2575 4
        meta["success"] = True
2576

2577 4
        return {"data": data, "meta": meta}
2578

2579 4
    def get_manager_logs(self, manager_ids: Union[List[str], str], timestamp_after=None, limit=None, skip=0):
2580 2
        meta = get_metadata_template()
2581 2
        query = format_query(QueueManagerLogORM, manager_id=manager_ids)
2582

2583 2
        if timestamp_after:
2584 2
            query.append(QueueManagerLogORM.timestamp >= timestamp_after)
2585

2586 2
        data, meta["n_found"] = self.get_query_projection(
2587
            QueueManagerLogORM, query, limit=limit, skip=skip, exclude=["id"]
2588
        )
2589 2
        meta["success"] = True
2590

2591 2
        return {"data": data, "meta": meta}
2592

2593 4
    def _copy_managers(self, record_list: Dict):
2594
        """
2595
        copy the given managers as-is to the DB. Used for data migration
2596

2597
        Parameters
2598
        ----------
2599
        record_list : List[Dict[str, Any]]
2600
            list of dict of managers data
2601
        Returns
2602
        -------
2603
        Dict[str, Any]
2604
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2605
        """
2606

2607 0
        meta = add_metadata_template()
2608

2609 0
        manager_names = []
2610 0
        with self.session_scope() as session:
2611 0
            for manager in record_list:
2612 0
                doc = session.query(QueueManagerORM).filter_by(name=manager["name"])
2613

2614 0
                if get_count_fast(doc) == 0:
2615 0
                    doc = QueueManagerORM(**manager)
2616 0
                    if isinstance(doc.created_on, float):
2617 0
                        doc.created_on = dt.fromtimestamp(doc.created_on / 1e3)
2618 0
                    if isinstance(doc.modified_on, float):
2619 0
                        doc.modified_on = dt.fromtimestamp(doc.modified_on / 1e3)
2620 0
                    session.add(doc)
2621 0
                    session.commit()  # TODO: faster if done in bulk
2622 0
                    manager_names.append(doc.name)
2623 0
                    meta["n_inserted"] += 1
2624
                else:
2625 0
                    name = doc.first().name
2626 0
                    meta["duplicates"].append(name)  # TODO
2627
                    # If new or duplicate, add the id to the return list
2628 0
                    manager_names.append(id)
2629 0
        meta["success"] = True
2630

2631 0
        ret = {"data": manager_names, "meta": meta}
2632 0
        return ret
2633

2634
    ### UserORMs
2635

2636 4
    _valid_permissions = frozenset({"read", "write", "compute", "queue", "admin"})
2637

2638 4
    @staticmethod
2639 4
    def _generate_password() -> str:
2640
        """
2641
        Generates a random password e.g. for add_user and modify_user.
2642

2643
        Returns
2644
        -------
2645
        str
2646
            An unhashed random password.
2647
        """
2648 4
        return secrets.token_urlsafe(32)
2649

2650 4
    def add_user(
2651
        self, username: str, password: Optional[str] = None, permissions: List[str] = ["read"], overwrite: bool = False
2652
    ) -> Tuple[bool, str]:
2653
        """
2654
        Adds a new user and associated permissions.
2655

2656
        Passwords are stored using bcrypt.
2657

2658
        Parameters
2659
        ----------
2660
        username : str
2661
            New user's username
2662
        password : Optional[str], optional
2663
            The user's password. If None, a new password will be generated.
2664
        permissions : Optional[List[str]], optional
2665
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2666
        overwrite: bool, optional
2667
            Overwrite the user if it already exists.
2668
        Returns
2669
        -------
2670
        Tuple[bool, str]
2671
            A tuple of (success flag, password)
2672
        """
2673

2674
        # Make sure permissions are valid
2675 4
        if not self._valid_permissions >= set(permissions):
2676 0
            raise KeyError("Permissions settings not understood: {}".format(set(permissions) - self._valid_permissions))
2677

2678 4
        if password is None:
2679 4
            password = self._generate_password()
2680

2681 4
        hashed = bcrypt.hashpw(password.encode("UTF-8"), bcrypt.gensalt(6))
2682

2683 4
        blob = {"username": username, "password": hashed, "permissions": permissions}
2684

2685 4
        success = False
2686 4
        with self.session_scope() as session:
2687 4
            if overwrite:
2688 4
                count = session.query(UserORM).filter_by(username=username).update(blob)
2689
                # doc.upsert_one(**blob)
2690 4
                success = count == 1
2691

2692
            else:
2693 4
                try:
2694 4
                    user = UserORM(**blob)
2695 4
                    session.add(user)
2696 4
                    session.commit()
2697 4
                    success = True
2698 4
                except IntegrityError as err:
2699 4
                    self.logger.warning(str(err))
2700 4
                    success = False
2701 4
                    session.rollback()
2702

2703 4
        return success, password
2704

2705 4
    def verify_user(self, username: str, password: str, permission: str) -> Tuple[bool, str]:
2706
        """
2707
        Verifies if a user has the requested permissions or not.
2708

2709
        Passwords are stored and verified using bcrypt.
2710

2711
        Parameters
2712
        ----------
2713
        username : str
2714
            The username to verify
2715
        password : str
2716
            The password associated with the username
2717
        permission : str
2718
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2719

2720
        Returns
2721
        -------
2722
        Tuple[bool, str]
2723
            A tuple of (success flag, failure string)
2724

2725
        Examples
2726
        --------
2727

2728
        >>> db.add_user("george", "shortpw")
2729

2730
        >>> db.verify_user("george", "shortpw", "read")[0]
2731
        True
2732

2733
        >>> db.verify_user("george", "shortpw", "admin")[0]
2734
        False
2735

2736
        """
2737

2738 4
        if self._bypass_security or (self._allow_read and (permission == "read")):
2739 4
            return (True, "Success")
2740

2741 4
        with self.session_scope() as session:
2742 4
            data = session.query(UserORM).filter_by(username=username).first()
2743

2744 4
            if data is None:
2745 4
                return (False, "User not found.")
2746

2747
            # Completely general failure
2748 4
            try:
2749 4
                pwcheck = bcrypt.checkpw(password.encode("UTF-8"), data.password)
2750 0
            except Exception as e:
2751 0
                self.logger.warning(f"Password check failure, error: {str(e)}")
2752 0
                self.logger.warning(
2753
                    f"Error likely caused by encryption salt mismatch, potentially fixed by creating a new password for user {username}."
2754
                )
2755 0
                return (False, "Password decryption failure, please contact your database administrator.")
2756

2757 4
            if pwcheck is False:
2758 4
                return (False, "Incorrect password.")
2759

2760
            # Admin has access to everything
2761 4
            if (permission.lower() not in data.permissions) and ("admin" not in data.permissions):
2762 4
                return (False, "User has insufficient permissions.")
2763

2764 4
        return (True, "Success")
2765

2766 4
    def modify_user(
2767
        self,
2768
        username: str,
2769
        password: Optional[str] = None,
2770
        reset_password: bool = False,
2771
        permissions: Optional[List[str]] = None,
2772
    ) -> Tuple[bool, str]:
2773
        """
2774
        Alters a user's password, permissions, or both
2775

2776
        Passwords are stored using bcrypt.
2777

2778
        Parameters
2779
        ----------
2780
        username : str
2781
            The username
2782
        password : Optional[str], optional
2783
            The user's new password. If None, the password will not be updated. Excludes reset_password.
2784
        reset_password: bool, optional
2785
            Reset the user's password to a new autogenerated one. The default is False.
2786
        permissions : Optional[List[str]], optional
2787
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2788

2789
        Returns
2790
        -------
2791
        Tuple[bool, str]
2792
            A tuple of (success flag, message)
2793
        """
2794

2795 4
        if reset_password and password is not None:
2796 4
            return False, "only one of reset_password and password may be specified"
2797

2798 4
        with self.session_scope() as session:
2799 4
            data = session.query(UserORM).filter_by(username=username).first()
2800

2801 4
            if data is None:
2802 4
                return False, f"User {username} not found."
2803

2804 4
            blob = {"username": username}
2805

2806 4
            if permissions is not None:
2807
                # Make sure permissions are valid
2808 4
                if not self._valid_permissions >= set(permissions):
2809 0
                    return False, "Permissions not understood: {}".format(set(permissions) - self._valid_permissions)
2810 4
                blob["permissions"] = permissions
2811 4
            if reset_password:
2812 4
                password = self._generate_password()
2813 4
            if password is not None:
2814 4
                blob["password"] = bcrypt.hashpw(password.encode("UTF-8"), bcrypt.gensalt(6))
2815

2816 4
            count = session.query(UserORM).filter_by(username=username).update(blob)
2817 4
            success = count == 1
2818

2819 4
        if success:
2820 4
            return True, None if password is None else f"New password is {password}"
2821