MolSSI / QCFractal
1
"""
2
Contains testing infrastructure for QCFractal.
3
"""
4

5 4
import os
6 4
import pkgutil
7 4
import shutil
8 4
import signal
9 4
import socket
10 4
import subprocess
11 4
import sys
12 4
import threading
13 4
import time
14 4
from collections import Mapping
15 4
from contextlib import contextmanager
16

17 4
import numpy as np
18 4
import pandas as pd
19 4
import pytest
20 4
import qcengine as qcng
21 4
import requests
22 4
from qcelemental.models import Molecule
23 4
from tornado.ioloop import IOLoop
24

25 4
from .interface import FractalClient
26 4
from .postgres_harness import PostgresHarness, TemporaryPostgres
27 4
from .queue import build_queue_adapter
28 4
from .server import FractalServer
29 4
from .snowflake import FractalSnowflake
30 4
from .storage_sockets import storage_socket_factory
31

32
### Addon testing capabilities
33

34

35 4
def pytest_addoption(parser):
36
    """
37
    Additional PyTest CLI flags to add
38

39
    See `pytest_collection_modifyitems` for handling and `pytest_configure` for adding known in-line marks.
40
    """
41 0
    parser.addoption("--runslow", action="store_true", default=False, help="run slow tests")
42 0
    parser.addoption("--runexamples", action="store_true", default=False, help="run example tests")
43

44

45 4
def pytest_collection_modifyitems(config, items):
46
    """
47
    Handle test triggers based on the CLI flags
48

49
    Use decorators:
50
    @pytest.mark.slow
51
    @pyrest.mark.example
52
    """
53 4
    runslow = config.getoption("--runslow")
54 4
    runexamples = config.getoption("--runexamples")
55 4
    skip_slow = pytest.mark.skip(reason="need --runslow option to run")
56 4
    skip_example = pytest.mark.skip(reason="need --runexamples option to run")
57 4
    for item in items:
58 4
        if "slow" in item.keywords and not runslow:
59 0
            item.add_marker(skip_slow)
60 4
        if "example" in item.keywords and not runexamples:
61 0
            item.add_marker(skip_example)
62

63

64 4
def pytest_configure(config):
65 4
    import sys
66

67 4
    sys._called_from_test = True
68 4
    config.addinivalue_line("markers", "example: Mark a given test as an example which can be run")
69 4
    config.addinivalue_line(
70
        "markers", "slow: Mark a given test as slower than most other tests, needing a special " "flag to run."
71
    )
72

73

74 4
def pytest_unconfigure(config):
75 0
    import sys
76

77 0
    del sys._called_from_test
78

79

80 4
def _plugin_import(plug):
81 4
    plug_spec = pkgutil.find_loader(plug)
82 4
    if plug_spec is None:
83 4
        return False
84
    else:
85 2
        return True
86

87

88 4
_import_message = "Not detecting module {}. Install package if necessary and add to envvar PYTHONPATH"
89

90 4
_adapter_testing = ["pool", "dask", "fireworks", "parsl"]
91

92
# Figure out what is imported
93 4
_programs = {
94
    "fireworks": _plugin_import("fireworks"),
95
    "rdkit": _plugin_import("rdkit"),
96
    "psi4": _plugin_import("psi4"),
97
    "parsl": _plugin_import("parsl"),
98
    "dask": _plugin_import("dask"),
99
    "dask_jobqueue": _plugin_import("dask_jobqueue"),
100
    "geometric": _plugin_import("geometric"),
101
    "torsiondrive": _plugin_import("torsiondrive"),
102
    "torchani": _plugin_import("torchani"),
103
}
104 4
if _programs["dask"]:
105 1
    _programs["dask.distributed"] = _plugin_import("dask.distributed")
106
else:
107 3
    _programs["dask.distributed"] = False
108

109 4
_programs["dftd3"] = "dftd3" in qcng.list_available_programs()
110

111

112 4
def has_module(name):
113 4
    return _programs[name]
114

115

116 4
def check_has_module(program):
117 4
    import_message = "Not detecting module {}. Install package if necessary to enable tests."
118 4
    if has_module(program) is False:
119 3
        pytest.skip(import_message.format(program))
120

121

122 4
def _build_pytest_skip(program):
123 4
    import_message = "Not detecting module {}. Install package if necessary to enable tests."
124 4
    return pytest.mark.skipif(has_module(program) is False, reason=import_message.format(program))
