1
"""
2
SQLAlchemy Database class to handle access to Pstgres through ORM
3
"""
4

5 8
try:
6 8
    from sqlalchemy import create_engine, and_, or_, case, func
7 8
    from sqlalchemy.exc import IntegrityError
8 8
    from sqlalchemy.orm import sessionmaker, with_polymorphic
9 8
    from sqlalchemy.sql.expression import desc
10 8
    from sqlalchemy.sql.expression import case as expression_case
11 0
except ImportError:
12 0
    raise ImportError(
13
        "SQLAlchemy_socket requires sqlalchemy, please install this python " "module or try a different db_socket."
14
    )
15

16 8
import json
17 8
import logging
18 8
import secrets
19 8
from collections.abc import Iterable
20 8
from contextlib import contextmanager
21 8
from datetime import datetime as dt
22 8
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
23

24 8
import bcrypt
25

26
# pydantic classes
27 8
from qcfractal.interface.models import (
28
    GridOptimizationRecord,
29
    KeywordSet,
30
    Molecule,
31
    ObjectId,
32
    OptimizationRecord,
33
    ResultRecord,
34
    TaskRecord,
35
    TaskStatusEnum,
36
    TorsionDriveRecord,
37
    KVStore,
38
    CompressionEnum,
39
    prepare_basis,
40
)
41 8
from qcfractal.interface.models.records import RecordStatusEnum
42

43 8
from qcfractal.storage_sockets.db_queries import QUERY_CLASSES
44 8
from qcfractal.storage_sockets.models import (
45
    AccessLogORM,
46
    BaseResultORM,
47
    CollectionORM,
48
    DatasetORM,
49
    GridOptimizationProcedureORM,
50
    KeywordsORM,
51
    KVStoreORM,
52
    MoleculeORM,
53
    OptimizationProcedureORM,
54
    QueueManagerLogORM,
55
    QueueManagerORM,
56
    ReactionDatasetORM,
57
    ResultORM,
58
    ServerStatsLogORM,
59
    ServiceQueueORM,
60
    TaskQueueORM,
61
    TorsionDriveProcedureORM,
62
    UserORM,
63
    VersionsORM,
64
    WavefunctionStoreORM,
65
)
66 8
from qcfractal.storage_sockets.storage_utils import add_metadata_template, get_metadata_template
67

68 8
from .models import Base
69

70 8
if TYPE_CHECKING:
71 0
    from ..services.service_util import BaseService
72

73
# for version checking
74 8
import qcelemental, qcfractal, qcengine
75

76 8
_null_keys = {"basis", "keywords"}
77 8
_id_keys = {"id", "molecule", "keywords", "procedure_id"}
78 8
_lower_func = lambda x: x.lower()
79 8
_prepare_keys = {"program": _lower_func, "basis": prepare_basis, "method": _lower_func, "procedure": _lower_func}
80

81

82 8
def dict_from_tuple(keys, values):
83 0
    return [dict(zip(keys, row)) for row in values]
84

85

86 8
def format_query(ORMClass, **query: Dict[str, Union[str, List[str]]]) -> Dict[str, Union[str, List[str]]]:
87
    """
88
    Formats a query into a SQLAlchemy format.
89
    """
90

91 8
    ret = []
92 8
    for k, v in query.items():
93 8
        if v is None:
94 8
            continue
95

96
        # Handle None keys
97 8
        k = k.lower()
98 8
        if (k in _null_keys) and (v == "null"):
99 8
            v = None
100

101 8
        if k in _prepare_keys:
102 8
            f = _prepare_keys[k]
103 8
            if isinstance(v, (list, tuple)):
104 8
                v = [f(x) for x in v]
105
            else:
106 8
                v = f(v)
107

108 8
        if isinstance(v, (list, tuple)):
109 8
            col = getattr(ORMClass, k)
110 8
            ret.append(getattr(col, "in_")(v))
111
        else:
112 8
            ret.append(getattr(ORMClass, k) == v)
113

114 8
    return ret
115

116

117 8
def get_count_fast(query):
118
    """
119
    returns total count of the query using:
120
        Fast: SELECT COUNT(*) FROM TestModel WHERE ...
121

122
    Not like q.count():
123
        Slow: SELECT COUNT(*) FROM (SELECT ... FROM TestModel WHERE ...) ...
124
    """
125

126 8
    count_q = query.statement.with_only_columns([func.count()]).order_by(None)
127 8
    count = query.session.execute(count_q).scalar()
128

129 8
    return count
130

131

132 8
def get_procedure_class(record):
133

134 8
    if isinstance(record, OptimizationRecord):
135 8
        procedure_class = OptimizationProcedureORM
136 8
    elif isinstance(record, TorsionDriveRecord):
137 8
        procedure_class = TorsionDriveProcedureORM
138 1
    elif isinstance(record, GridOptimizationRecord):
139 1
        procedure_class = GridOptimizationProcedureORM
140
    else:
141 0
        raise TypeError("Procedure of type {} is not valid or supported yet.".format(type(record)))
142

143 8
    return procedure_class
144

145

146 8
def get_collection_class(collection_type):
147

148 8
    collection_map = {"dataset": DatasetORM, "reactiondataset": ReactionDatasetORM}
149

150 8
    collection_class = CollectionORM
151

152 8
    if collection_type in collection_map:
153 8
        collection_class = collection_map[collection_type]
154

155 8
    return collection_class
156

157

158 8
class SQLAlchemySocket:
159
    """
160
    SQLAlcehmy QCDB wrapper class.
161
    """
162

163 8
    def __init__(
164
        self,
165
        uri: str,
166
        project: str = "molssidb",
167
        bypass_security: bool = False,
168
        allow_read: bool = True,
169
        logger: "Logger" = None,
170
        sql_echo: bool = False,
171
        max_limit: int = 1000,
172
        skip_version_check: bool = False,
173
    ):
174
        """
175
        Constructs a new SQLAlchemy socket
176

177
        """
178

179
        # Logging data
180 8
        if logger:
181 0
            self.logger = logger
182
        else:
183 8
            self.logger = logging.getLogger("SQLAlcehmySocket")
184

185
        # Security
186 8
        self._bypass_security = bypass_security
187 8
        self._allow_read = allow_read
188

189 8
        self._lower_results_index = ["method", "basis", "program"]
190

191
        # disconnect from any active default connection
192
        # disconnect()
193 8
        if "psycopg2" not in uri:
194 8
            uri = uri.replace("postgresql", "postgresql+psycopg2")
195

196 8
        if project and not uri.endswith("/"):
197 0
            uri = uri + "/"
198

199 8
        uri = uri + project
200 8
        self.logger.info(f"SQLAlchemy attempt to connect to {uri}.")
201

202
        # Connect to DB and create session
203 8
        self.uri = uri
204 8
        self.engine = create_engine(
205
            uri,
206
            echo=sql_echo,  # echo for logging into python logging
207
            pool_size=5,  # 5 is the default, 0 means unlimited
208
        )
209 8
        self.logger.info(
210
            "Connected SQLAlchemy to DB dialect {} with driver {}".format(self.engine.dialect.name, self.engine.driver)
211
        )
212

213 8
        self.Session = sessionmaker(bind=self.engine)
214

215
        # check version compatibility
216 8
        db_ver = self.check_lib_versions()
217 8
        self.logger.info(f"DB versions: {db_ver}")
218 8
        if (not skip_version_check) and (db_ver and qcfractal.__version__ != db_ver["fractal_version"]):
219 0
            raise TypeError(
220
                f"You are running QCFractal version {qcfractal.__version__} "
221
                f'with an older DB version ({db_ver["fractal_version"]}). '
222
                f'Please run "qcfractal-server upgrade" first before starting the server.'
223
            )
224

225
        # actually create the tables
226 8
        try:
227 8
            Base.metadata.create_all(self.engine)
228 8
            self.check_lib_versions()  # update version if new DB
229 0
        except Exception as e:
230 0
            raise ValueError(f"SQLAlchemy Connection Error\n {str(e)}") from None
231

232
        # Advanced queries objects
233 8
        self._query_classes = {
234
            cls._class_name: cls(self.engine.url.database, max_limit=max_limit) for cls in QUERY_CLASSES
235
        }
236

237
        # if expanded_uri["password"] is not None:
238
        #     # connect to mongoengine
239
        #     self.client = db.connect(db=project, host=uri, authMechanism=authMechanism, authSource=authSource)
240
        # else:
241
        #     # connect to mongoengine
242
        #     self.client = db.connect(db=project, host=uri)
243

244
        # self._url, self._port = expanded_uri["nodelist"][0]
245

246
        # try:
247
        #     version_array = self.client.server_info()['versionArray']
248
        #
249
        #     if tuple(version_array) < (3, 2):
250
        #         raise RuntimeError
251
        # except AttributeError:
252
        #     raise RuntimeError(
253
        #         "Could not detect MongoDB version at URL {}. It may be a very old version or installed incorrectly. "
254
        #         "Choosing to stop instead of assuming version is at least 3.2.".format(uri))
255
        # except RuntimeError:
256
        #     # Trap low version
257
        #     raise RuntimeError("Connected MongoDB at URL {} needs to be at least version 3.2, found version {}.".
258
        #                        format(uri, self.client.server_info()['version']))
259

260 8
        self._project_name = project
261 8
        self._max_limit = max_limit
262

263 8
    def __str__(self) -> str:
264 0
        return f"<SQLAlchemySocket: address='{self.uri}`>"
265

266 8
    @contextmanager
267 3
    def session_scope(self):
268
        """Provide a transactional scope"""
269

270 8
        session = self.Session()
271 8
        try:
272 8
            yield session
273 8
            session.commit()
274 8
        except:
275 8
            session.rollback()
276 8
            raise
277
        finally:
278 8
            session.close()
279

280 8
    def _clear_db(self, db_name: str = None):
281
        """Dangerous, make sure you are deleting the right DB"""
282

283 8
        self.logger.warning("SQL: Clearing database '{}' and dropping all tables.".format(db_name))
284

285
        # drop all tables that it knows about
286 8
        Base.metadata.drop_all(self.engine)
287

288
        # create the tables again
289 8
        Base.metadata.create_all(self.engine)
290

291
        # self.client.drop_database(db_name)
292

293 8
    def _delete_DB_data(self, db_name):
294
        """TODO: needs more testing"""
295

296 8
        with self.session_scope() as session:
297
            # Metadata
298 8
            session.query(VersionsORM).delete(synchronize_session=False)
299
            # Task and services
300 8
            session.query(TaskQueueORM).delete(synchronize_session=False)
301 8
            session.query(QueueManagerLogORM).delete(synchronize_session=False)
302 8
            session.query(QueueManagerORM).delete(synchronize_session=False)
303 8
            session.query(ServiceQueueORM).delete(synchronize_session=False)
304

305
            # Collections
306 8
            session.query(CollectionORM).delete(synchronize_session=False)
307

308
            # Records
309 8
            session.query(TorsionDriveProcedureORM).delete(synchronize_session=False)
310 8
            session.query(GridOptimizationProcedureORM).delete(synchronize_session=False)
311 8
            session.query(OptimizationProcedureORM).delete(synchronize_session=False)
312 8
            session.query(ResultORM).delete(synchronize_session=False)
313 8
            session.query(WavefunctionStoreORM).delete(synchronize_session=False)
314 8
            session.query(BaseResultORM).delete(synchronize_session=False)
315

316
            # Auxiliary tables
317 8
            session.query(KVStoreORM).delete(synchronize_session=False)
318 8
            session.query(MoleculeORM).delete(synchronize_session=False)
319

320 8
    def get_project_name(self) -> str:
321 8
        return self._project_name
322

323 8
    def get_limit(self, limit: Optional[int]) -> int:
324
        """Get the allowed limit on results to return in queries based on the
325
        given `limit`. If this number is greater than the
326
        SQLAlchemySocket.max_limit then the max_limit will be returned instead.
327
        """
328

329 8
        return limit if limit is not None and limit < self._max_limit else self._max_limit
330

331 8
    def get_query_projection(self, className, query, *, limit=None, skip=0, include=None, exclude=None):
332

333 8
        if include and exclude:
334 0
            raise AttributeError(
335
                f"Either include or exclude can be "
336
                f"used, not both at the same query. "
337
                f"Given include: {include}, exclude: {exclude}"
338
            )
339

340 8
        prop, hybrids, relationships = className._get_col_types()
341

342
        # build projection from include or exclude
343 8
        _projection = []
344 8
        if include:
345 8
            _projection = set(include)
346 8
        elif exclude:
