1 4
from typing import List, Optional, Set, Union
2

3 4
from sqlalchemy import Integer, inspect
4 4
from sqlalchemy.sql import bindparam, text
5

6 4
from qcfractal.interface.models import Molecule, ResultRecord
7 4
from qcfractal.storage_sockets.models import MoleculeORM, ResultORM
8

9 4
QUERY_CLASSES = set()
10

11

12 4
class QueryBase:
13

14
    # The name/alias used by the REST APIs to access this class
15 4
    _class_name = None
16 4
    _available_groupby = set()
17

18
    # Mapping of the requested feature and the internal query method
19 4
    _query_method_map = {}
20

21 4
    def __init__(self, database_name, max_limit=1000):
22 4
        self.database_name = database_name
23 4
        self.max_limit = max_limit
24

25 4
    def __init_subclass__(cls, **kwargs):
26 4
        if cls not in QUERY_CLASSES:
27 4
            QUERY_CLASSES.add(cls)
28 4
        super().__init_subclass__(**kwargs)
29

30 4
    def query(self, session, query_key, limit=0, skip=0, include=None, exclude=None, **kwargs):
31

32 4
        if query_key not in self._query_method_map:
33 0
            raise TypeError(f"Query type {query_key} is unimplemented for class {self._class_name}")
34

35 4
        self.session = session
36

37 4
        return getattr(self, self._query_method_map[query_key])(**kwargs)
38

39 4
    def execute_query(self, sql_statement, with_keys=True, **kwargs):
40
        """Execute sql statemet, apply limit, and return results as dict if needed"""
41

42
        # TODO: check count first, way to iterate
43

44
        # sql_statement += f' LIMIT {self.max_limit}'
45 4
        result = self.session.execute(sql_statement, kwargs)
46 4
        keys = result.keys()  # get keys before fetching
47 4
        result = result.fetchall()
48 4
        self.session.commit()
49

50
        # create a list of dict with the keys and values of the results (instead of tuples)
51 4
        if with_keys:
52 4
            result = [dict(zip(keys, res)) for res in result]
53

54 4
        return result
55

56 4
    def _base_count(self, table_name: str, available_groupbys: Set[str], groupby: Optional[List[str]] = None):
57 1
        if groupby:
58 1
            bad_groups = set(groupby) - available_groupbys
59 1
            if bad_groups:
60 0
                raise AttributeError(f"The following groups are not permissible: {missing}")
61

62 1
            global_str = ", ".join(groupby)
63 1
            select_str = global_str + ", "
64 1
            extra_str = f"""GROUP BY {global_str}\nORDER BY {global_str}"""
65

66
        else:
67 1
            select_str = ""
68 1
            extra_str = ""
69

70 1
        sql_statement = f"""
71
select {select_str}count(*) from {table_name}
72
{extra_str}
73
"""
74

75 1
        ret = self.execute_query(sql_statement, with_keys=True)
76

77 1
        if groupby:
78 1
            return ret
79
        else:
80 1
            return ret[0]["count"]
81

82 4
    @staticmethod
83 4
    def _raise_missing_attribute(cls, query_key, missing_attribute, amend_msg=""):
84
        """Raises error for missing attribute in a message suitable for the REST user"""
85

86 0
        raise AttributeError(f"To query {cls._class_name} for {query_key} " f"you must provide {missing_attribute}.")
87

88

89
# ----------------------------------------------------------------------------
90

91

92 4
class TaskQueries(QueryBase):
93

94 4
    _class_name = "task"
95 4
    _query_method_map = {"counts": "_task_counts"}
96

97 4
    def _task_counts(self):
98

99 0
        sql_statement = f"""
100
            SELECT tag, priority, status, count(*)
101
            FROM task_queue
102
            WHERE True
103
            group by tag, priority, status
104
            order by tag, priority, status
105
        """
106

107 0
        return self.execute_query(sql_statement, with_keys=True)
108

109

110
# ----------------------------------------------------------------------------
111

112