125

126

127
# Add a number of module testing options
128 4
using_dask = _build_pytest_skip("dask.distributed")
129 4
using_dask_jobqueue = _build_pytest_skip("dask_jobqueue")
130 4
using_dftd3 = _build_pytest_skip("dftd3")
131 4
using_fireworks = _build_pytest_skip("fireworks")
132 4
using_geometric = _build_pytest_skip("geometric")
133 4
using_parsl = _build_pytest_skip("parsl")
134 4
using_psi4 = _build_pytest_skip("psi4")
135 4
using_rdkit = _build_pytest_skip("rdkit")
136 4
using_torsiondrive = _build_pytest_skip("torsiondrive")
137 4
using_unix = pytest.mark.skipif(
138
    os.name.lower() != "posix", reason="Not on Unix operating system, " "assuming Bash is not present"
139
)
140

141
### Generic helpers
142

143

144 4
def recursive_dict_merge(base_dict, dict_to_merge_in):
145
    """Recursive merge for more complex than a simple top-level merge {**x, **y} which does not handle nested dict."""
146 1
    for k, v in dict_to_merge_in.items():
147 1
        if k in base_dict and isinstance(base_dict[k], dict) and isinstance(dict_to_merge_in[k], Mapping):
148 1
            recursive_dict_merge(base_dict[k], dict_to_merge_in[k])
149
        else:
150 1
            base_dict[k] = dict_to_merge_in[k]
151

152

153 4
def find_open_port():
154
    """
155
    Use socket's built in ability to find an open port.
156
    """
157 4
    sock = socket.socket()
158 4
    sock.bind(("", 0))
159

160 4
    host, port = sock.getsockname()
161

162 4
    return port
163

164

165 4
@contextmanager
166
def preserve_cwd():
167
    """Always returns to CWD on exit"""
168 0
    cwd = os.getcwd()
169 0
    try:
170 0
        yield cwd
171
    finally:
172 0
        os.chdir(cwd)
173

174

175 4
def await_true(wait_time, func, *args, **kwargs):
176

177 1
    wait_period = kwargs.pop("period", 4)
178 1
    periods = max(int(wait_time / wait_period), 1)
179 1
    for period in range(periods):
180 1
        ret = func(*args, **kwargs)
181 1
        if ret:
182 1
            return True
183 1
        time.sleep(wait_period)
184

185 0
    return False
186

187

188
### Background thread loops
189

190

191 4
@contextmanager
192
def pristine_loop():
193
    """
194
    Builds a clean IOLoop for using as a background request.
195
    Courtesy of Dask Distributed
196
    """
197 4
    IOLoop.clear_instance()
198 4
    IOLoop.clear_current()
199 4
    loop = IOLoop()
200 4
    loop.make_current()
201 4
    assert IOLoop.current() is loop
202

203 4
    try:
204 4
        yield loop
205
    finally:
206 4
        try:
207 4
            loop.close(all_fds=True)
208 0
        except (ValueError, KeyError, RuntimeError):
209 0
            pass
210 4
        IOLoop.clear_instance()
211 4
        IOLoop.clear_current()
212

213

214 4
@contextmanager
215
def loop_in_thread():
216 4
    with pristine_loop() as loop:
217
        # Add the IOloop to a thread daemon
218 4
        thread = threading.Thread(target=loop.start, name="test IOLoop")
219 4
        thread.daemon = True
220 4
        thread.start()
221 4
        loop_started = threading.Event()
222 4
        loop.add_callback(loop_started.set)
223 4
        loop_started.wait()
224

225 4
        try:
226 4
            yield loop
227
        finally:
228 4
            try:
229 4
                loop.add_callback(loop.stop)
230 4
                thread.join(timeout=5)
231 0
            except:
232 0
                pass
233

234

235 4
def terminate_process(proc):
236 4
    if proc.poll() is None:
237

238
        # Sigint (keyboard interupt)
239 4
        if sys.platform.startswith("win"):
240 0
            proc.send_signal(signal.CTRL_BREAK_EVENT)
241
        else:
242 4
            proc.send_signal(signal.SIGINT)
243

244 4
        try:
245 4
            start = time.time()
246 4
            while (proc.poll() is None) and (time.time() < (start + 15)):
247 4
                time.sleep(0.02)
248
        # Flat kill
249
        finally:
250 4
            proc.kill()
251

252