347 8
            _projection = set(className._all_col_names()) - set(exclude) - set(className.db_related_fields)
348 8
        _projection = list(_projection)
349

350 8
        proj = []
351 8
        join_attrs = {}
352 8
        callbacks = []
353

354
        # prepare hybrid attributes for callback and joins
355 8
        for key in _projection:
356 8
            if key in prop:  # normal column
357 8
                proj.append(getattr(className, key))
358

359
            # if hybrid property, save callback, and relation if any
360 8
            elif key in hybrids:
361 8
                callbacks.append(key)
362

363
                # if it has a relationship
364 8
                if key + "_obj" in relationships.keys():
365

366
                    # join_class_name = relationships[key + '_obj']
367 8
                    join_attrs[key] = relationships[key + "_obj"]
368
            else:
369 0
                raise AttributeError(f"Atrribute {key} is not found in class {className}.")
370

371 8
        for key in join_attrs:
372 8
            _projection.remove(key)
373

374 8
        with self.session_scope() as session:
375 8
            if _projection or join_attrs:
376

377 8
                if join_attrs and "id" not in _projection:  # if the id is need for joins
378 8
                    proj.append(getattr(className, "id"))
379 8
                    _projection.append("_id")  # not to be returned to user
380

381
                # query with projection, without joins
382 8
                data = session.query(*proj).filter(*query)
383

384 8
                n_found = get_count_fast(data)  # before iterating on the data
385 8
                data = data.limit(self.get_limit(limit)).offset(skip)
386 8
                rdata = [dict(zip(_projection, row)) for row in data]
387

388
                # query for joins if any (relationships and hybrids)
389 8
                if join_attrs:
390 8
                    res_ids = [d.get("id", d.get("_id")) for d in rdata]
391 8
                    res_ids.sort()
392 8
                    join_data = {res_id: {} for res_id in res_ids}
393

394
                    # relations data
395 8
                    for key, relation_details in join_attrs.items():
396 8
                        ret = (
397
                            session.query(
398
                                relation_details["remote_side_column"].label("id"), relation_details["join_class"]
399
                            )
400
                            .filter(relation_details["remote_side_column"].in_(res_ids))
401
                            .order_by(relation_details["remote_side_column"])
402
                            .all()
403
                        )
404 8
                        for res_id in res_ids:
405 8
                            join_data[res_id][key] = []
406 8
                            for res in ret:
407 8
                                if res_id == res[0]:
408 8
                                    join_data[res_id][key].append(res[1])
409

410 8
                        for data in rdata:
411 8
                            parent_id = data.get("id", data.get("_id"))
412 8
                            data[key] = join_data[parent_id][key]
413 8
                            data.pop("_id", None)
414

415
                # call hybrid methods
416 8
                for callback in callbacks:
417 8
                    for res in rdata:
418 8
                        res[callback] = getattr(className, "_" + callback)(res[callback])
419

420 8
                id_fields = className._get_fieldnames_with_DB_ids_()
421 8
                for d in rdata:
422
                    # Expand extra json into fields
423 8
                    if "extra" in d:
424 8
                        d.update(d["extra"])
425 8
                        del d["extra"]
426

427
                    # transform ids from int into str
428 8
                    for key in id_fields:
429 8
                        if key in d.keys() and d[key] is not None:
430 8
                            if isinstance(d[key], Iterable):
431 0
                                d[key] = [str(i) for i in d[key]]
432
                            else:
433 8
                                d[key] = str(d[key])
434
                # print('--------rdata after: ', rdata)
435
            else:
436 8
                data = session.query(className).filter(*query)
437

438
                # from sqlalchemy.dialects import postgresql
439
                # print(data.statement.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True}))
440 8
                n_found = get_count_fast(data)
441 8
                data = data.limit(self.get_limit(limit)).offset(skip).all()
442 8
                rdata = [d.to_dict() for d in data]
443

444 8
        return rdata, n_found
445

446
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
447

448 8
    def custom_query(self, class_name: str, query_key: str, **kwargs):
449
        """
450
        Run advanced or specialized queries on different classes
451

452
        Parameters
453
        ----------
454
        class_name : str
455
            REST APIs name of the class (not the actual python name),
456
             e.g., torsiondrive
457
        query_key : str
458
            The feature or attribute to look for, like initial_molecule
459
        kwargs
460
            Extra arguments needed by the query, like the id of the torison drive
461

462
        Returns
463
        -------
464
            Dict[str,Any]:
465
                Query result dictionary with keys:
466
                data: returned data by the query (variable format)
467
                meta:
468
                    success: True or False
469
                    error_description: Error msg to show to the user
470
        """
471

472 8
        ret = {"data": [], "meta": get_metadata_template()}
473

474 8
        try:
475 8
            if class_name not in self._query_classes:
476 0
                raise AttributeError(f"Class name {class_name} is not found.")
477

478 8
            session = self.Session()
479 8
            ret["data"] = self._query_classes[class_name].query(session, query_key, **kwargs)
480 8
            ret["meta"]["success"] = True
481 8
            try:
482 8
                ret["meta"]["n_found"] = len(ret["data"])
483 8
            except TypeError:
484 8
                ret["meta"]["n_found"] = 1
485 0
        except Exception as err:
486 0
            ret["meta"]["error_description"] = str(err)
487

488 8
        return ret
489

490
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logging ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
491

492 8
    def save_access(self, log_data):
493

494 0
        with self.session_scope() as session:
495 0
            log = AccessLogORM(**log_data)
496 0
            session.add(log)
497 0
            session.commit()
498

499
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Logs (KV store) ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
500

501 8
    def add_kvstore(self, outputs: List[KVStore]):
502
        """
503
        Adds to the key/value store table.
504

505
        Parameters
506
        ----------
507
        outputs : List[Any]
508
            A list of KVStore objects add.
509

510
        Returns
511
        -------
512
        Dict[str, Any]
513
            Dictionary with keys data and meta, data is the ids of added blobs
514
        """
515

516 8
        meta = add_metadata_template()
517 8
        output_ids = []
518 8
        with self.session_scope() as session:
519 8
            for output in outputs:
520 8
                if output is None:
521 4
                    output_ids.append(None)
522 4
                    continue
523

524 8
                entry = KVStoreORM(**output.dict())
525 8
                session.add(entry)
526 8
                session.commit()
527 8
                output_ids.append(str(entry.id))
528 8
                meta["n_inserted"] += 1
529

530 8
        meta["success"] = True
531

532 8
        return {"data": output_ids, "meta": meta}
533

534 8
    def get_kvstore(self, id: List[ObjectId] = None, limit: int = None, skip: int = 0):
535
        """
536
        Pulls from the key/value store table.
537

538
        Parameters
539
        ----------
540
        id : List[str]
541
            A list of ids to query
542
        limit : Optional[int], optional
543
            Maximum number of results to return.
544
        skip : Optional[int], optional
545
            skip the `skip` results
546
        Returns
547
        -------
548
        Dict[str, Any]
549
            Dictionary with keys data and meta, data is a key-value dictionary of found key-value stored items.
550
        """
551

552 8
        meta = get_metadata_template()
553

554 8
        query = format_query(KVStoreORM, id=id)
555

556 8
        rdata, meta["n_found"] = self.get_query_projection(KVStoreORM, query, limit=limit, skip=skip)
557

558 8
        meta["success"] = True
559

560 8
        data = {}
561
        # TODO - after migrating everything, remove the 'value' column in the table
562 8
        for d in rdata:
563 8
            val = d.pop("value")
564 8
            if d["data"] is None:
565
                # Set the data field to be the string or dictionary
566 8
                d["data"] = val
567

568
                # Remove these and let the model handle the defaults
569 8
                d.pop("compression")
570 8
                d.pop("compression_level")
571

572
            # The KVStore constructor can handle conversion of strings and dictionaries
573 8
            data[d["id"]] = KVStore(**d)
574

575 8
        return {"data": data, "meta": meta}
576

577
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Molecule ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
578

579 8
    def get_add_molecules_mixed(self, data: List[Union[ObjectId, Molecule]]) -> List[Molecule]:
580
        """
581
        Get or add the given molecules (if they don't exit).
582
        MoleculeORMs are given in a mixed format, either as a dict of mol data
583
        or as existing mol id
584

585
        TODO: to be split into get by_id and get_by_data
586
        """
587

588 8
        meta = get_metadata_template()
589

590 8
        ordered_mol_dict = {indx: mol for indx, mol in enumerate(data)}
591 8
        new_molecules = {}
592 8
        id_mols = {}
593 8
        for idx, mol in ordered_mol_dict.items():
594 8
            if isinstance(mol, (int, str)):
595 8
                id_mols[idx] = mol
596 8
            elif isinstance(mol, Molecule):
597 8
                new_molecules[idx] = mol
598
            else:
599 0
                meta["errors"].append((idx, "Data type not understood"))
600

601 8
        ret_mols = {}
602

603
        # Add all new molecules
604 8
        flat_mols = []
605 8
        flat_mol_keys = []
606 8
        for k, v in new_molecules.items():
607 8
            flat_mol_keys.append(k)
608 8
            flat_mols.append(v)
609 8
        flat_mols = self.add_molecules(flat_mols)["data"]
610

611 8
        id_mols.update({k: v for k, v in zip(flat_mol_keys, flat_mols)})
612

613
        # Get molecules by index and translate back to dict
614 8
        tmp = self.get_molecules(list(id_mols.values()))
615 8
        id_mols_list = tmp["data"]
616 8
        meta["errors"].extend(tmp["meta"]["errors"])
617

618
        # TODO - duplicate ids get removed on the line below. Some
619
        # code may depend on this behavior, so careful changing it
620 8
        inv_id_mols = {v: k for k, v in id_mols.items()}
621

622 8
        for mol in id_mols_list:
623 8
            ret_mols[inv_id_mols[mol.id]] = mol
624

625 8
        meta["success"] = True
626 8
        meta["n_found"] = len(ret_mols)
627 8
        meta["missing"] = list(ordered_mol_dict.keys() - ret_mols.keys())
628

629
        # Rewind to flat last
630 8
        ret = []
631 8
        for ind in range(len(ordered_mol_dict)):
632 8
            if ind in ret_mols:
633 8
                ret.append(ret_mols[ind])
634
            else:
635 8
                ret.append(None)
636

637 8
        return {"meta": meta, "data": ret}
638

639 8
    def add_molecules(self, molecules: List[Molecule]):
640
        """
641
        Adds molecules to the database.
642

643
        Parameters
644
        ----------
645
        molecules : List[Molecule]
646
            A List of molecule objects to add.
647

648
        Returns
649
        -------
650
        bool
651
            Whether the operation was successful.
652
        """
653

654 8
        meta = add_metadata_template()
655

656 8
        results = []
657 8
        with self.session_scope() as session:
658

659
            # Build out the ORMs
660 8
            orm_molecules = []
661 8
            for dmol in molecules:
662

663 8
                if dmol.validated is False:
664 0
                    dmol = Molecule(**dmol.dict(), validate=True)
665

666 8
                mol_dict = dmol.dict(exclude={"id", "validated"})
667

668
                # TODO: can set them as defaults in the sql_models, not here
669 8
                mol_dict["fix_com"] = True
670 8
                mol_dict["fix_orientation"] = True
671

672
                # Build fresh indices
673 8
                mol_dict["molecule_hash"] = dmol.get_hash()
674 8
                mol_dict["molecular_formula"] = dmol.get_molecular_formula()
675

676 8
                mol_dict["identifiers"] = {}
677 8
                mol_dict["identifiers"]["molecule_hash"] = mol_dict["molecule_hash"]
678 8
                mol_dict["identifiers"]["molecular_formula"] = mol_dict["molecular_formula"]
679

680
                # search by index keywords not by all keys, much faster
681 8
                orm_molecules.append(MoleculeORM(**mol_dict))
682

683
            # Check if we have duplicates
684 8
            hash_list = [x.molecule_hash for x in orm_molecules]
685 8
            query = format_query(MoleculeORM, molecule_hash=hash_list)
686 8
            indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
687 8
            previous_id_map = {k: v for k, v in indices}
688

689
            # For a bulk add there must be no pre-existing and there must be no duplicates in the add list
690 8
            bulk_ok = len(hash_list) == len(set(hash_list))
691 8
            bulk_ok &= len(previous_id_map) == 0
692
            # bulk_ok = False
693

694 8
            if bulk_ok:
695
                # Bulk save, doesn't update fields for speed
696 8
                session.bulk_save_objects(orm_molecules)
697 8
                session.commit()
698

699
                # Query ID's and reorder based off orm_molecule ordered list
