torch.jit.ignore as a context manager (#55172)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55172

Description:
This is part 1 of series of PRs for supporting torch.jit.ignore as context manager. Following features are implemented in this PR:

- Unique name for the registered function under torch.jit.frontend module. The unique name is generated based on the file name and line number of context manager
- Forcing user to explicitly annotate the input and outputs.
- No side effects are considered.

Test Plan: Imported from OSS

Reviewed By: gmagogsfm

Differential Revision: D27895283

Pulled By: tugsbayasgalan

fbshipit-source-id: 5d36d9aa5d457055a6bb1676f264647a745ec36a
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2021-05-14 01:52:38 -07:00 committed by Facebook GitHub Bot
parent cf1daf571d
commit 88ff651e90
9 changed files with 248 additions and 10 deletions

View File

@ -74,16 +74,16 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
SCIPY_VERSION=1.1.0
if [ "$ANACONDA_PYTHON_VERSION" = "3.9" ]; then
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
conda_install numpy=1.19.2 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 -c conda-forge
conda_install numpy=1.19.2 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 -c conda-forge
SCIPY_VERSION=1.6.0
elif [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0
elif [ "$ANACONDA_PYTHON_VERSION" = "3.7" ]; then
# DO NOT install dataclasses if installing python-3.7, since its part of python-3.7 core packages
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six typing_extensions
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six typing_extensions
else
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six dataclasses typing_extensions
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six dataclasses typing_extensions
fi
if [[ "$CUDA_VERSION" == 10.0* ]]; then

View File

@ -195,7 +195,7 @@ Other potentially useful environment variables may be found in `setup.py`.
Common
```bash
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
```
On Linux

View File

@ -1,4 +1,5 @@
# Python dependencies required for development
astunparse
future
numpy
psutil

View File

@ -0,0 +1,104 @@
import os
import sys
import unittest
import torch
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
if __name__ == "__main__":
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead.")
class TestIgnoreContextManager(JitTestCase):
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_inp_out(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self):
a: int = 4
b: int = 5
c: int = 0
d: int = 6
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
d = 9
return c + d
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), model())
self.assertEqual(s(), 20)
class B(torch.nn.Module):
def __init__(self):
super(B, self).__init__()
def forward(self):
a: int = 4
b: int = 5
c: int = 0
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"):
l = [2 for i in range(a) if i > 2]
c = l[0] + a + b
return c
model = B()
s = torch.jit.script(model)
self.assertEqual(s(), 11)
self.assertEqual(s(), model())
class C(torch.nn.Module):
def __init__(self):
super(C, self).__init__()
def forward(self):
a: int = 4
b: int = 5
with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
l = [2 for i in range(a) if i > 2]
b = l[0] + a
return b
model = C()
s = torch.jit.script(model)
self.assertEqual(s(), 6)
self.assertEqual(s(), model())
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_just_inp(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self):
a: int = 4
b: int = 5
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
l = [2 + b for i in range(a) if i > 2]
return a
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 4)
self.assertEqual(s(), model())
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
def test_with_ignore_context_manager_with_just_out(self):
class A(torch.nn.Module):
def __init__(self):
super(A, self).__init__()
def forward(self):
with torch.jit._IgnoreContextManager(c="out:List[int]"):
c = [2 for i in range(7) if i > 2]
c[0] = 3
return c[0] + c[1]
model = A()
s = torch.jit.script(model)
self.assertEqual(s(), 5)
self.assertEqual(s(), model())

View File

@ -18,6 +18,7 @@ from jit.test_custom_operators import TestCustomOperators # noqa: F401
from jit.test_export_modes import TestExportModes # noqa: F401
from jit.test_class_type import TestClassType # noqa: F401
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
from jit.test_peephole import TestPeephole # noqa: F401

View File

@ -41,9 +41,17 @@ class JitPlugin(CoveragePlugin):
filename = getsourcefile(obj)
# We don't want to report for filename = None
if filename:
sourcelines, starting_lineno = getsourcelines(obj)
line_data = {filename: range(starting_lineno, starting_lineno + len(sourcelines))}
cov_data.add_lines(line_data)
# TODO: Because torch.jit._IgnoreContextManager relies on Python's `exec` method
# which doesn't generate source codelines, getsourcelines(obj) fails. For now,
# we just ignore the exception until we figure out a better way to
# implement torch.jit._IgnoreContextManager.
try:
sourcelines, starting_lineno = getsourcelines(obj)
except OSError:
pass
else:
line_data = {filename: range(starting_lineno, starting_lineno + len(sourcelines))}
cov_data.add_lines(line_data)
super().dynamic_context(frame)
def coverage_init(reg, options):

View File