253 4
@contextmanager
254
def popen(args, **kwargs):
255
    """
256
    Opens a background task.
257

258
    Code and idea from dask.distributed's testing suite
259
    https://github.com/dask/distributed
260
    """
261 4
    args = list(args)
262

263
    # Bin prefix
264 4
    if sys.platform.startswith("win"):
265 0
        bin_prefix = os.path.join(sys.prefix, "Scripts")
266
    else:
267 4
        bin_prefix = os.path.join(sys.prefix, "bin")
268

269
    # Do we prefix with Python?
270 4
    if kwargs.pop("append_prefix", True):
271 4
        args[0] = os.path.join(bin_prefix, args[0])
272

273
    # Add coverage testing
274 4
    if kwargs.pop("coverage", False):
275 4
        coverage_dir = os.path.join(bin_prefix, "coverage")
276 4
        if not os.path.exists(coverage_dir):
277 0
            print("Could not find Python coverage, skipping cov.")
278

279
        else:
280 4
            src_dir = os.path.dirname(os.path.abspath(__file__))
281 4
            coverage_flags = [coverage_dir, "run", "--parallel-mode", "--source=" + src_dir]
282

283
            # If python script, skip the python bin
284 4
            if args[0].endswith("python"):
285 0
                args.pop(0)
286 4
            args = coverage_flags + args
287

288
    # Do we optionally dumpstdout?
289 4
    dump_stdout = kwargs.pop("dump_stdout", False)
290

291 4
    if sys.platform.startswith("win"):
292
        # Allow using CTRL_C_EVENT / CTRL_BREAK_EVENT
293 0
        kwargs["creationflags"] = subprocess.CREATE_NEW_PROCESS_GROUP
294

295 4
    kwargs["stdout"] = subprocess.PIPE
296 4
    kwargs["stderr"] = subprocess.PIPE
297 4
    proc = subprocess.Popen(args, **kwargs)
298 4
    try:
299 4
        yield proc
300 0
    except Exception:
301 0
        dump_stdout = True
302 0
        raise
303

304
    finally:
305 4
        try:
306 4
            terminate_process(proc)
307
        finally:
308 4
            output, error = proc.communicate()
309 4
            if dump_stdout:
310 4
                print("\n" + "-" * 30)
311 4
                print("\n|| Process command: {}".format(" ".join(args)))
312 4
                print("\n|| Process stderr: \n{}".format(error.decode()))
313 4
                print("-" * 30)
314 4
                print("\n|| Process stdout: \n{}".format(output.decode()))
315 4
                print("-" * 30)
316

317

318 4
def run_process(args, **kwargs):
319
    """
320
    Runs a process in the background until complete.
321

322
    Returns True if exit code zero.
323
    """
324

325 4
    timeout = kwargs.pop("timeout", 30)
326 4
    terminate_after = kwargs.pop("interupt_after", None)
327 4
    with popen(args, **kwargs) as proc:
328 4
        if terminate_after is None:
329 4
            proc.wait(timeout=timeout)
330
        else:
331 4
            time.sleep(terminate_after)
332 4
            terminate_process(proc)
333

334 4
        retcode = proc.poll()
335

336 4
    return retcode == 0
337

338

339
### Server testing mechanics
340

341

342 4
@pytest.fixture(scope="session")
343
def postgres_server():
344

345 4
    if shutil.which("psql") is None:
346 0
        pytest.skip("Postgres is not installed on this server and no active postgres could be found.")
347

348 4
    storage = None
349 4
    psql = PostgresHarness({"database": {"port": 5432}})
350
    # psql = PostgresHarness({"database": {"port": 5432, "username": "qcarchive", "password": "mypass"}})
351 4
    if not psql.is_alive():
352 4
        print()
353 4
        print(
354
            f"Could not connect to a Postgres server at {psql.config.database_uri()}, this will increase time per test session by ~3 seconds."
355
        )
356 4
        print()
357 4
        storage = TemporaryPostgres()
358 4
        psql = storage.psql
359 4
        print("Using Database: ", psql.config.database_uri())
360

361 4
    yield psql
362

363 4
    if storage:
364 4
        storage.stop()
365

366

367 4
def reset_server_database(server):
368
    """Resets the server database for testing."""
369 4
    if "QCFRACTAL_RESET_TESTING_DB" in os.environ:
370 0
        server.storage._clear_db(server.storage._project_name)
371

372 4
    server.storage._delete_DB_data(server.storage._project_name)
373