700 8
                query = format_query(MoleculeORM, molecule_hash=hash_list)
701 8
                indices = session.query(MoleculeORM.molecule_hash, MoleculeORM.id).filter(*query)
702

703 8
                id_map = {k: v for k, v in indices}
704 8
                n_inserted = len(orm_molecules)
705

706
            else:
707
                # Start from old ID map
708 8
                id_map = previous_id_map
709

710 8
                new_molecules = []
711 8
                n_inserted = 0
712

713 8
                for orm_mol in orm_molecules:
714 8
                    duplicate_id = id_map.get(orm_mol.molecule_hash, False)
715 8
                    if duplicate_id is not False:
716 8
                        meta["duplicates"].append(str(duplicate_id))
717
                    else:
718 8
                        new_molecules.append(orm_mol)
719 8
                        id_map[orm_mol.molecule_hash] = "placeholder_id"
720 8
                        n_inserted += 1
721 8
                        session.add(orm_mol)
722

723
                    # We should make sure there was not a hash collision?
724
                    # new_mol.compare(old_mol)
725
                    # raise KeyError("!!! WARNING !!!: Hash collision detected")
726

727 8
                session.commit()
728

729 8
                for new_mol in new_molecules:
730 8
                    id_map[new_mol.molecule_hash] = new_mol.id
731

732 8
            results = [str(id_map[x.molecule_hash]) for x in orm_molecules]
733 8
            assert "placeholder_id" not in results
734 8
            meta["n_inserted"] = n_inserted
735

736 8
        meta["success"] = True
737

738 8
        ret = {"data": results, "meta": meta}
739 8
        return ret
740

741 8
    def get_molecules(self, id=None, molecule_hash=None, molecular_formula=None, limit: int = None, skip: int = 0):
742 8
        try:
743 8
            if isinstance(molecular_formula, str):
744 8
                molecular_formula = qcelemental.molutil.order_molecular_formula(molecular_formula)
745 8
            elif isinstance(molecular_formula, list):
746 4
                molecular_formula = [qcelemental.molutil.order_molecular_formula(form) for form in molecular_formula]
747 0
        except ValueError:
748
            # Probably, the user provided an invalid chemical formula
749 0
            pass
750

751 8
        meta = get_metadata_template()
752

753 8
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash, molecular_formula=molecular_formula)
754

755
        # Don't include the hash or the molecular_formula in the returned result
756 8
        rdata, meta["n_found"] = self.get_query_projection(
757
            MoleculeORM, query, limit=limit, skip=skip, exclude=["molecule_hash", "molecular_formula"]
758
        )
759

760 8
        meta["success"] = True
761

762
        # This is required for sparse molecules as we don't know which values are spase
763
        # We are lucky that None is the default and doesn't mean anything in Molecule
764
        # This strategy does not work for other objects
765 8
        data = []
766 8
        for mol_dict in rdata:
767 8
            mol_dict = {k: v for k, v in mol_dict.items() if v is not None}
768 8
            data.append(Molecule(**mol_dict, validate=False, validated=True))
769

770 8
        return {"meta": meta, "data": data}
771

772 8
    def del_molecules(self, id: List[str] = None, molecule_hash: List[str] = None):
773
        """
774
        Removes a molecule from the database from its hash.
775

776
        Parameters
777
        ----------
778
        id : str or List[str], optional
779
            ids of molecules, can use the hash parameter instead
780
        molecule_hash : str or List[str]
781
            The hash of a molecule.
782

783
        Returns
784
        -------
785
        bool
786
            Number of deleted molecules.
787
        """
788

789 8
        query = format_query(MoleculeORM, id=id, molecule_hash=molecule_hash)
790

791 8
        with self.session_scope() as session:
792 8
            ret = session.query(MoleculeORM).filter(*query).delete(synchronize_session=False)
793

794 8
        return ret
795

796
    # ~~~~~~~~~~~~~~~~~~~~~~~ Keywords ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
797

798 8
    def add_keywords(self, keyword_sets: List[KeywordSet]):
799
        """Add one KeywordSet uniquly identified by 'program' and the 'name'.
800

801
        Parameters
802
        ----------
803
        keywords_set : List[KeywordSet]
804
            A list of KeywordSets to be inserted.
805

806
        Returns
807
        -------
808
        Dict[str, Any]
809
            (see add_metadata_template())
810
            The 'data' part is a list of ids of the inserted options
811
            data['duplicates'] has the duplicate entries
812

813
        Notes
814
        ------
815
            Duplicates are not considered errors.
816

817
        """
818

819 8
        meta = add_metadata_template()
820

821 8
        keywords = []
822 8
        with self.session_scope() as session:
823 8
            for kw in keyword_sets:
824

825 8
                kw_dict = kw.dict(exclude={"id"})
826

827
                # search by index keywords not by all keys, much faster
828 8
                found = session.query(KeywordsORM).filter_by(hash_index=kw_dict["hash_index"]).first()
829 8
                if not found:
830 8
                    doc = KeywordsORM(**kw_dict)
831 8
                    session.add(doc)
832 8
                    session.commit()
833 8
                    keywords.append(str(doc.id))
834 8
                    meta["n_inserted"] += 1
835
                else:
836 8
                    meta["duplicates"].append(str(found.id))  # TODO
837 8
                    keywords.append(str(found.id))
838 8
                meta["success"] = True
839

840 8
        ret = {"data": keywords, "meta": meta}
841

842 8
        return ret
843

844 8
    def get_keywords(
845
        self,
846
        id: Union[str, list] = None,
847
        hash_index: Union[str, list] = None,
848
        limit: int = None,
849
        skip: int = 0,
850
        return_json: bool = False,
851
        with_ids: bool = True,
852
    ) -> List[KeywordSet]:
853
        """Search for one (unique) option based on the 'program'
854
        and the 'name'. No overwrite allowed.
855

856
        Parameters
857
        ----------
858
        id : List[str] or str
859
            Ids of the keywords
860
        hash_index : List[str] or str
861
            hash index of keywords
862
        limit : Optional[int], optional
863
            Maximum number of results to return.
864
            If this number is greater than the SQLAlchemySocket.max_limit then
865
            the max_limit will be returned instead.
866
            Default is to return the socket's max_limit (when limit=None or 0)
867
        skip : int, optional
868
        return_json : bool, optional
869
            Return the results as a json object
870
            Default is True
871
        with_ids : bool, optional
872
            Include the DB ids in the returned object (names 'id')
873
            Default is True
874

875

876
        Returns
877
        -------
878
            A dict with keys: 'data' and 'meta'
879
            (see get_metadata_template())
880
            The 'data' part is an object of the result or None if not found
881
        """
882

883 8
        meta = get_metadata_template()
884 8
        query = format_query(KeywordsORM, id=id, hash_index=hash_index)
885

886 8
        rdata, meta["n_found"] = self.get_query_projection(
887
            KeywordsORM, query, limit=limit, skip=skip, exclude=[None if with_ids else "id"]
888
        )
889

890 8
        meta["success"] = True
891

892
        # meta['error_description'] = str(err)
893

894 8
        if not return_json:
895 8
            data = [KeywordSet(**d) for d in rdata]
896
        else:
897 0
            data = rdata
898

899 8
        return {"data": data, "meta": meta}
900

901 8
    def get_add_keywords_mixed(self, data):
902
        """
903
        Get or add the given options (if they don't exit).
904
        KeywordsORM are given in a mixed format, either as a dict of mol data
905
        or as existing mol id
906

907
        TODO: to be split into get by_id and get_by_data
908
        """
909

910 8
        meta = get_metadata_template()
911

912 8
        ids = []
913 8
        for idx, kw in enumerate(data):
914 8
            if isinstance(kw, (int, str)):
915 8
                ids.append(kw)
916

917 8
            elif isinstance(kw, KeywordSet):
918 8
                new_id = self.add_keywords([kw])["data"][0]
919 8
                ids.append(new_id)
920
            else:
921 0
                meta["errors"].append((idx, "Data type not understood"))
922 0
                ids.append(None)
923

924 8
        missing = []
925 8
        ret = []
926 8
        for idx, id in enumerate(ids):
927 8
            if id is None:
928 0
                ret.append(None)
929 0
                missing.append(idx)
930 0
                continue
931

932 8
            tmp = self.get_keywords(id=id)["data"]
933 8
            if tmp:
934 8
                ret.append(tmp[0])
935
            else:
936 8
                ret.append(None)
937

938 8
        meta["success"] = True
939 8
        meta["n_found"] = len(ret) - len(missing)
940 8
        meta["missing"] = missing
941

942 8
        return {"meta": meta, "data": ret}
943

944 8
    def del_keywords(self, id: str) -> int:
945
        """
946
        Removes a option set from the database based on its id.
947

948
        Parameters
949
        ----------
950
        id : str
951
            id of the keyword
952

953
        Returns
954
        -------
955
        int
956
           number of deleted documents
957
        """
958

959 8
        count = 0
960 8
        with self.session_scope() as session:
961 8
            count = session.query(KeywordsORM).filter_by(id=id).delete(synchronize_session=False)
962

963 8
        return count
964