@ -540,6 +540,14 @@ def unused(fn):
fn._torchscript_modifier = FunctionModifiers.UNUSED
return fn
# No op context manager from python side
class _IgnoreContextManager(contextlib.AbstractContextManager):
def __init__(self, **kwargs):
pass
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
pass
def ignore(drop=False, **kwargs):
"""
This decorator indicates to the compiler that a function or method should
@ -961,6 +969,7 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
self.uses_true_division = uses_true_division
self.filename = filename
def fake_range():

View File

@ -9,6 +9,7 @@ from torch.utils import set_module
from torch._jit_internal import (
Final,
Future,
_IgnoreContextManager,
_overload,
_overload_method,
ignore,

View File

@ -3,8 +3,9 @@ import sys
import ast
import inspect
import string
from collections import namedtuple
from textwrap import dedent
from typing import List
from typing import List, Tuple # noqa: F401
from torch._C._jit_tree_views import (
ClassDef, Ident, Stmt, Decl, Def, Var,
EmptyTypeAnnotation, Param, ExprStmt, Assign,
@ -18,9 +19,16 @@ from torch._C._jit_tree_views import (
)
from torch._utils_internal import get_source_lines_and_file
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
from torch._jit_internal import SourceContext, should_drop, is_static_fn
from torch._jit_internal import SourceContext, should_drop, is_static_fn, FunctionModifiers # noqa: F401
import torch.jit.annotations
_IS_ASTUNPARSE_INSTALLED = False
try:
import astunparse # type: ignore[import]
_IS_ASTUNPARSE_INSTALLED = True
except ImportError:
pass
# Borrowed from cPython implementation
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
@ -299,6 +307,22 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
# TODO: more robust handling of recognizing ignore context manager
def is_torch_jit_ignore_context_manager(stmt):
# checks if the statement is torch.jit.ignore context manager
if isinstance(stmt.items[0].context_expr, ast.Call):
# extract torch part
function = stmt.items[0].context_expr.func
if isinstance(function, ast.Attribute):
attr_name = function.attr
attr_value = function.value
if attr_name == "_IgnoreContextManager" and isinstance(attr_value, ast.Attribute):
# there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager)
if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name):
if attr_value.value.id == "torch":
return True
return False
class Builder(object):
def __call__(self, ctx, node):
method = getattr(self, 'build_' + node.__class__.__name__, None)
@ -382,6 +406,89 @@ def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
annotation_expr = EmptyTypeAnnotation(r)
return Param(annotation_expr, Ident(r, name), kwarg_only)
def build_ignore_context_manager(ctx, stmt):
InputType = namedtuple('InputType', ['name', 'ann'])
OutputType = namedtuple('OutputType', ['name', 'ann'])
def process_ins_outs(args):
# parse the context manager to figure out inputs and outputs
# with their annotated types
# TODO: add input, output validator
inputs = []
outputs = []
for arg in args:
var_name = arg.arg
if sys.version_info < (3, 8):
# Starting python3.8 ast.Str is deprecated
var_ann = arg.value.s
else:
var_ann = arg.value.value
var_decl_type, var_ann = var_ann.split(":")
if var_decl_type == "inp":
inputs.append(InputType(var_name, var_ann))
if var_decl_type == "out":
outputs.append(OutputType(var_name, var_ann))
return inputs, outputs
def create_unique_name_ext(ctx, stmt):
# extension will be based on the full path filename plus
# the line number of original context manager
return ctx.filename.replace(".", "_").replace("/", "_") + "_" + str(stmt.lineno)
def build_return_ann_stmt(outputs):
return_type_ann = ""
return_statement_str = "return "
if len(outputs) == 0:
return_type_ann += " -> None"
if len(outputs) == 1:
return_type_ann = " -> " + outputs[0].ann
return_statement_str += outputs[0].name
if len(outputs) > 1:
return_type_ann = " -> Tuple"
return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
return_statement_str += ", ".join([var.name for var in outputs])
return return_type_ann, return_statement_str
def build_args(args):
return ", ".join([arg.name for arg in args])
inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)
# build the replacement function str with given inputs and outputs
ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
ignore_function_str = "\ndef " + ignore_function_name
ignore_function_str += "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"
return_ann, return_stmt = build_return_ann_stmt(outputs)
ignore_function_str += return_ann + ": pass"
# first create the functionDef object from just declaration
ignore_function = ast.parse(ignore_function_str).body[0]
# dump the body of context manager to dummy function
ignore_function.body = stmt.body # type: ignore[attr-defined]
# insert return statement to the function
return_stmt = ast.parse(return_stmt).body[0]
ignore_function.body.append(return_stmt) # type: ignore[attr-defined]
# registers the custom function in the global context
ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
ignore_func_str += "\nglobals()[\"{}\"] = {}".format(ignore_function_name, ignore_function_name)
exec(ignore_func_str) # noqa: P204
# build the statements as:
# <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
assign_str_lhs = build_args(outputs)
# this function will be registered in torch.jit.frontend module by default
assign_str_rhs = "torch.jit.frontend.{}(".format(ignore_function_name) + build_args(inputs) + ")"
if len(outputs) > 0:
assign_str = assign_str_lhs + " = " + assign_str_rhs
else:
assign_str = assign_str_rhs
assign_ast = ast.parse(assign_str).body[0]
return assign_ast
def get_default_args(fn):
if fn is None:
@ -563,6 +670,13 @@ class StmtBuilder(Builder):
@staticmethod
def build_With(ctx, stmt):
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with"))
# Handle ignore context manager
if is_torch_jit_ignore_context_manager(stmt):
if not _IS_ASTUNPARSE_INSTALLED:
raise RuntimeError("torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\
please install it in your Python environment")
assign_ast = build_ignore_context_manager(ctx, stmt)
return build_stmt(ctx, assign_ast)
return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body))
class ExprBuilder(Builder):