113 4
class DatabaseStatQueries(QueryBase):
114

115 4
    _class_name = "database_stats"
116

117 4
    _query_method_map = {
118
        "table_count": "_table_count",
119
        "database_size": "_database_size",
120
        "table_information": "_table_information",
121
    }
122

123 4
    def _table_count(self, table_name=None):
124

125 4
        if table_name is None:
126 0
            self._raise_missing_attribute("table_name", "table name")
127

128 4
        sql_statement = f"SELECT count(*) from {table_name}"
129 4
        return self.execute_query(sql_statement, with_keys=False)[0]
130

131 4
    def _database_size(self):
132

133 4
        sql_statement = f"SELECT pg_database_size('{self.database_name}')"
134 4
        return self.execute_query(sql_statement, with_keys=True)[0]["pg_database_size"]
135

136 4
    def _table_information(self):
137

138 4
        sql_statement = f"""
139
SELECT relname                                AS table_name
140
     , c.reltuples::BIGINT                    AS row_estimate
141
     , pg_total_relation_size(c.oid)          AS total_bytes
142
     , pg_indexes_size(c.oid)                 AS index_bytes
143
     , pg_total_relation_size(reltoastrelid)  AS toast_bytes
144
FROM pg_class c
145
         LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
146
WHERE relkind = 'r';
147
 """
148

149 4
        result = self.execute_query(sql_statement, with_keys=False)
150

151 4
        ret = []
152 4
        for row in result:
153 4
            if ("pg_" in row[0]) or ("sql_" in row[0]):
154 4
                continue
155 4
            ret.append(list(row))
156

157 4
        ret = {"columns": ["table_name", "row_estimate", "total_bytes", "index_bytes", "toast_bytes"], "rows": ret}
158

159 4
        return ret
160

161

162 4
class ResultQueries(QueryBase):
163

164 4
    _class_name = "result"
165

166 4
    _query_method_map = {"count": "_count"}
167

168 4
    def _count(self, groupby: Optional[List[str]] = None):
169

170 1
        available_groupbys = {"result_type", "status"}
171

172 1
        return self._base_count("base_result", available_groupbys, groupby=groupby)
173

174

175 4
class MoleculeQueries(QueryBase):
176

177 4
    _class_name = "molecule"
178

179 4
    _query_method_map = {"count": "_count"}
180

181 4
    def _count(self, groupby: Optional[List[str]] = None):
182

183 1
        available_groupbys = set()
184

185 1
        return self._base_count("molecule", available_groupbys, groupby=groupby)
186

187

188
# ----------------------------------------------------------------------------
189

190

191 4
class TorsionDriveQueries(QueryBase):
192

193 4
    _class_name = "torsiondrive"
194

195 4
    _query_method_map = {
196
        "initial_molecules": "_get_initial_molecules",
197
        "initial_molecules_ids": "_get_initial_molecules_ids",
198
        "final_molecules": "_get_final_molecules",
199
        "final_molecules_ids": "_get_final_molecules_ids",
200
        "return_results": "_get_return_results",
201
    }
202

203 4
    def _get_initial_molecules_ids(self, torsion_id=None):
204

205 1
        if torsion_id is None:
206 0
            self._raise_missing_attribute("initial_molecules_ids", "torsion drive id")
207

208 1
        sql_statement = f"""
209
                select initial_molecule from optimization_procedure as opt where opt.id in
210
                (
211
                    select opt_id from optimization_history where torsion_id = {torsion_id}
212
                )
213
                order by opt.id
214
        """
215

216 1
        return self.execute_query(sql_statement, with_keys=False)
217

218 4
    def _get_initial_molecules(self, torsion_id=None):
219

220 1
        if torsion_id is None:
221 0
            self._raise_missing_attribute("initial_molecules", "torsion drive id")
222

223 1
        sql_statement = f"""
224
                select molecule.* from molecule
225
                join optimization_procedure as opt
226
                on molecule.id = opt.initial_molecule
227
                where opt.id in
228
                    (select opt_id from optimization_history where torsion_id = {torsion_id})
229
        """
