1
#! /usr/bin/env python3
2
# -*- coding: utf-8 -`-
3 27
"""
4
Code generation script for class methods
5
to be exported as public API
6
"""
7 27
import argparse
8 27
import ast
9 27
import astor
10 27
import os
11 27
from pathlib import Path
12 27
import sys
13

14 27
from textwrap import indent
15

16 27
PREFIX = "_generated"
17

18 27
HEADER = """# ***********************************************************
19
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
20
# *************************************************************
21
from ._run import GLOBAL_RUN_CONTEXT, _NO_SEND
22
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
23
from ._instrumentation import Instrument
24

25
# fmt: off
26
"""
27

28 27
FOOTER = """# fmt: on
29
"""
30

31 27
TEMPLATE = """locals()[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
32
try:
33
    return{}GLOBAL_RUN_CONTEXT.{}.{}
34
except AttributeError:
35
    raise RuntimeError("must be called from async context")
36
"""
37

38

39 27
def is_function(node):
40
    """Check if the AST node is either a function
41
    or an async function
42
    """
43 27
    if isinstance(node, ast.FunctionDef) or isinstance(node, ast.AsyncFunctionDef):
44 27
        return True
45 27
    return False
46

47

48 27
def is_public(node):
49
    """Check if the AST node has a _public decorator"""
50 27
    if not is_function(node):
51 27
        return False
52 27
    for decorator in node.decorator_list:
53 27
        if isinstance(decorator, ast.Name) and decorator.id == "_public":
54 27
            return True
55 27
    return False
56

57

58 27
def get_public_methods(tree):
59
    """Return a list of methods marked as public.
60
    The function walks the given tree and extracts
61
    all objects that are functions which are marked
62
    public.
63
    """
64 27
    for node in ast.walk(tree):
65 27
        if is_public(node):
66 27
            yield node
67

68

69 27
def create_passthrough_args(funcdef):
70
    """Given a function definition, create a string that represents taking all
71
    the arguments from the function, and passing them through to another
72
    invocation of the same function.
73

74
    Example input: ast.parse("def f(a, *, b): ...")
75
    Example output: "(a, b=b)"
76
    """
77 27
    call_args = []
78 27
    for arg in funcdef.args.args:
79 27
        call_args.append(arg.arg)
80 27
    if funcdef.args.vararg:
81 27
        call_args.append("*" + funcdef.args.vararg.arg)
82 27
    for arg in funcdef.args.kwonlyargs:
83 27
        call_args.append(arg.arg + "=" + arg.arg)
84 27
    if funcdef.args.kwarg:
85 27
        call_args.append("**" + funcdef.args.kwarg.arg)
86 27
    return "({})".format(", ".join(call_args))
87

88

89 27
def gen_public_wrappers_source(source_path: Path, lookup_path: str) -> str:
90
    """Scan the given .py file for @_public decorators, and generate wrapper
91
    functions.
92

93
    """
94 27
    generated = [HEADER]
95 27
    source = astor.code_to_ast.parse_file(source_path)
96 27
    for method in get_public_methods(source):
97
        # Remove self from arguments
98 27
        assert method.args.args[0].arg == "self"
99 27
        del method.args.args[0]
100

101
        # Remove decorators
102 27
        method.decorator_list = []
103

104
        # Create pass through arguments
105 27
        new_args = create_passthrough_args(method)
106

107
        # Remove method body without the docstring
108 27
        if ast.get_docstring(method) is None:
109 27
            del method.body[:]
110
        else:
111
            # The first entry is always the docstring
112 27
            del method.body[1:]
113

114
        # Create the function definition including the body
115 27
        func = astor.to_source(method, indent_with=" " * 4)
116

117
        # Create export function body
118 27
        template = TEMPLATE.format(
119
            " await " if isinstance(method, ast.AsyncFunctionDef) else " ",
120
            lookup_path,
121
            method.name + new_args,
122
        )
123

124
        # Assemble function definition arguments and body
125 27
        snippet = func + indent(template, " " * 4)
126

127
        # Append the snippet to the corresponding module
128 27
        generated.append(snippet)
129 27
    generated.append(FOOTER)
130 27
    return "\n\n".join(generated)
131

132

133 27
def matches_disk_files(new_files):
134 27
    for new_path, new_source in new_files.items():
135 27
        if not os.path.exists(new_path):
136 27
            return False
137 27
        with open(new_path, "r", encoding="utf-8") as old_file:
138 27
            old_source = old_file.read()
139 27
        if old_source != new_source:
140 27
            return False
141 27
    return True
142

143

144 27
def process(sources_and_lookups, *, do_test):
145 27
    new_files = {}
146 27
    for source_path, lookup_path in sources_and_lookups:
147 27
        print("Scanning:", source_path)
148 27
        new_source = gen_public_wrappers_source(source_path, lookup_path)
149 27
        dirname, basename = os.path.split(source_path)
150 27
        new_path = os.path.join(dirname, PREFIX + basename)
151 27
        new_files[new_path] = new_source
152 27
    if do_test:
153 27
        if not matches_disk_files(new_files):
154 27
            print("Generated sources are outdated. Please regenerate.")
155 27
            sys.exit(1)
156
        else:
157 27
            print("Generated sources are up to date.")
158
    else:
159 27
        for new_path, new_source in new_files.items():
160 27
            with open(new_path, "w", encoding="utf-8") as f:
161 27
                f.write(new_source)
162 27
        print("Regenerated sources successfully.")
163

164

165
# This is in fact run in CI, but only in the formatting check job, which
166
# doesn't collect coverage.
167
def main():  # pragma: no cover
168
    parser = argparse.ArgumentParser(
169
        description="Generate python code for public api wrappers"
170
    )
171
    parser.add_argument(
172
        "--test", "-t", action="store_true", help="test if code is still up to date"
173
    )
174
    parsed_args = parser.parse_args()
175

176
    source_root = Path.cwd()
177
    # Double-check we found the right directory
178
    assert (source_root / "LICENSE").exists()
179
    core = source_root / "trio/_core"
180
    to_wrap = [
181
        (core / "_run.py", "runner"),
182
        (core / "_instrumentation.py", "runner.instruments"),
183
        (core / "_io_windows.py", "runner.io_manager"),
184
        (core / "_io_epoll.py", "runner.io_manager"),
185
        (core / "_io_kqueue.py", "runner.io_manager"),
186
    ]
187

188
    process(to_wrap, do_test=parsed_args.test)
189

190

191
if __name__ == "__main__":  # pragma: no cover
192
    main()

Read our documentation on viewing source code .

Loading