374
    # Force a heartbeat after database clean if a manager is present.
375 4
    if server.queue_socket:
376 4
        server.await_results()
377

378

379 4
@pytest.fixture(scope="module")
380
def test_server(request, postgres_server):
381
    """
382
    Builds a server instance with the event loop running in a thread.
383
    """
384

385
    # Storage name
386 4
    storage_name = "test_qcfractal_server"
387 4
    postgres_server.create_database(storage_name)
388

389 4
    with FractalSnowflake(
390
        max_workers=0,
391
        storage_project_name="test_qcfractal_server",
392
        storage_uri=postgres_server.database_uri(),
393
        start_server=False,
394
        reset_database=True,
395
    ) as server:
396

397
        # Clean and re-init the database
398 4
        yield server
399

400

401 4
def build_adapter_clients(mtype, storage_name="test_qcfractal_compute_server"):
402

403
    # Basic boot and loop information
404 4
    if mtype == "pool":
405 4
        from concurrent.futures import ProcessPoolExecutor
406

407 4
        adapter_client = ProcessPoolExecutor(max_workers=2)
408

409 4
    elif mtype == "dask":
410 4
        dd = pytest.importorskip("dask.distributed")
411 1
        adapter_client = dd.Client(n_workers=2, threads_per_worker=1, resources={"process": 1})
412

413
        # Not super happy about this line, but shuts up dangling reference errors
414 1
        adapter_client._should_close_loop = False
415

416 4
    elif mtype == "fireworks":
417 4
        fireworks = pytest.importorskip("fireworks")
418

419 1
        fireworks_name = storage_name + "_fireworks_queue"
420 1
        adapter_client = fireworks.LaunchPad(name=fireworks_name, logdir="/tmp/", strm_lvl="CRITICAL")
421

422 4
    elif mtype == "parsl":
423 4
        parsl = pytest.importorskip("parsl")
424

425
        # Must only be a single thread as we run thread unsafe applications.
426 1
        adapter_client = parsl.config.Config(executors=[parsl.executors.threads.ThreadPoolExecutor(max_threads=1)])
427

428
    else:
429 0
        raise TypeError("fractal_compute_server: internal parametrize error")
430

431 4
    return adapter_client
432

433

434 4
@pytest.fixture(scope="module", params=_adapter_testing)
435
def adapter_client_fixture(request):
436 4
    adapter_client = build_adapter_clients(request.param)
437 4
    yield adapter_client
438

439
    # Do a final close with existing tech
440 4
    build_queue_adapter(adapter_client).close()
441

442

443 4
@pytest.fixture(scope="module", params=_adapter_testing)
444
def managed_compute_server(request, postgres_server):
445
    """
446
    A FractalServer with compute associated parametrize for all managers.
447
    """
448

449 2
    storage_name = "test_qcfractal_compute_server"
450 2
    postgres_server.create_database(storage_name)
451

452 2
    adapter_client = build_adapter_clients(request.param, storage_name=storage_name)
453

454
    # Build a server with the thread in a outer context loop
455
    # Not all adapters play well with internal loops
456 2
    with loop_in_thread() as loop:
457 2
        server = FractalServer(
458
            port=find_open_port(),
459
            storage_project_name=storage_name,
460
            storage_uri=postgres_server.database_uri(),
461
            loop=loop,
462
            queue_socket=adapter_client,
463
            ssl_options=False,
464
            skip_storage_version_check=True,
465
        )
466

467
        # Clean and re-init the database
468 2
        reset_server_database(server)
469

470
        # Build Client and Manager
471 2
        from qcfractal.interface import FractalClient
472

473 2
        client = FractalClient(server)
474

475 2
        from qcfractal.queue import QueueManager
476

477 2
        manager = QueueManager(client, adapter_client)
478

479 2
        yield client, server, manager
480

481
        # Close down and clean the adapter
482 2
        manager.close_adapter()
483 2
        manager.stop()
484

485

486 4
@pytest.fixture(scope="module")
487
def fractal_compute_server(postgres_server):
488
    """
489
    A FractalServer with a local Pool manager.
490
    """
491

492
    # Storage name
493 4
    storage_name = "test_qcfractal_compute_snowflake"
494 4
    postgres_server.create_database(storage_name)
495

496 4
    with FractalSnowflake(
497
        max_workers=2,
498
        storage_project_name=storage_name,
499
        storage_uri=postgres_server.database_uri(),
500
        reset_database=True,
501
        start_server=False,
502
    ) as server:
503 4
        reset_server_database(server)
504 4
        yield server
505

506

507 4
def build_socket_fixture(stype, server=None):
508 4
    print("")
509

510
    # Check mongo
511 4
    storage_name = "test_qcfractal_storage" + stype
512

513
    # IP/port/drop table is specific to build
514 4
    if stype == "sqlalchemy":
515

516 4
        server.create_database(storage_name)
517 4
        storage = storage_socket_factory(server.database_uri(), storage_name, db_type=stype, sql_echo=False)
518

519
        # Clean and re-init the database
520 4
        storage._clear_db(storage_name)
521
    else:
522 0
        raise KeyError("Storage type {} not understood".format(stype))
523

524 4
    yield storage
525

526 4
    if stype == "sqlalchemy":
527
        # todo: drop db
528
        # storage._clear_db(storage_name)
529 4
        pass
530
    else:
531 0
        raise KeyError("Storage type {} not understood".format(stype))
532

533

534 4
@pytest.fixture(scope="module", params=["sqlalchemy"])
535
def socket_fixture(request):
536

537 0
    yield from build_socket_fixture(request.param)
538

539

540 4
@pytest.fixture(scope="module")
541
def sqlalchemy_socket_fixture(request, postgres_server):
542

543 4
    yield from build_socket_fixture("sqlalchemy", postgres_server)
544

545

546 4
def live_fractal_or_skip():
547
    """
548
    Ensure Fractal live connection can be made
549
    First looks for a local staging server, then tries QCArchive.
550
    """
551 4
    try:
552 4
        return FractalClient("localhost:7777", verify=False)
553 4
    except (requests.exceptions.ConnectionError, ConnectionRefusedError):
554 4
        print("Failed to connect to localhost, trying MolSSI QCArchive.")
555 4
        try:
556 4
            requests.get("https://api.qcarchive.molssi.org:443", json={}, timeout=5)
557 4
            return FractalClient()
558 0
        except (requests.exceptions.ConnectionError, ConnectionRefusedError):
559 0
            return pytest.skip("Could not make a connection to central Fractal server")
560

561

562 4
def df_compare(df1, df2, sort=False):
563
    """ checks equality even when columns contain numpy arrays, which .equals and == struggle with """
564 4
    if sort:
565 4
        if isinstance(df1, pd.DataFrame):
566 4
            df1 = df1.reindex(sorted(df1.columns), axis=1)
567 4
        elif isinstance(df1, pd.Series):
568 4
            df1 = df1.sort_index()
569 4
        if isinstance(df2, pd.DataFrame):
570 4
            df2 = df2.reindex(sorted(df2.columns), axis=1)
571 4
        elif isinstance(df2, pd.Series):
572 4
            df2 = df2.sort_index()
573

574 4
    def element_equal(e1, e2):
575 4
        if isinstance(e1, np.ndarray):
576 4
            if not np.array_equal(e1, e2):
577 0
                return False
578 4
        elif isinstance(e1, Molecule):
579 4
            if not e1.get_hash() == e2.get_hash():
580 0
                return False
581
        # Because nan != nan
582 4
        elif isinstance(e1, float) and np.isnan(e1):
583 0
            if not np.isnan(e2):
584 0
                return False
585
        else:
586 4
            if not e1 == e2:
587 0
                return False
588 4
        return True
589

590 4
    if isinstance(df1, pd.Series):
591 4
        if not isinstance(df2, pd.Series):
592 0
            return False
593 4
        if len(df1) != len(df2):
594 0
            return False
595 4
        for i in range(len(df1)):
596 4
            if not element_equal(df1[i], df2[i]):
597 0
                return False
598 4
        return True
599

600 4
    for column in df1.columns:
601 4
        if column.startswith("_"):
602 0
            df1.drop(column, axis=1, inplace=True)
603 4
    for column in df2.columns:
604 4
        if column.startswith("_"):
605 0
            df2.drop(column, axis=1, inplace=True)
606 4
    if not all(df1.columns == df2.columns):
607 0
        return False
608 4
    if not all(df1.index.values == df2.index.values):
609 0
        return False
610 4
    for i in range(df1.shape[0]):
611 4
        for j in range(df1.shape[1]):
612 4
            if not element_equal(df1.iloc[i, j], df2.iloc[i, j]):
613 0
                return False
614

615 4
    return True

Read our documentation on viewing source code .

Loading