230

231 1
        return self.execute_query(sql_statement, with_keys=True)
232

233 4
    def _get_final_molecules_ids(self, torsion_id=None):
234

235 1
        if torsion_id is None:
236 0
            self._raise_missing_attribute("final_molecules_ids", "torsion drive id")
237

238 1
        sql_statement = f"""
239
                select final_molecule from optimization_procedure as opt where opt.id in
240
                (
241
                    select opt_id from optimization_history where torsion_id = {torsion_id}
242
                )
243
                order by opt.id
244
        """
245

246 1
        return self.execute_query(sql_statement, with_keys=False)
247

248 4
    def _get_final_molecules(self, torsion_id=None):
249

250 1
        if torsion_id is None:
251 0
            self._raise_missing_attribute("final_molecules", "torsion drive id")
252

253 1
        sql_statement = f"""
254
                select molecule.* from molecule
255
                join optimization_procedure as opt
256
                on molecule.id = opt.final_molecule
257
                where opt.id in
258
                    (select opt_id from optimization_history where torsion_id = {torsion_id})
259
        """
260

261 1
        return self.execute_query(sql_statement, with_keys=True)
262

263 4
    def _get_return_results(self, torsion_id=None):
264
        """All return results ids of a torsion drive"""
265

266 1
        if torsion_id is None:
267 0
            self._raise_missing_attribute("return_results", "torsion drive id")
268

269 1
        sql_statement = f"""
270
                select opt_res.opt_id, result.id as result_id, result.return_result from result
271
                join opt_result_association as opt_res
272
                on result.id = opt_res.result_id
273
                where opt_res.opt_id in
274
                (
275
                    select opt_id from optimization_history where torsion_id = {torsion_id}
276
                )
277
        """
278

279 1
        return self.execute_query(sql_statement, with_keys=False)
280

281

282 4
class OptimizationQueries(QueryBase):
283

284 4
    _class_name = "optimization"
285 4
    _exclude = ["molecule_hash", "molecular_formula", "result_type"]
286 4
    _query_method_map = {
287
        "all_results": "_get_all_results",
288
        "final_result": "_get_final_results",
289
        "initial_molecule": "_get_initial_molecules",
290
        "final_molecule": "_get_final_molecules",
291
    }
292

293 4
    def _remove_excluded_keys(self, data):
294 1
        for key in self._exclude:
295 1
            data.pop(key, None)
296

297 4
    def _get_all_results(self, optimization_ids: List[Union[int, str]] = None):
298
        """Returns all the results objects (trajectory) of each optmization
299
        Returns list(list)"""
300

301 1
        if optimization_ids is None:
302 0
            self._raise_missing_attribute("all_results", "List of optimizations ids")
303

304
        # row_to_json(result.*)
305 1
        sql_statement = text(
306
            """
307
            select * from base_result
308
            join (
309
                select opt_id, result.* from result
310
                join opt_result_association as traj
311
                on result.id = traj.result_id
312
                where traj.opt_id in :optimization_ids
313
            ) result
314
            on base_result.id = result.id
315
        """
316
        )
317

318
        # bind and expand ids list
319 1
        sql_statement = sql_statement.bindparams(bindparam("optimization_ids", expanding=True))
320

321
        # column types:
322 1
        columns = inspect(ResultORM).columns
323 1
        sql_statement = sql_statement.columns(opt_id=Integer, *columns)
324 1
        query_result = self.execute_query(sql_statement, optimization_ids=list(optimization_ids))
325

326 1
        ret = {}
327 1
        for rec in query_result:
328 1
            self._remove_excluded_keys(rec)
329 1
            key = rec.pop("opt_id")
330 1
            if key not in ret:
331 1
                ret[key] = []
332

333 1
            ret[key].append(ResultRecord(**rec))
334

335 1
        return ret
336

337 4
    def _get_final_results(self, optimization_ids: List[Union[int, str]] = None):