965
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~`
966

967
    ### database functions
968

969 8
    def add_collection(self, data: Dict[str, Any], overwrite: bool = False):
970
        """Add (or update) a collection to the database.
971

972
        Parameters
973
        ----------
974
        data : Dict[str, Any]
975
            should inlcude at least(keys):
976
            collection : str (immutable)
977
            name : str (immutable)
978

979
        overwrite : bool
980
            Update existing collection
981

982
        Returns
983
        -------
984
        Dict[str, Any]
985
        A dict with keys: 'data' and 'meta'
986
            (see add_metadata_template())
987
            The 'data' part is the id of the inserted document or none
988

989
        Notes
990
        -----
991
        ** Change: The data doesn't have to include the ID, the document
992
        is identified by the (collection, name) pairs.
993
        ** Change: New fields will be added to the collection, but existing won't
994
            be removed.
995
        """
996

997 8
        meta = add_metadata_template()
998 8
        col_id = None
999
        # try:
1000

1001
        # if ("id" in data) and (data["id"] == "local"):
1002
        #     data.pop("id", None)
1003 8
        if "id" in data:  # remove the ID in any case
1004 8
            data.pop("id", None)
1005 8
        lname = data.get("name").lower()
1006 8
        collection = data.pop("collection").lower()
1007

1008
        # Get collection class if special type is implemented
1009 8
        collection_class = get_collection_class(collection)
1010

1011 8
        update_fields = {}
1012 8
        for field in collection_class._all_col_names():
1013 8
            if field in data:
1014 8
                update_fields[field] = data.pop(field)
1015

1016 8
        update_fields["extra"] = data  # todo: check for sql injection
1017

1018 8
        with self.session_scope() as session:
1019

1020 8
            try:
1021 8
                if overwrite:
1022 8
                    col = session.query(collection_class).filter_by(collection=collection, lname=lname).first()
1023 8
                    for key, value in update_fields.items():
1024 8
                        setattr(col, key, value)
1025
                else:
1026 8
                    col = collection_class(collection=collection, lname=lname, **update_fields)
1027

1028 8
                session.add(col)
1029 8
                session.commit()
1030 8
                col.update_relations(**update_fields)
1031 8
                session.commit()
1032

1033 8
                col_id = str(col.id)
1034 8
                meta["success"] = True
1035 8
                meta["n_inserted"] = 1
1036

1037 8
            except Exception as err:
1038 8
                session.rollback()
1039 8
                meta["error_description"] = str(err)
1040

1041 8
        ret = {"data": col_id, "meta": meta}
1042 8
        return ret
1043

1044 8
    def get_collections(
1045
        self,
1046
        collection: Optional[str] = None,
1047
        name: Optional[str] = None,
1048
        col_id: Optional[int] = None,
1049
        limit: Optional[int] = None,
1050
        include: Optional[List[str]] = None,
1051
        exclude: Optional[List[str]] = None,
1052
        skip: int = 0,
1053
    ) -> Dict[str, Any]:
1054
        """Get collection by collection and/or name
1055

1056
        Parameters
1057
        ----------
1058
        collection: Optional[str], optional
1059
            Type of the collection, e.g. ReactionDataset
1060
        name: Optional[str], optional
1061
            Name of the collection, e.g. S22
1062
        col_id: Optional[int], optional
1063
            Database id of the collection
1064
        limit: Optional[int], optional
1065
            Maximum number of results to return
1066
        include: Optional[List[str]], optional
1067
            Columns to return
1068
        exclude: Optional[List[str]], optional
1069
            Return all but these columns
1070
        skip: int, optional
1071
            Skip the first `skip` results
1072

1073
        Returns
1074
        -------
1075
        Dict[str, Any]
1076
            A dict with keys: 'data' and 'meta'
1077
            The data is a list of the collections found
1078
        """
1079

1080 8
        meta = get_metadata_template()
1081 8
        if name:
1082 8
            name = name.lower()
1083 8
        if collection:
1084 8
            collection = collection.lower()
1085

1086 8
        collection_class = get_collection_class(collection)
1087 8
        query = format_query(collection_class, lname=name, collection=collection, id=col_id)
1088

1089
        # try:
1090 8
        rdata, meta["n_found"] = self.get_query_projection(
1091
            collection_class, query, include=include, exclude=exclude, limit=limit, skip=skip
1092
        )
1093

1094 8
        meta["success"] = True
1095
        # except Exception as err:
1096
        #     meta['error_description'] = str(err)
1097

1098 8
        return {"data": rdata, "meta": meta}
1099

1100 8
    def del_collection(
1101
        self, collection: Optional[str] = None, name: Optional[str] = None, col_id: Optional[int] = None
1102
    ) -> bool:
1103
        """
1104
        Remove a collection from the database from its keys.
1105

1106
        Parameters
1107
        ----------
1108
        collection: Optional[str], optional
1109
            CollectionORM type
1110
        name : Optional[str], optional
1111
            CollectionORM name
1112
        col_id: Optional[int], optional
1113
            Database id of the collection
1114
        Returns
1115
        -------
1116
        int
1117
            Number of documents deleted
1118
        """
1119

1120
        # Assuming here that we don't want to allow deletion of all collections, all datasets, etc.
1121 8
        if not (col_id is not None or (collection is not None and name is not None)):
1122 0
            raise ValueError(
1123
                "Either col_id ({col_id}) must be specified, or collection ({collection}) and name ({name}) must be specified."
1124
            )
1125

1126 8
        filter_spec = {}
1127 8
        if collection is not None:
1128 8
            filter_spec["collection"] = collection.lower()
1129 8
        if name is not None:
1130 8
            filter_spec["lname"] = name.lower()
1131 8
        if col_id is not None:
1132 8
            filter_spec["id"] = col_id
1133

1134 8
        with self.session_scope() as session:
1135 8
            count = session.query(CollectionORM).filter_by(**filter_spec).delete(synchronize_session=False)
1136 8
        return count
1137

1138
    ## ResultORMs functions
1139

1140 8
    def add_results(self, record_list: List[ResultRecord]):
1141
        """
1142
        Add results from a given dict. The dict should have all the required
1143
        keys of a result.
1144

1145
        Parameters
1146
        ----------
1147
        data : List[ResultRecord]
1148
            Each dict in the list must have:
1149
            program, driver, method, basis, options, molecule
1150
            Where molecule is the molecule id in the DB
1151
            In addition, it should have the other attributes that it needs
1152
            to store
1153

1154
        Returns
1155
        -------
1156
        Dict[str, Any]
1157
            Dict with keys: data, meta
1158
            Data is the ids of the inserted/updated/existing docs, in the same order as the
1159
            input record_list
1160
        """
1161

1162 8
        meta = add_metadata_template()
1163

1164 8
        results_list = []
1165 8
        duplicates_list = []
1166

1167
        # Stores indices referring to elements in record_list
1168 8
        new_record_idx, duplicates_idx = [], []
1169

1170
        # creating condition for a multi-value select
1171
        # This can be used to query for multiple results in a single query
1172 8
        conds = [
1173
            and_(
1174
                ResultORM.program == res.program,
1175
                ResultORM.driver == res.driver,
1176
                ResultORM.method == res.method,
1177
                ResultORM.basis == res.basis,
1178
                ResultORM.keywords == res.keywords,
1179
                ResultORM.molecule == res.molecule,
1180
            )
1181
            for res in record_list
1182
        ]
1183

1184 8
        with self.session_scope() as session:
1185
            # Query for all existing
1186
            # TODO: RACE CONDITION: Records could be inserted between this query and inserting later
1187

1188 8
            existing_results = {}
1189

1190 8
            for cond in conds:
1191 8
                doc = (
1192
                    session.query(
1193
                        ResultORM.program,
1194
                        ResultORM.driver,
1195
                        ResultORM.method,
1196
                        ResultORM.basis,
1197
                        ResultORM.keywords,
1198
                        ResultORM.molecule,
1199
                        ResultORM.id,
1200
                    )
1201
                    .filter(cond)
1202
                    .one_or_none()
1203
                )
1204

1205 8
                if doc is not None:
1206 8
                    existing_results[
1207
                        (doc.program, doc.driver, doc.method, doc.basis, doc.keywords, str(doc.molecule))
1208
                    ] = doc
1209

1210
            # Loop over all (input) records, keeping track each record's index in the list
1211 8
            for i, result in enumerate(record_list):
1212
                # constructing an index from the record compare against items existing_results
1213 8
                idx = (
1214
                    result.program,
1215
                    result.driver.value,
1216
                    result.method,
1217
                    result.basis,
1218
                    int(result.keywords) if result.keywords else None,
1219
                    result.molecule,
1220
                )
1221

1222 8
                if idx not in existing_results:
1223
                    # Does not exist in the database. Construct a new ResultORM
1224 8
                    doc = ResultORM(**result.dict(exclude={"id"}))
1225

1226
                    # Store in existing_results in case later records are duplicates
1227 8
                    existing_results[idx] = doc
1228

1229
                    # add the object to the list for later adding and committing to database.
1230 8
                    results_list.append(doc)
1231

1232
                    # Store the index of this record (in record_list) as a new_record
1233 8
                    new_record_idx.append(i)
1234 8
                    meta["n_inserted"] += 1
1235
                else:
1236
                    # This result already exists in the database
1237 8
                    doc = existing_results[idx]
1238

1239
                    # Store the index of this record (in record_list) as a new_record
1240 8
                    duplicates_idx.append(i)
1241

1242
                    # Store the entire object. Since this may be a duplicate of a record
1243
                    # added in a previous iteration of the loop, and the data hasn't been added/committed
1244
                    # to the database, the id may not be known here
1245 8
                    duplicates_list.append(doc)
1246

1247 8
            session.add_all(results_list)
1248 8
            session.commit()
1249

1250
            # At this point, all ids should be known. So store only the ids in the returned metadata
1251 8
            meta["duplicates"] = [str(doc.id) for doc in duplicates_list]
1252

1253
            # Construct the ID list to return (in the same order as the input data)
1254
            # Use a placeholder for all, and we will fill later
1255 8
            result_ids = [None] * len(record_list)
1256

1257
            # At this point:
1258
            #     results_list: ORM objects for all newly-added results
1259
            #     new_record_idx: indices (referring to record_list) of newly-added results
1260
            #     duplicates_idx: indices (referring to record_list) of results that already existed
1261
            #
1262
            # results_list and new_record_idx are in the same order
1263
            # (ie, the index stored at new_record_idx[0] refers to some element of record_list. That
1264
            # newly-added ResultORM is located at results_list[0])
1265
            #
1266
            # Similarly, duplicates_idx and meta["duplicates"] are in the same order
1267

1268 8
            for idx, new_result in zip(new_record_idx, results_list):
1269 8
                result_ids[idx] = str(new_result.id)
1270

1271
            # meta["duplicates"] only holds ids at this point
1272 8
            for idx, existing_result_id in zip(duplicates_idx, meta["duplicates"]):
1273 8
                result_ids[idx] = existing_result_id
1274

1275 8
        assert None not in result_ids
1276

1277 8
        meta["success"] = True
1278

1279 8
        ret = {"data": result_ids, "meta": meta}
1280 8
        return ret
1281

1282 8
    def update_results(self, record_list: List[ResultRecord]):
1283
        """
1284
        Update results from a given dict (replace existing)
1285

1286
        Parameters
1287
        ----------
1288
        id : list of str
1289
            Ids of the results to update, must exist in the DB
1290
        data : list of dict
1291
            Data that needs to be updated
1292
            Shouldn't update:
1293
            program, driver, method, basis, options, molecule
1294

1295
        Returns
1296
        -------
1297
            number of records updated
1298
        """
1299 4
        query_ids = [res.id for res in record_list]
1300
        # find duplicates among ids
1301 4
        duplicates = len(query_ids) != len(set(query_ids))
1302

1303 4
        with self.session_scope() as session:
1304

1305 4
            found = session.query(ResultORM).filter(ResultORM.id.in_(query_ids)).all()
1306
            # found items are stored in a dictionary
1307 4
            found_dict = {str(record.id): record for record in found}
1308

1309 4
            updated_count = 0
1310 4
            for result in record_list:
1311

1312 4
                if result.id is None:
1313 0
                    self.logger.error("Attempted update without ID, skipping")
1314 0
                    continue
1315

1316 4
                data = result.dict(exclude={"id"})
1317
                # retrieve the found item
1318 4
                found_db = found_dict[result.id]
1319

1320
                # updating the found item with input attribute values.
1321 4
                for attr, val in data.items():
1322 4
                    setattr(found_db, attr, val)
1323

1324
                # if any duplicate ids are found in the input, commit should be called each iteration
1325 4
                if duplicates:
1326 0
                    session.commit()
1327 4
                updated_count += 1
1328
            # if no duplicates found, only commit at the end of the loop.
1329 4
            if not duplicates:
1330 4
                session.commit()
1331

1332 4
        return updated_count
1333

1334 8
    def get_results_count(self):
1335
        """
1336
        TODO: just return the count, used for big queries
1337

1338
        Returns
1339
        -------
1340

1341
        """
1342

1343 8
    def get_results(
1344
        self,
1345
        id: Union[str, List] = None,
1346
        program: str = None,
1347
        method: str = None,
1348
        basis: str = None,
1349
        molecule: str = None,
1350
        driver: str = None,
1351
        keywords: str = None,
1352
        task_id: Union[str, List] = None,
1353
        manager_id: Union[str, List] = None,
1354
        status: str = "COMPLETE",
1355
        include: Optional[List[str]] = None,
1356
        exclude: Optional[List[str]] = None,
1357
        limit: int = None,
1358
        skip: int = 0,
1359
        return_json=True,
1360
        with_ids=True,
1361
    ):
1362
        """
1363

1364
        Parameters
1365
        ----------
1366
        id : str or List[str]
1367
        program : str
1368
        method : str
1369
        basis : str
1370
        molecule : str
1371
            MoleculeORM id in the DB
1372
        driver : str
1373
        keywords : str
1374
            The id of the option in the DB
1375
        task_id: str or List[str]
1376
            id or a list of ids of tasks
1377
        manager_id: str or List[str]
1378
            id or a list of ids of queue_mangers
1379
        status : bool, optional
1380
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1381
            Default is 'COMPLETE'
1382
        include : Optional[List[str]], optional
1383
            The fields to return, default to return all
1384
        exclude : Optional[List[str]], optional
1385
            The fields to not return, default to return all
1386
        limit : Optional[int], optional
1387
            maximum number of results to return
1388
            if 'limit' is greater than the global setting self._max_limit,
1389
            the self._max_limit will be returned instead
1390
            (This is to avoid overloading the server)
1391
        skip : int, optional
1392
            skip the first 'skip' results. Used to paginate
1393
            Default is 0
1394
        return_json : bool, optional
1395
            Return the results as a list of json inseated of objects
1396
            default is True
1397
        with_ids : bool, optional
1398
            Include the ids in the returned objects/dicts
1399
            default is True
1400

1401
        Returns
1402
        -------
1403
        Dict[str, Any]
1404
            Dict with keys: data, meta
1405
            Data is the objects found
1406
        """
1407

1408 8
        if task_id:
1409 0
            return self._get_results_by_task_id(task_id)
1410

1411 8
        meta = get_metadata_template()
1412

1413
        # Ignore status if Id is present
1414 8
        if id is not None:
1415 8
            status = None
1416

1417 8
        query = format_query(
1418
            ResultORM,
1419
            id=id,
1420
            program=program,
1421
            method=method,
1422
            basis=basis,
1423
            molecule=molecule,
1424
            driver=driver,
1425
            keywords=keywords,
1426
            manager_id=manager_id,
1427
            status=status,
1428
        )
1429

1430 8
        data, meta["n_found"] = self.get_query_projection(
1431
            ResultORM, query, include=include, exclude=exclude, limit=limit, skip=skip
1432
        )
1433 8
        meta["success"] = True
1434

1435 8
        return {"data": data, "meta": meta}
1436

1437 8
    def _get_results_by_task_id(self, task_id: Union[str, List] = None, return_json=True):
1438
        """
1439

1440
        Parameters
1441
        ----------
1442
        task_id : str or List[str]
1443

1444
        return_json : bool, optional
1445
            Return the results as a list of json inseated of objects
1446
            Default is True
1447

1448
        Returns
1449
        -------
1450
        Dict[str, Any]
1451
            Dict with keys: data, meta
1452
            Data is the objects found
1453
        """
1454

1455 0
        meta = get_metadata_template()
1456

1457 0
        data = []
1458 0
        task_id_list = [task_id] if isinstance(task_id, (int, str)) else task_id
1459
        # try:
1460 0
        with self.session_scope() as session:
1461 0
            data = (
1462
                session.query(BaseResultORM)
1463
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
1464
                .filter(TaskQueueORM.id.in_(task_id_list))
1465
            )
1466 0
            meta["n_found"] = get_count_fast(data)
1467 0
            data = [d.to_dict() for d in data.all()]
1468 0
            meta["success"] = True
1469
            # except Exception as err:
1470
            #     meta['error_description'] = str(err)
1471

1472 0
        return {"data": data, "meta": meta}
1473

1474 8
    def del_results(self, ids: List[str]):
1475
        """
1476
        Removes results from the database using their ids
1477
        (Should be cautious! other tables maybe referencing results)
1478

1479
        Parameters
1480
        ----------
1481
        ids : List[str]
1482
            The Ids of the results to be deleted
1483

1484
        Returns
1485
        -------
1486
        int
1487
            number of results deleted
1488
        """
1489

1490 8
        with self.session_scope() as session:
1491 8
            results = session.query(ResultORM).filter(ResultORM.id.in_(ids)).all()
1492
            # delete through session to delete correctly from base_result
1493 8
            for result in results:
1494 8
                session.delete(result)
1495 8
            session.commit()
1496 8
            count = len(results)
1497

1498 8
        return count
1499

1500 8
    def add_wavefunction_store(self, blobs_list: List[Dict[str, Any]]):
1501
        """
1502
        Adds to the wavefunction key/value store table.
1503

1504
        Parameters
1505
        ----------
1506
        blobs_list : List[Dict[str, Any]]
1507
            A list of wavefunction data blobs to add.
1508

1509
        Returns
1510
        -------
1511
        Dict[str, Any]
1512
            Dict with keys data and meta, where data represent the blob_ids of inserted wavefuction data blobs.
1513
        """
1514

1515 1
        meta = add_metadata_template()
1516 1
        blob_ids = []
1517 1
        with self.session_scope() as session:
1518 1
            for blob in blobs_list:
1519 1
                if blob is None:
1520 0
                    blob_ids.append(None)
1521 0
                    continue
1522

1523 1
                doc = WavefunctionStoreORM(**blob)
1524 1
                session.add(doc)
1525 1
                session.commit()
1526 1
                blob_ids.append(str(doc.id))
1527 1
                meta["n_inserted"] += 1
1528

1529 1
        meta["success"] = True
1530

1531 1
        return {"data": blob_ids, "meta": meta}
1532

1533 8
    def get_wavefunction_store(
1534
        self,
1535
        id: List[str] = None,
1536
        include: Optional[List[str]] = None,
1537
        exclude: Optional[List[str]] = None,
1538
        limit: int = None,
1539
        skip: int = 0,
1540
    ) -> Dict[str, Any]:
1541
        """
1542
        Pulls from the wavefunction key/value store table.
1543

1544
        Parameters
1545
        ----------
1546
        id : List[str], optional
1547
            A list of ids to query
1548
        include : Optional[List[str]], optional
1549
            The fields to return, default to return all
1550
        exclude : Optional[List[str]], optional
1551
            The fields to not return, default to return all
1552
        limit : int, optional
1553
            Maximum number of results to return.
1554
            Default is set to 0
1555
        skip : int, optional
1556
            Skips a number of results in the query, used for pagination
1557
            Default is set to 0
1558

1559
        Returns
1560
        -------
1561
        Dict[str, Any]
1562
            Dictionary with keys data and meta, where data is the found wavefunction items
1563
        """
1564

1565 1
        meta = get_metadata_template()
1566

1567 1
        query = format_query(WavefunctionStoreORM, id=id)
1568 1
        rdata, meta["n_found"] = self.get_query_projection(
1569
            WavefunctionStoreORM, query, limit=limit, skip=skip, include=include, exclude=exclude
1570
        )
1571

1572 1
        meta["success"] = True
1573

1574 1
        return {"data": rdata, "meta": meta}
1575

1576
    ### Mongo procedure/service functions
1577

1578 8
    def add_procedures(self, record_list: List["BaseRecord"]):
1579
        """
1580
        Add procedures from a given dict. The dict should have all the required
1581
        keys of a result.
1582

1583
        Parameters
1584
        ----------
1585
        record_list : List["BaseRecord"]
1586
            Each dict must have:
1587
            procedure, program, keywords, qc_meta, hash_index
1588
            In addition, it should have the other attributes that it needs
1589
            to store
1590

1591
        Returns
1592
        -------
1593
        Dict[str, Any]
1594
            Dictionary with keys data and meta, data is the ids of the inserted/updated/existing docs
1595
        """
1596

1597 8
        meta = add_metadata_template()
1598

1599 8
        if not record_list:
1600 0
            return {"data": [], "meta": meta}
1601

1602 8
        procedure_class = get_procedure_class(record_list[0])
1603

1604 8
        procedure_ids = []
1605 8
        with self.session_scope() as session:
1606 8
            for procedure in record_list:
1607 8
                doc = session.query(procedure_class).filter_by(hash_index=procedure.hash_index)
1608

1609 8
                if get_count_fast(doc) == 0:
1610 8
                    data = procedure.dict(exclude={"id"})
1611 8
                    proc_db = procedure_class(**data)
1612 8
                    session.add(proc_db)
1613 8
                    session.commit()
1614 8
                    proc_db.update_relations(**data)
1615 8
                    session.commit()
1616 8
                    procedure_ids.append(str(proc_db.id))
1617 8
                    meta["n_inserted"] += 1
1618
                else:
1619 1
                    id = str(doc.first().id)
1620 1
                    meta["duplicates"].append(id)  # TODO
1621 1
                    procedure_ids.append(id)
1622 8
        meta["success"] = True
1623

1624 8
        ret = {"data": procedure_ids, "meta": meta}
1625 8
        return ret
1626

1627 8
    def get_procedures(
1628
        self,
1629
        id: Union[str, List] = None,
1630
        procedure: str = None,
1631
        program: str = None,
1632
        hash_index: str = None,
1633
        task_id: Union[str, List] = None,
1634
        manager_id: Union[str, List] = None,
1635
        status: str = "COMPLETE",
1636
        include=None,
1637
        exclude=None,
1638
        limit: int = None,
1639
        skip: int = 0,
1640
        return_json=True,
1641
        with_ids=True,
1642
    ):
1643
        """
1644

1645
        Parameters
1646
        ----------
1647
        id : str or List[str]
1648
        procedure : str
1649
        program : str
1650
        hash_index : str
1651
        task_id : str or List[str]
1652
        status : bool, optional
1653
            The status of the result: 'COMPLETE', 'INCOMPLETE', or 'ERROR'
1654
            Default is 'COMPLETE'
1655
        include : Optional[List[str]], optional
1656
            The fields to return, default to return all
1657
        exclude : Optional[List[str]], optional
1658
            The fields to not return, default to return all
1659
        limit : Optional[int], optional
1660
            maximum number of results to return
1661
            if 'limit' is greater than the global setting self._max_limit,
1662
            the self._max_limit will be returned instead
1663
            (This is to avoid overloading the server)
1664
        skip : int, optional
1665
            skip the first 'skip' resaults. Used to paginate
1666
            Default is 0
1667
        return_json : bool, optional
1668
            Return the results as a list of json inseated of objects
1669
            Default is True
1670
        with_ids : bool, optional
1671
            Include the ids in the returned objects/dicts
1672
            Default is True
1673

1674
        Returns
1675
        -------
1676
        Dict[str, Any]
1677
            Dict with keys: data and meta. Data is the objects found
1678
        """
1679

1680 8
        meta = get_metadata_template()
1681

1682 8
        if id is not None or task_id is not None:
1683 1
            status = None
1684

1685 8
        if procedure == "optimization":
1686 8
            className = OptimizationProcedureORM
1687 8
        elif procedure == "torsiondrive":
1688 8
            className = TorsionDriveProcedureORM
1689 8
        elif procedure == "gridoptimization":
1690 0
            className = GridOptimizationProcedureORM
1691
        else:
1692
            # raise TypeError('Unsupported procedure type {}. Id: {}, task_id: {}'
1693
            #                 .format(procedure, id, task_id))
1694 8
            className = BaseResultORM  # all classes, including those with 'selectin'
1695 8
            program = None  # make sure it's not used
1696 8
            if id is None:
1697 8
                self.logger.error(f"Procedure type not specified({procedure}), and ID is not given.")
1698 8
                raise KeyError("ID is required if procedure type is not specified.")
1699

1700 8
        query = format_query(
1701
            className,
1702
            id=id,
1703
            procedure=procedure,
1704
            program=program,
1705
            hash_index=hash_index,
1706
            task_id=task_id,
1707
            manager_id=manager_id,
1708
            status=status,
1709
        )
1710

1711 8
        data = []
1712 8
        try:
1713
            # TODO: decide a way to find the right type
1714

1715 8
            data, meta["n_found"] = self.get_query_projection(
1716
                className, query, limit=limit, skip=skip, include=include, exclude=exclude
1717
            )
1718 8
            meta["success"] = True
1719 0
        except Exception as err:
1720 0
            meta["error_description"] = str(err)
1721

1722 8
        return {"data": data, "meta": meta}
1723

1724 8
    def update_procedures(self, records_list: List["BaseRecord"]):
1725
        """
1726
        TODO: needs to be of specific type
1727
        """
1728

1729 8
        updated_count = 0
1730 8
        with self.session_scope() as session:
1731 8
            for procedure in records_list:
1732

1733 8
                className = get_procedure_class(procedure)
1734
                # join_table = get_procedure_join(procedure)
1735
                # Must have ID
1736 8
                if procedure.id is None:
1737 0
                    self.logger.error(
1738
                        "No procedure id found on update (hash_index={}), skipping.".format(procedure.hash_index)
1739
                    )
1740 0
                    continue
1741

1742 8
                proc_db = session.query(className).filter_by(id=procedure.id).first()
1743

1744 8
                data = procedure.dict(exclude={"id"})
1745 8
                proc_db.update_relations(**data)
1746

1747 8
                for attr, val in data.items():
1748 8
                    setattr(proc_db, attr, val)
1749

1750
                # session.add(proc_db)
1751

1752
                # Upsert relations (insert or update)
1753
                # needs primarykeyconstraint on the table keys
1754
                # for result_id in procedure.trajectory:
1755
                #     statement = postgres_insert(opt_result_association)\
1756
                #         .values(opt_id=procedure.id, result_id=result_id)\
1757
                #         .on_conflict_do_update(
1758
                #             index_elements=[opt_result_association.c.opt_id, opt_result_association.c.result_id],
1759
                #             set_=dict(result_id=result_id))
1760
                #     session.execute(statement)
1761

1762 8
                session.commit()
1763 8
                updated_count += 1
1764

1765
        # session.commit()  # save changes, takes care of inheritance
1766

1767 8
        return updated_count
1768

1769 8
    def del_procedures(self, ids: List[str]):
1770
        """
1771
        Removes results from the database using their ids
1772
        (Should be cautious! other tables maybe referencing results)
1773

1774
        Parameters
1775
        ----------
1776
        ids : List[str]
1777
            The Ids of the results to be deleted
1778

1779
        Returns
1780
        -------
1781
        int
1782
            number of results deleted
1783
        """
1784

1785 8
        with self.session_scope() as session:
1786 8
            procedures = (
1787
                session.query(
1788
                    with_polymorphic(
1789
                        BaseResultORM,
1790
                        [OptimizationProcedureORM, TorsionDriveProcedureORM, GridOptimizationProcedureORM],
1791
                    )
1792
                )
1793
                .filter(BaseResultORM.id.in_(ids))
1794
                .all()
1795
            )
1796
            # delete through session to delete correctly from base_result
1797 8
            for proc in procedures:
1798 8
                session.delete(proc)
1799
            # session.commit()
1800 8
            count = len(procedures)
1801

1802 8
        return count
1803

1804 8
    def add_services(self, service_list: List["BaseService"]):
1805
        """
1806
        Add services from a given list of dict.
1807

1808
        Parameters
1809
        ----------
1810
        services_list : List[Dict[str, Any]]
1811
            List of services to be added
1812
        Returns
1813
        -------
1814
        Dict[str, Any]
1815
            Dict with keys: data, meta. Data is the hash_index of the inserted/existing docs
1816
        """
1817

1818 8
        meta = add_metadata_template()
1819

1820 8
        procedure_ids = []
1821 8
        with self.session_scope() as session:
1822 8
            for service in service_list:
1823

1824
                # Add the underlying procedure
1825 8
                new_procedure = self.add_procedures([service.output])
1826

1827
                # ProcedureORM already exists
1828 8
                proc_id = new_procedure["data"][0]
1829

1830 8
                if new_procedure["meta"]["duplicates"]:
1831 1
                    procedure_ids.append(proc_id)
1832 1
                    meta["duplicates"].append(proc_id)
1833 1
                    continue
1834

1835
                # search by hash index
1836 8
                doc = session.query(ServiceQueueORM).filter_by(hash_index=service.hash_index)
1837 8
                service.procedure_id = proc_id
1838

1839 8
                if doc.count() == 0:
1840 8
                    doc = ServiceQueueORM(**service.dict(include=set(ServiceQueueORM.__dict__.keys())))
1841 8
                    doc.extra = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1842 8
                    doc.priority = doc.priority.value  # Must be an integer for sorting
1843 8
                    session.add(doc)
1844 8
                    session.commit()  # TODO
1845 8
                    procedure_ids.append(proc_id)
1846 8
                    meta["n_inserted"] += 1
1847
                else:
1848 0
                    procedure_ids.append(None)
1849 0
                    meta["errors"].append((doc.id, "Duplicate service, but not caught by procedure."))
1850

1851 8
        meta["success"] = True
1852

1853 8
        ret = {"data": procedure_ids, "meta": meta}
1854 8
        return ret
1855

1856 8
    def get_services(
1857
        self,
1858
        id: Union[List[str], str] = None,
1859
        procedure_id: Union[List[str], str] = None,
1860
        hash_index: Union[List[str], str] = None,
1861
        status: str = None,
1862
        limit: int = None,
1863
        skip: int = 0,
1864
        return_json=True,
1865
    ):
1866
        """
1867

1868
        Parameters
1869
        ----------
1870
        id / hash_index : List[str] or str, optional
1871
            service id
1872
        procedure_id : List[str] or str, optional
1873
            procedure_id for the specific procedure
1874
        status : str, optional
1875
            status of the record queried for
1876
        limit : Optional[int], optional
1877
            maximum number of results to return
1878
            if 'limit' is greater than the global setting self._max_limit,
1879
            the self._max_limit will be returned instead
1880
            (This is to avoid overloading the server)
1881
        skip : int, optional
1882
            skip the first 'skip' resaults. Used to paginate
1883
            Default is 0
1884
        return_json : bool, deafult is True
1885
            Return the results as a list of json instead of objects
1886

1887
        Returns
1888
        -------
1889
        Dict[str, Any]
1890
            Dict with keys: data, meta. Data is the objects found
1891
        """
1892

1893 8
        meta = get_metadata_template()
1894 8
        query = format_query(ServiceQueueORM, id=id, hash_index=hash_index, procedure_id=procedure_id, status=status)
1895

1896 8
        with self.session_scope() as session:
1897 8
            data = (
1898
                session.query(ServiceQueueORM)
1899
                .filter(*query)
1900
                .order_by(ServiceQueueORM.priority.desc(), ServiceQueueORM.created_on)
1901
                .limit(limit)
1902
                .offset(skip)
1903
                .all()
1904
            )
1905 8
            data = [x.to_dict() for x in data]
1906

1907 8
        meta["n_found"] = len(data)
1908 8
        meta["success"] = True
1909

1910
        # except Exception as err:
1911
        #     meta['error_description'] = str(err)
1912

1913 8
        return {"data": data, "meta": meta}
1914

1915 8
    def update_services(self, records_list: List["BaseService"]) -> int:
1916
        """
1917
        Replace existing service
1918

1919
        Raises exception if the id is invalid
1920

1921
        Parameters
1922
        ----------
1923
        records_list: List[Dict[str, Any]]
1924
            List of Service items to be updated using their id
1925

1926
        Returns
1927
        -------
1928
        int
1929
            number of updated services
1930
        """
1931

1932 8
        updated_count = 0
1933 8
        for service in records_list:
1934 8
            if service.id is None:
1935 0
                self.logger.error("No service id found on update (hash_index={}), skipping.".format(service.hash_index))
1936 0
                continue
1937

1938 8
            with self.session_scope() as session:
1939

1940 8
                doc_db = session.query(ServiceQueueORM).filter_by(id=service.id).first()
1941

1942 8
                data = service.dict(include=set(ServiceQueueORM.__dict__.keys()))
1943 8
                data["extra"] = service.dict(exclude=set(ServiceQueueORM.__dict__.keys()))
1944

1945 8
                data["id"] = int(data["id"])
1946 8
                for attr, val in data.items():
1947 8
                    setattr(doc_db, attr, val)
1948

1949 8
                session.add(doc_db)
1950 8
                session.commit()
1951

1952 8
            procedure = service.output
1953 8
            procedure.__dict__["id"] = service.procedure_id
1954

1955
            # Copy the stdout/error from the service itself to its procedure
1956 8
            if service.stdout:
1957 1
                stdout = KVStore(data=service.stdout)
1958 1
                stdout_id = self.add_kvstore([stdout])["data"][0]
1959 1
                procedure.__dict__["stdout"] = stdout_id
1960 8
            if service.error:
1961 1
                error = KVStore(data=service.error)
1962 1
                error_id = self.add_kvstore([error])["data"][0]
1963 1
                procedure.__dict__["error"] = error_id
1964

1965 8
            self.update_procedures([procedure])
1966

1967 8
            updated_count += 1
1968

1969 8
        return updated_count
1970

1971 8
    def update_service_status(
1972
        self, status: str, id: Union[List[str], str] = None, procedure_id: Union[List[str], str] = None
1973
    ) -> int:
1974
        """
1975
        Update the status of the existing services in the database.
1976

1977
        Raises an exception if any of the ids are invalid.
1978
        Parameters
1979
        ----------
1980
        status : str
1981
            The input status string ready to replace the previous status
1982
        id : Optional[Union[List[str], str]], optional
1983
            ids of all the services requested to be updated, by default None
1984
        procedure_id : Optional[Union[List[str], str]], optional
1985
            procedure_ids for the specific procedures, by default None
1986

1987
        Returns
1988
        -------
1989
        int
1990
            1 indicating that the status update was successful
1991
        """
1992

1993 1
        if (id is None) and (procedure_id is None):
1994 0
            raise KeyError("id or procedure_id must not be None.")
1995

1996 1
        status = status.lower()
1997 1
        with self.session_scope() as session:
1998

1999 1
            query = format_query(ServiceQueueORM, id=id, procedure_id=procedure_id)
2000

2001
            # Update the service
2002 1
            service = session.query(ServiceQueueORM).filter(*query).first()
2003 1
            service.status = status
2004

2005
            # Update the procedure
2006 1
            if status == "waiting":
2007 0
                status = "incomplete"
2008 1
            session.query(BaseResultORM).filter(BaseResultORM.id == service.procedure_id).update({"status": status})
2009

2010 1
            session.commit()
2011

2012 1
        return 1
2013

2014 8
    def services_completed(self, records_list: List["BaseService"]) -> int:
2015
        """
2016
        Delete the services which are completed from the database.
2017

2018
        Parameters
2019
        ----------
2020
        records_list : List["BaseService"]
2021
            List of Service objects which are completed.
2022

2023
        Returns
2024
        -------
2025
        int
2026
            Number of deleted active services from database.
2027
        """
2028 4
        done = 0
2029 4
        for service in records_list:
2030 1
            if service.id is None:
2031 0
                self.logger.error(
2032
                    "No service id found on completion (hash_index={}), skipping.".format(service.hash_index)
2033
                )
2034 0
                continue
2035

2036
            # in one transaction
2037 1
            with self.session_scope() as session:
2038

2039 1
                procedure = service.output
2040 1
                procedure.__dict__["id"] = service.procedure_id
2041 1
                self.update_procedures([procedure])
2042

2043 1
                session.query(ServiceQueueORM).filter_by(id=service.id).delete()  # synchronize_session=False)
2044

2045 1
            done += 1
2046

2047 4
        return done
2048

2049
    ### Mongo queue handling functions
2050

2051 8
    def queue_submit(self, data: List[TaskRecord]):
2052
        """Submit a list of tasks to the queue.
2053
        Tasks are unique by their base_result, which should be inserted into
2054
        the DB first before submitting it's corresponding task to the queue
2055
        (with result.status='INCOMPLETE' as the default)
2056
        The default task.status is 'WAITING'
2057

2058
        Parameters
2059
        ----------
2060
        data : List[TaskRecord]
2061
            A task is a dict, with the following fields:
2062
            - hash_index: idx, not used anymore
2063
            - spec: dynamic field (dict-like), can have any structure
2064
            - tag: str
2065
            - base_results: tuple (required), first value is the class type
2066
             of the result, {'results' or 'procedure'). The second value is
2067
             the ID of the result in the DB. Example:
2068
             "base_result": ('results', result_id)
2069

2070
        Returns
2071
        -------
2072
        Dict[str, Any]
2073
            Dictionary with keys data and meta.
2074
            'data' is a list of the IDs of the tasks IN ORDER, including
2075
            duplicates. An errored task has 'None' in its ID
2076
            meta['duplicates'] has the duplicate tasks
2077
        """
2078

2079 8
        meta = add_metadata_template()
2080

2081 8
        results = ["placeholder"] * len(data)
2082

2083 8
        with self.session_scope() as session:
2084
            # preserving all the base results for later check
2085 8
            all_base_results = [record.base_result for record in data]
2086 8
            query_res = (
2087
                session.query(TaskQueueORM.id, TaskQueueORM.base_result_id)
2088
                .filter(TaskQueueORM.base_result_id.in_(all_base_results))
2089
                .all()
2090
            )
2091

2092
            # constructing a dict of found tasks and their ids
2093 8
            found_dict = {str(base_result_id): str(task_id) for task_id, base_result_id in query_res}
2094 8
            new_tasks, new_idx = [], []
2095 8
            duplicate_idx = []
2096 8
            for task_num, record in enumerate(data):
2097

2098 8
                if found_dict.get(record.base_result):
2099
                    # if found, get id from found_dict
2100
                    # Note: found_dict may return a task object because the duplicate id is of an object in the input.
2101 8
                    results[task_num] = found_dict.get(record.base_result)
2102
                    # add index of duplicates
2103 8
                    duplicate_idx.append(task_num)
2104 8
                    meta["duplicates"].append(task_num)
2105

2106
                else:
2107 8
                    task_dict = record.dict(exclude={"id"})
2108 8
                    task = TaskQueueORM(**task_dict)
2109 8
                    new_idx.append(task_num)
2110 8
                    task.priority = task.priority.value
2111
                    # append all the new tasks that should be added
2112 8
                    new_tasks.append(task)
2113
                    # add the (yet to be) inserted object id to dictionary
2114 8
                    found_dict[record.base_result] = task
2115

2116 8
            session.add_all(new_tasks)
2117 8
            session.commit()
2118

2119 8
            meta["n_inserted"] += len(new_tasks)
2120
            # setting the id for new inserted objects, cannot be done before commiting as new objects do not have ids
2121 8
            for i, task_idx in enumerate(new_idx):
2122 8
                results[task_idx] = str(new_tasks[i].id)
2123

2124
            # finding the duplicate items in input, for which ids are found only after insertion
2125 8
            for i in duplicate_idx:
2126 8
                if not isinstance(results[i], str):
2127 8
                    results[i] = str(results[i].id)
2128

2129 8
        meta["success"] = True
2130

2131 8
        ret = {"data": results, "meta": meta}
2132 8
        return ret
2133

2134 8
    def queue_get_next(
2135
        self, manager, available_programs, available_procedures, limit=100, tag=None
2136
    ) -> List[TaskRecord]:
2137
        """Obtain tasks for a manager
2138

2139
        Given tags and available programs/procedures on the manager, obtain
2140
        waiting tasks to run.
2141
        """
2142

2143 8
        proc_filt = TaskQueueORM.procedure.in_([p.lower() for p in available_procedures])
2144 8
        none_filt = TaskQueueORM.procedure == None  # lgtm [py/test-equals-none]
2145

2146 8
        order_by = []
2147 8
        if tag is not None:
2148 4
            if isinstance(tag, str):
2149 4
                tag = [tag]
2150

2151 8
        order_by.extend([TaskQueueORM.priority.desc(), TaskQueueORM.created_on])
2152 8
        queries = []
2153 8
        if tag is not None:
2154 4
            for t in tag:
2155 4
                query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs, tag=t)
2156 4
                query.append(or_(proc_filt, none_filt))
2157 4
                queries.append(query)
2158
        else:
2159 8
            query = format_query(TaskQueueORM, status=TaskStatusEnum.waiting, program=available_programs)
2160 8
            query.append((or_(proc_filt, none_filt)))
2161 8
            queries.append(query)
2162

2163 8
        new_limit = limit
2164 8
        found = []
2165 8
        update_count = 0
2166

2167 8
        update_fields = {"status": TaskStatusEnum.running, "modified_on": dt.utcnow(), "manager": manager}
2168 8
        with self.session_scope() as session:
2169 8
            for q in queries:
2170

2171
                # Have we found all we needed to find
2172 8
                if new_limit == 0:
2173 4
                    break
2174

2175
                # with_for_update locks the rows. skip_locked=True makes it skip already-locked rows
2176
                # (possibly from another process)
2177 8
                query = (
2178
                    session.query(TaskQueueORM)
2179
                    .filter(*q)
2180
                    .order_by(*order_by)
2181
                    .limit(new_limit)
2182
                    .with_for_update(skip_locked=True)
2183
                )
2184

2185 8
                new_items = query.all()
2186 8
                new_ids = [x.id for x in new_items]
2187

2188
                # Update all the task records to reflect this manager claiming them
2189 8
                update_count += (
2190
                    session.query(TaskQueueORM)
2191
                    .filter(TaskQueueORM.id.in_(new_ids))
2192
                    .update(update_fields, synchronize_session=False)
2193
                )
2194

2195
                # After commiting, the row locks are released
2196 8
                session.commit()
2197

2198
                # How many more do we have to query
2199 8
                new_limit = limit - len(new_items)
2200

2201
                # I would assume this is always true. If it isn't,
2202
                # that would be really bad, and lead to an infinite loop
2203 8
                assert new_limit >= 0
2204

2205
                # Store in dict form for returning. We will add the updated fields later
2206 8
                found.extend([task.to_dict(exclude=update_fields.keys()) for task in new_items])
2207

2208
            # avoid another trip to the DB to get the updated values, set them here
2209 8
            found = [TaskRecord(**task, **update_fields) for task in found]
2210

2211 8
        if update_count != len(found):
2212 0
            self.logger.warning("QUEUE: Number of found tasks does not match the number of updated tasks.")
2213

2214 8
        return found
2215

2216 8
    def get_queue(
2217
        self,
2218
        id=None,
2219
        hash_index=None,
2220
        program=None,
2221
        status: str = None,
2222
        base_result: str = None,
2223
        tag=None,
2224
        manager=None,
2225
        include=None,
2226
        exclude=None,
2227
        limit: int = None,
2228
        skip: int = 0,
2229
        return_json=False,
2230
        with_ids=True,
2231
    ):
2232
        """
2233
        TODO: check what query keys are needs
2234
        Parameters
2235
        ----------
2236
        id : Optional[List[str]], optional
2237
            Ids of the tasks
2238
        Hash_index: Optional[List[str]], optional,
2239
            hash_index of service, not used
2240
        program, list of str or str, optional
2241
        status : Optional[bool], optional (find all)
2242
            The status of the task: 'COMPLETE', 'RUNNING', 'WAITING', or 'ERROR'
2243
        base_result: Optional[str], optional
2244
            base_result id
2245
        include : Optional[List[str]], optional
2246
            The fields to return, default to return all
2247
        exclude : Optional[List[str]], optional
2248
            The fields to not return, default to return all
2249
        limit : Optional[int], optional
2250
            maximum number of results to return
2251
            if 'limit' is greater than the global setting self._max_limit,
2252
            the self._max_limit will be returned instead
2253
            (This is to avoid overloading the server)
2254
        skip : int, optional
2255
            skip the first 'skip' results. Used to paginate, default is 0
2256
        return_json : bool, optional
2257
            Return the results as a list of json inseated of objects, deafult is True
2258
        with_ids : bool, optional
2259
            Include the ids in the returned objects/dicts, default is True
2260

2261
        Returns
2262
        -------
2263
        Dict[str, Any]
2264
            Dict with keys: data, meta. Data is the objects found
2265
        """
2266

2267 8
        meta = get_metadata_template()
2268 8
        query = format_query(
2269
            TaskQueueORM,
2270
            program=program,
2271
            id=id,
2272
            hash_index=hash_index,
2273
            status=status,
2274
            base_result_id=base_result,
2275
            tag=tag,
2276
            manager=manager,
2277
        )
2278

2279 8
        data = []
2280 8
        try:
2281 8
            data, meta["n_found"] = self.get_query_projection(
2282
                TaskQueueORM, query, limit=limit, skip=skip, include=include, exclude=exclude
2283
            )
2284 8
            meta["success"] = True
2285 0
        except Exception as err:
2286 0
            meta["error_description"] = str(err)
2287

2288 8
        data = [TaskRecord(**task) for task in data]
2289

2290 8
        return {"data": data, "meta": meta}
2291

2292 8
    def queue_get_by_id(self, id: List[str], limit: int = None, skip: int = 0, as_json: bool = True):
2293
        """Get tasks by their IDs
2294

2295
        Parameters
2296
        ----------
2297
        id : List[str]
2298
            List of the task Ids in the DB
2299
        limit : Optional[int], optional
2300
            max number of returned tasks. If limit > max_limit, max_limit
2301
            will be returned instead (safe query)
2302
        skip : int, optional
2303
            skip the first 'skip' results. Used to paginate, default is 0
2304
        as_json : bool, optioanl
2305
            Return tasks as JSON, default is True
2306

2307
        Returns
2308
        -------
2309
        List[TaskRecord]
2310
            List of the found tasks
2311
        """
2312

2313 8
        with self.session_scope() as session:
2314 8
            found = (
2315
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(id)).limit(self.get_limit(limit)).offset(skip)
2316
            )
2317

2318 8
            if as_json:
2319 8
                found = [TaskRecord(**task.to_dict()) for task in found]
2320

2321 8
        return found
2322

2323 8
    def queue_mark_complete(self, task_ids: List[str]) -> int:
2324
        """Update the given tasks as complete
2325
        Note that each task is already pointing to its result location
2326
        Mark the corresponding result/procedure as complete
2327

2328
        Parameters
2329
        ----------
2330
        task_ids : List[str]
2331
            IDs of the tasks to mark as COMPLETE
2332

2333
        Returns
2334
        -------
2335
        int
2336
            number of TaskRecord objects marked as COMPLETE, and deleted from the database consequtively.
2337
        """
2338

2339 8
        if not task_ids:
2340 4
            return 0
2341

2342 8
        update_fields = dict(status=TaskStatusEnum.complete, modified_on=dt.utcnow())
2343 8
        with self.session_scope() as session:
2344
            # assuming all task_ids are valid, then managers will be in order by id
2345 8
            managers = (
2346
                session.query(TaskQueueORM.manager)
2347
                .filter(TaskQueueORM.id.in_(task_ids))
2348
                .order_by(TaskQueueORM.id)
2349
                .all()
2350
            )
2351 8
            managers = [manager[0] if manager else manager for manager in managers]
2352 8
            task_manger_map = {task_id: manager for task_id, manager in zip(sorted(task_ids), managers)}
2353 8
            update_fields[BaseResultORM.manager_name] = case(task_manger_map, value=TaskQueueORM.id)
2354

2355 8
            session.query(BaseResultORM).filter(BaseResultORM.id == TaskQueueORM.base_result_id).filter(
2356
                TaskQueueORM.id.in_(task_ids)
2357
            ).update(update_fields, synchronize_session=False)
2358

2359
            # delete completed tasks
2360 8
            tasks_c = (
2361
                session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2362
            )
2363

2364 8
        return tasks_c
2365

2366 8
    def queue_mark_error(self, data: List[Tuple[int, Dict[str, str]]]):
2367
        """
2368
        update the given tasks as errored
2369
        Mark the corresponding result/procedure as Errored
2370

2371
        Parameters
2372
        ----------
2373
        data : List[Tuple[int, Dict[str, str]]]
2374
            List of task ids and their error messages desired to be assigned to them.
2375

2376
        Returns
2377
        -------
2378
        int
2379
            Number of tasks updated as errored.
2380
        """
2381

2382 8
        if not data:
2383 4
            return 0
2384

2385 8
        task_ids = []
2386 8
        with self.session_scope() as session:
2387
            # Make sure returned results are in the same order as the task ids
2388
            # SQL queries change the order when using "in"
2389 8
            data_dict = {item[0]: item[1] for item in data}
2390 8
            sorted_data = {key: data_dict[key] for key in sorted(data_dict.keys())}
2391 8
            task_objects = (
2392
                session.query(TaskQueueORM)
2393
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2394
                .order_by(TaskQueueORM.id)
2395
                .all()
2396
            )
2397 8
            base_results = (
2398
                session.query(BaseResultORM)
2399
                .filter(BaseResultORM.id == TaskQueueORM.base_result_id)
2400
                .filter(TaskQueueORM.id.in_(sorted_data.keys()))
2401
                .order_by(TaskQueueORM.id)
2402
                .all()
2403
            )
2404

2405 8
            for (task_id, error_dict), task_obj, base_result in zip(sorted_data.items(), task_objects, base_results):
2406

2407 8
                task_ids.append(task_id)
2408
                # update task
2409 8
                task_obj.status = TaskStatusEnum.error
2410 8
                task_obj.modified_on = dt.utcnow()
2411

2412
                # update result
2413 8
                base_result.status = TaskStatusEnum.error
2414 8
                base_result.manager_name = task_obj.manager
2415 8
                base_result.modified_on = dt.utcnow()
2416

2417
                # Compress error dicts here. Should be fast, since errors are small
2418 8
                err = KVStore.compress(error_dict, CompressionEnum.lzma, 1)
2419 8
                err_id = self.add_kvstore([err])["data"][0]
2420 8
                base_result.error = err_id
2421

2422 8
            session.commit()
2423

2424 8
        return len(task_ids)
2425

2426 8
    def queue_reset_status(
2427
        self,
2428
        id: Union[str, List[str]] = None,
2429
        base_result: Union[str, List[str]] = None,
2430
        manager: Optional[str] = None,
2431
        reset_running: bool = False,
2432
        reset_error: bool = False,
2433
    ) -> int:
2434
        """
2435
        Reset the status of the tasks that a manager owns from Running to Waiting
2436
        If reset_error is True, then also reset errored tasks AND its results/proc
2437

2438
        Parameters
2439
        ----------
2440
        id : Optional[Union[str, List[str]]], optional
2441
            The id of the task to modify
2442
        base_result : Optional[Union[str, List[str]]], optional
2443
            The id of the base result to modify
2444
        manager : Optional[str], optional
2445
            The manager name to reset the status of
2446
        reset_running : bool, optional
2447
            If True, reset running tasks to be waiting
2448
        reset_error : bool, optional
2449
            If True, also reset errored tasks to be waiting,
2450
            also update results/proc to be INCOMPLETE
2451

2452
        Returns
2453
        -------
2454
        int
2455
            Updated count
2456
        """
2457

2458 8
        if not (reset_running or reset_error):
2459
            # nothing to do
2460 0
            return 0
2461

2462 8
        if sum(x is not None for x in [id, base_result, manager]) == 0:
2463 8
            raise ValueError("All query fields are None, reset_status must specify queries.")
2464

2465 8
        status = []
2466 8
        if reset_running:
2467 8
            status.append(TaskStatusEnum.running)
2468 8
        if reset_error:
2469 4
            status.append(TaskStatusEnum.error)
2470

2471 8
        query = format_query(TaskQueueORM, id=id, base_result_id=base_result, manager=manager, status=status)
2472

2473
        # Must have status + something, checking above as well(being paranoid)
2474 8
        if len(query) < 2:
2475 0
            raise ValueError("All query fields are None, reset_status must specify queries.")
2476

2477 8
        with self.session_scope() as session:
2478
            # Update results and procedures if reset_error
2479 8
            task_ids = session.query(TaskQueueORM.id).filter(*query)
2480 8
            session.query(BaseResultORM).filter(TaskQueueORM.base_result_id == BaseResultORM.id).filter(
2481
                TaskQueueORM.id.in_(task_ids)
2482
            ).update(dict(status=RecordStatusEnum.incomplete, modified_on=dt.utcnow()), synchronize_session=False)
2483

2484 8
            updated = (
2485
                session.query(TaskQueueORM)
2486
                .filter(TaskQueueORM.id.in_(task_ids))
2487
                .update(dict(status=TaskStatusEnum.waiting, modified_on=dt.utcnow()), synchronize_session=False)
2488
            )
2489

2490 8
        return updated
2491

2492 8
    def reset_base_result_status(
2493
        self,
2494
        id: Union[str, List[str]] = None,
2495
    ) -> int:
2496
        """
2497
        Reset the status of a base result to "incomplete". Will only work if the
2498
        status is not complete.
2499

2500
        This should be rarely called. Handle with care!
2501

2502
        Parameters
2503
        ----------
2504
        id : Optional[Union[str, List[str]]], optional
2505
            The id of the base result to modify
2506

2507
        Returns
2508
        -------
2509
        int
2510
            Number of base results modified
2511
        """
2512

2513 1
        query = format_query(BaseResultORM, id=id)
2514 1
        update_dict = {"status": RecordStatusEnum.incomplete, "modified_on": dt.utcnow()}
2515

2516 1
        with self.session_scope() as session:
2517 1
            updated = (
2518
                session.query(BaseResultORM)
2519
                .filter(*query)
2520
                .filter(BaseResultORM.status != RecordStatusEnum.complete)
2521
                .update(update_dict, synchronize_session=False)
2522
            )
2523

2524 1
        return updated
2525

2526 8
    def queue_modify_tasks(
2527
        self,
2528
        id: Union[str, List[str]] = None,
2529
        base_result: Union[str, List[str]] = None,
2530
        new_tag: Optional[str] = None,
2531
        new_priority: Optional[int] = None,
2532
    ):
2533
        """
2534
        Modifies the tag and priority of tasks.
2535

2536
        This will only modify if the status is not running
2537

2538
        Parameters
2539
        ----------
2540
        id : Optional[Union[str, List[str]]], optional
2541
            The id of the task to modify
2542
        base_result : Optional[Union[str, List[str]]], optional
2543
            The id of the base result to modify
2544
        new_tag : Optional[str], optional
2545
            New tag to assign to the given tasks
2546
        new_priority: int, optional
2547
            New priority to assign to the given tasks
2548

2549
        Returns
2550
        -------
2551
        int
2552
            Updated count
2553
        """
2554

2555 8
        if new_tag is None and new_priority is None:
2556
            # nothing to do
2557 0
            return 0
2558

2559 8
        if sum(x is not None for x in [id, base_result]) == 0:
2560 0
            raise ValueError("All query fields are None, modify_task must specify queries.")
2561

2562 8
        query = format_query(TaskQueueORM, id=id, base_result_id=base_result)
2563

2564 8
        update_dict = {}
2565 8
        if new_tag is not None:
2566 8
            update_dict["tag"] = new_tag
2567 8
        if new_priority is not None:
2568 8
            update_dict["priority"] = new_priority
2569

2570 8
        update_dict["modified_on"] = dt.utcnow()
2571

2572 8
        with self.session_scope() as session:
2573 8
            updated = (
2574
                session.query(TaskQueueORM)
2575
                .filter(*query)
2576
                .filter(TaskQueueORM.status != TaskStatusEnum.running)
2577
                .update(update_dict, synchronize_session=False)
2578
            )
2579

2580 8
        return updated
2581

2582 8
    def del_tasks(self, id: Union[str, list]):
2583
        """
2584
        Delete a task from the queue. Use with cautious
2585

2586
        Parameters
2587
        ----------
2588
        id : str or List
2589
            Ids of the tasks to delete
2590
        Returns
2591
        -------
2592
        int
2593
            Number of tasks deleted
2594
        """
2595

2596 8
        task_ids = [id] if isinstance(id, (int, str)) else id
2597 8
        with self.session_scope() as session:
2598 8
            count = session.query(TaskQueueORM).filter(TaskQueueORM.id.in_(task_ids)).delete(synchronize_session=False)
2599

2600 8
        return count
2601

2602 8
    def _copy_task_to_queue(self, record_list: List[TaskRecord]):
2603
        """
2604
        copy the given tasks as-is to the DB. Used for data migration
2605

2606
        Parameters
2607
        ----------
2608
        record_list : List[TaskRecords]
2609
            List of task records to be copied
2610

2611
        Returns
2612
        -------
2613
        Dict[str, Any]
2614
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2615
        """
2616

2617 0
        meta = add_metadata_template()
2618

2619 0
        task_ids = []
2620 0
        with self.session_scope() as session:
2621 0
            for task in record_list:
2622 0
                doc = session.query(TaskQueueORM).filter_by(base_result_id=task.base_result_id)
2623

2624 0
                if get_count_fast(doc) == 0:
2625 0
                    doc = TaskQueueORM(**task.dict(exclude={"id"}))
2626 0
                    doc.priority = doc.priority.value
2627 0
                    if isinstance(doc.error, dict):
2628 0
                        doc.error = json.dumps(doc.error)
2629

2630 0
                    session.add(doc)
2631 0
                    session.commit()  # TODO: faster if done in bulk
2632 0
                    task_ids.append(str(doc.id))
2633 0
                    meta["n_inserted"] += 1
2634
                else:
2635 0
                    id = str(doc.first().id)
2636 0
                    meta["duplicates"].append(id)  # TODO
2637
                    # If new or duplicate, add the id to the return list
2638 0
                    task_ids.append(id)
2639 0
        meta["success"] = True
2640

2641 0
        ret = {"data": task_ids, "meta": meta}
2642 0
        return ret
2643

2644
    ### QueueManagerORMs
2645

2646 8
    def manager_update(self, name, **kwargs):
2647

2648 8
        do_log = kwargs.pop("log", False)
2649

2650 8
        inc_count = {
2651
            # Increment relevant data
2652
            "submitted": QueueManagerORM.submitted + kwargs.pop("submitted", 0),
2653
            "completed": QueueManagerORM.completed + kwargs.pop("completed", 0),
2654
            "returned": QueueManagerORM.returned + kwargs.pop("returned", 0),
2655
            "failures": QueueManagerORM.failures + kwargs.pop("failures", 0),
2656
        }
2657

2658 8
        upd = {key: kwargs[key] for key in QueueManagerORM.__dict__.keys() if key in kwargs}
2659

2660 8
        with self.session_scope() as session:
2661
            # QueueManagerORM.objects()  # init
2662 8
            manager = session.query(QueueManagerORM).filter_by(name=name)
2663 8
            if manager.count() > 0:  # existing
2664 8
                upd.update(inc_count, modified_on=dt.utcnow())
2665 8
                num_updated = manager.update(upd)
2666
            else:  # create new, ensures defaults and validations
2667 8
                manager = QueueManagerORM(name=name, **upd)
2668 8
                session.add(manager)
2669 8
                session.commit()
2670 8
                num_updated = 1
2671

2672 8
            if do_log:
2673
                # Pull again in case it was updated
2674 8
                manager = session.query(QueueManagerORM).filter_by(name=name).first()
2675

2676 8
                manager_log = QueueManagerLogORM(
2677
                    manager_id=manager.id,
2678
                    completed=manager.completed,
2679
                    submitted=manager.submitted,
2680
                    failures=manager.failures,
2681
                    total_worker_walltime=manager.total_worker_walltime,
2682
                    total_task_walltime=manager.total_task_walltime,
2683
                    active_tasks=manager.active_tasks,
2684
                    active_cores=manager.active_cores,
2685
                    active_memory=manager.active_memory,
2686
                )
2687

2688 8
                session.add(manager_log)
2689 8
                session.commit()
2690

2691 8
        return num_updated == 1
2692

2693 8
    def get_managers(
2694
        self, name: str = None, status: str = None, modified_before=None, modified_after=None, limit=None, skip=0
2695
    ):
2696

2697 8
        meta = get_metadata_template()
2698 8
        query = format_query(QueueManagerORM, name=name, status=status)
2699

2700 8
        if modified_before:
2701 8
            query.append(QueueManagerORM.modified_on <= modified_before)
2702

2703 8
        if modified_after:
2704 0
            query.append(QueueManagerORM.modified_on >= modified_after)
2705

2706 8
        data, meta["n_found"] = self.get_query_projection(QueueManagerORM, query, limit=limit, skip=skip)
2707 8
        meta["success"] = True
2708

2709 8
        return {"data": data, "meta": meta}
2710

2711 8
    def get_manager_logs(self, manager_ids: Union[List[str], str], timestamp_after=None, limit=None, skip=0):
2712 4
        meta = get_metadata_template()
2713 4
        query = format_query(QueueManagerLogORM, manager_id=manager_ids)
2714

2715 4
        if timestamp_after:
2716 4
            query.append(QueueManagerLogORM.timestamp >= timestamp_after)
2717

2718 4
        data, meta["n_found"] = self.get_query_projection(
2719
            QueueManagerLogORM, query, limit=limit, skip=skip, exclude=["id"]
2720
        )
2721 4
        meta["success"] = True
2722

2723 4
        return {"data": data, "meta": meta}
2724

2725 8
    def _copy_managers(self, record_list: Dict):
2726
        """
2727
        copy the given managers as-is to the DB. Used for data migration
2728

2729
        Parameters
2730
        ----------
2731
        record_list : List[Dict[str, Any]]
2732
            list of dict of managers data
2733
        Returns
2734
        -------
2735
        Dict[str, Any]
2736
            Dict with keys: data, meta. Data is the ids of the inserted/updated/existing docs
2737
        """
2738

2739 0
        meta = add_metadata_template()
2740

2741 0
        manager_names = []
2742 0
        with self.session_scope() as session:
2743 0
            for manager in record_list:
2744 0
                doc = session.query(QueueManagerORM).filter_by(name=manager["name"])
2745

2746 0
                if get_count_fast(doc) == 0:
2747 0
                    doc = QueueManagerORM(**manager)
2748 0
                    if isinstance(doc.created_on, float):
2749 0
                        doc.created_on = dt.fromtimestamp(doc.created_on / 1e3)
2750 0
                    if isinstance(doc.modified_on, float):
2751 0
                        doc.modified_on = dt.fromtimestamp(doc.modified_on / 1e3)
2752 0
                    session.add(doc)
2753 0
                    session.commit()  # TODO: faster if done in bulk
2754 0
                    manager_names.append(doc.name)
2755 0
                    meta["n_inserted"] += 1
2756
                else:
2757 0
                    name = doc.first().name
2758 0
                    meta["duplicates"].append(name)  # TODO
2759
                    # If new or duplicate, add the id to the return list
2760 0
                    manager_names.append(id)
2761 0
        meta["success"] = True
2762

2763 0
        ret = {"data": manager_names, "meta": meta}
2764 0
        return ret
2765

2766
    ### UserORMs
2767

2768 8
    _valid_permissions = frozenset({"read", "write", "compute", "queue", "admin"})
2769

2770 8
    @staticmethod
2771 8
    def _generate_password() -> str:
2772
        """
2773
        Generates a random password e.g. for add_user and modify_user.
2774

2775
        Returns
2776
        -------
2777
        str
2778
            An unhashed random password.
2779
        """
2780 8
        return secrets.token_urlsafe(32)
2781

2782 8
    def add_user(
2783
        self, username: str, password: Optional[str] = None, permissions: List[str] = ["read"], overwrite: bool = False
2784
    ) -> Tuple[bool, str]:
2785
        """
2786
        Adds a new user and associated permissions.
2787

2788
        Passwords are stored using bcrypt.
2789

2790
        Parameters
2791
        ----------
2792
        username : str
2793
            New user's username
2794
        password : Optional[str], optional
2795
            The user's password. If None, a new password will be generated.
2796
        permissions : Optional[List[str]], optional
2797
            The associated permissions of a user ['read', 'write', 'compute', 'queue', 'admin']
2798
        overwrite: bool, optional
2799
            Overwrite the user if it already exists.
2800
        Returns
2801
        -------
2802
        Tuple[bool, str]
2803
            A tuple of (success flag, password)
2804
        """
2805

2806
        # Make sure permissions are valid
2807 8
        if not self._valid_permissions >= set(permissions):
2808 0
            raise KeyError("Permissions settings not understood: {}".format(set(permissions) - self._valid_permissions))
2809

2810 8