import ast import functools import inspect from textwrap import dedent from typing import Any, Optional, Tuple, List, NamedTuple from torch._C import ErrorReport from torch._C._jit_tree_views import SourceRangeFactory def get_source_lines_and_file( obj: Any, error_msg: Optional[str] = None, ) -> Tuple[List[str], int, Optional[str]]: """ Wrapper around inspect.getsourcelines and inspect.getsourcefile. Returns: (sourcelines, file_lino, filename) """ filename = None # in case getsourcefile throws try: filename = inspect.getsourcefile(obj) sourcelines, file_lineno = inspect.getsourcelines(obj) except OSError as e: msg = (f"Can't get source for {obj}. TorchScript requires source access in " "order to carry out compilation, make sure original .py files are " "available.") if error_msg: msg += '\n' + error_msg raise OSError(msg) from e return sourcelines, file_lineno, filename def normalize_source_lines(sourcelines: List[str]) -> List[str]: """ This helper function accepts a list of source lines. It finds the indentation level of the function definition (`def`), then it indents all lines in the function body to a point at or greater than that level. This allows for comments and continued string literals that are at a lower indentation than the rest of the code. Args: sourcelines: function source code, separated into lines by the '\n' character Returns: A list of source lines that have been correctly aligned """ def remove_prefix(text, prefix): return text[text.startswith(prefix) and len(prefix):] # Find the line and line number containing the function definition for i, l in enumerate(sourcelines): if l.lstrip().startswith("def"): idx = i break fn_def = sourcelines[idx] # Get a string representing the amount of leading whitespace whitespace = fn_def.split("def")[0] # Add this leading whitespace to all lines before and after the `def` aligned_prefix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[:idx]] aligned_suffix = [whitespace + remove_prefix(s, whitespace) for s in sourcelines[idx + 1:]] # Put it together again aligned_prefix.append(fn_def) return aligned_prefix + aligned_suffix # Thin wrapper around SourceRangeFactory to store extra metadata # about the function-to-be-compiled. class SourceContext(SourceRangeFactory): def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True): super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len) self.uses_true_division = uses_true_division self.filename = filename @functools.lru_cache(maxsize=None) def make_source_context(*args): return SourceContext(*args) def fake_range(): return SourceContext('', None, 0, 0).make_raw_range(0, 1) class ParsedDef(NamedTuple): ast: ast.Module ctx: SourceContext source: str filename: Optional[str] file_lineno: int def parse_def(fn): sourcelines, file_lineno, filename = get_source_lines_and_file(fn, ErrorReport.call_stack()) sourcelines = normalize_source_lines(sourcelines) source = ''.join(sourcelines) dedent_src = dedent(source) py_ast = ast.parse(dedent_src) if len(py_ast.body) != 1 or not isinstance(py_ast.body[0], ast.FunctionDef): raise RuntimeError(f"Expected a single top-level function: {filename}:{file_lineno}") leading_whitespace_len = len(source.split('\n', 1)[0]) - len(dedent_src.split('\n', 1)[0]) ctx = make_source_context(source, filename, file_lineno, leading_whitespace_len, True) return ParsedDef(py_ast, ctx, source, filename, file_lineno)