338
        """Return the actual results objects of the best result in each optimization"""
339

340 4
        if optimization_ids is None:
341 0
            self._raise_missing_attribute("final_result", "List of optimizations ids")
342

343 4
        sql_statement = text(
344
            """
345
            select * from base_result
346
            join (
347
                select opt_id, result.* from result
348
                join (
349
                    select opt.opt_id, opt.result_id, max_pos from opt_result_association as opt
350
                    inner join (
351
                            select opt_id, max(position) as max_pos from opt_result_association
352
                            where opt_id in :optimization_ids
353
                            group by opt_id
354
                        ) opt2
355
                    on opt.opt_id = opt2.opt_id and opt.position = opt2.max_pos
356
                ) traj
357
                on result.id = traj.result_id
358
            ) result
359
            on base_result.id = result.id
360
        """
361
        )
362

363
        # bind and expand ids list
364 4
        sql_statement = sql_statement.bindparams(bindparam("optimization_ids", expanding=True))
365

366
        # column types:
367 4
        columns = inspect(ResultORM).columns
368 4
        sql_statement = sql_statement.columns(opt_id=Integer, *columns)
369 4
        query_result = self.execute_query(sql_statement, optimization_ids=list(optimization_ids))
370

371 4
        ret = {}
372 4
        for rec in query_result:
373 1
            self._remove_excluded_keys(rec)
374 1
            key = rec.pop("opt_id")
375 1
            ret[key] = ResultRecord(**rec)
376

377 4
        return ret
378

379 4
    def _get_initial_molecules(self, optimization_ids=None):
380

381 1
        if optimization_ids is None:
382 0
            self._raise_missing_attribute("initial_molecule", "List of optimizations ids")
383

384 1
        sql_statement = text(
385
            """
386
                select opt.id as opt_id, molecule.* from molecule
387
                join optimization_procedure as opt
388
                on molecule.id = opt.initial_molecule
389
                where opt.id in :optimization_ids
390
        """
391
        )
392

393
        # bind and expand ids list
394 1
        sql_statement = sql_statement.bindparams(bindparam("optimization_ids", expanding=True))
395

396
        # column types:
397 1
        columns = inspect(MoleculeORM).columns
398 1
        sql_statement = sql_statement.columns(opt_id=Integer, *columns)
399 1
        query_result = self.execute_query(sql_statement, optimization_ids=list(optimization_ids))
400

401 1
        ret = {}
402 1
        for rec in query_result:
403 1
            self._remove_excluded_keys(rec)
404 1
            key = rec.pop("opt_id")
405 1
            rec = {k: v for k, v in rec.items() if v is not None}
406 1
            ret[key] = Molecule(**rec)
407

408 1
        return ret
409

410 4
    def _get_final_molecules(self, optimization_ids=None):
411

412 1
        if optimization_ids is None:
413 0
            self._raise_missing_attribute("final_molecule", "List of optimizations ids")
414

415 1
        sql_statement = text(
416
            """
417
                select opt.id as opt_id, molecule.* from molecule
418
                join optimization_procedure as opt
419
                on molecule.id = opt.final_molecule
420
                where opt.id in :optimization_ids
421
        """
422
        )
423

424
        # bind and expand ids list
425 1
        sql_statement = sql_statement.bindparams(bindparam("optimization_ids", expanding=True))
426

427
        # column types:
428 1
        columns = inspect(MoleculeORM).columns
429 1
        sql_statement = sql_statement.columns(opt_id=Integer, *columns)
430 1
        query_result = self.execute_query(sql_statement, optimization_ids=list(optimization_ids))
431

432 1
        ret = {}
433 1
        for rec in query_result:
434 1
            self._remove_excluded_keys(rec)
435 1
            key = rec.pop("opt_id")
436 1
            rec = {k: v for k, v in rec.items() if v is not None}
437 1
            ret[key] = Molecule(**rec)
438

439 1
        return ret

Read our documentation on viewing source code .

Loading