mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
torch.jit.ignore as a context manager (#55172)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55172 Description: This is part 1 of series of PRs for supporting torch.jit.ignore as context manager. Following features are implemented in this PR: - Unique name for the registered function under torch.jit.frontend module. The unique name is generated based on the file name and line number of context manager - Forcing user to explicitly annotate the input and outputs. - No side effects are considered. Test Plan: Imported from OSS Reviewed By: gmagogsfm Differential Revision: D27895283 Pulled By: tugsbayasgalan fbshipit-source-id: 5d36d9aa5d457055a6bb1676f264647a745ec36a
This commit is contained in:
parent
cf1daf571d
commit
88ff651e90
|
|
@ -74,16 +74,16 @@ if [ -n "$ANACONDA_PYTHON_VERSION" ]; then
|
|||
SCIPY_VERSION=1.1.0
|
||||
if [ "$ANACONDA_PYTHON_VERSION" = "3.9" ]; then
|
||||
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
|
||||
conda_install numpy=1.19.2 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 -c conda-forge
|
||||
conda_install numpy=1.19.2 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0 -c conda-forge
|
||||
SCIPY_VERSION=1.6.0
|
||||
elif [ "$ANACONDA_PYTHON_VERSION" = "3.8" ]; then
|
||||
# Install llvm-8 as it is required to compile llvmlite-0.30.0 from source
|
||||
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0
|
||||
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six llvmdev=8.0.0
|
||||
elif [ "$ANACONDA_PYTHON_VERSION" = "3.7" ]; then
|
||||
# DO NOT install dataclasses if installing python-3.7, since its part of python-3.7 core packages
|
||||
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six typing_extensions
|
||||
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six typing_extensions
|
||||
else
|
||||
conda_install numpy=1.18.5 pyyaml mkl mkl-include setuptools cffi future six dataclasses typing_extensions
|
||||
conda_install numpy=1.18.5 astunparse pyyaml mkl mkl-include setuptools cffi future six dataclasses typing_extensions
|
||||
fi
|
||||
|
||||
if [[ "$CUDA_VERSION" == 10.0* ]]; then
|
||||
|
|
|
|||
|
|
@ -195,7 +195,7 @@ Other potentially useful environment variables may be found in `setup.py`.
|
|||
|
||||
Common
|
||||
```bash
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
|
||||
conda install astunparse numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
|
||||
```
|
||||
|
||||
On Linux
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
# Python dependencies required for development
|
||||
astunparse
|
||||
future
|
||||
numpy
|
||||
psutil
|
||||
|
|
|
|||
104
test/jit/test_ignore_context_manager.py
Normal file
104
test/jit/test_ignore_context_manager.py
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
# Make the helper files in test/ importable
|
||||
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
||||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
from torch.jit.frontend import _IS_ASTUNPARSE_INSTALLED
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
"\tpython test/test_jit.py TESTNAME\n\n"
|
||||
"instead.")
|
||||
|
||||
class TestIgnoreContextManager(JitTestCase):
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
def test_with_ignore_context_manager_with_inp_out(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(A, self).__init__()
|
||||
|
||||
def forward(self):
|
||||
a: int = 4
|
||||
b: int = 5
|
||||
c: int = 0
|
||||
d: int = 6
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int", d="out:int"):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
c = l[0] + a + b
|
||||
d = 9
|
||||
return c + d
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), model())
|
||||
self.assertEqual(s(), 20)
|
||||
|
||||
class B(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(B, self).__init__()
|
||||
|
||||
def forward(self):
|
||||
a: int = 4
|
||||
b: int = 5
|
||||
c: int = 0
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int", c="out:int"):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
c = l[0] + a + b
|
||||
return c
|
||||
model = B()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 11)
|
||||
self.assertEqual(s(), model())
|
||||
|
||||
class C(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(C, self).__init__()
|
||||
|
||||
def forward(self):
|
||||
a: int = 4
|
||||
b: int = 5
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="out:int"):
|
||||
l = [2 for i in range(a) if i > 2]
|
||||
b = l[0] + a
|
||||
return b
|
||||
model = C()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 6)
|
||||
self.assertEqual(s(), model())
|
||||
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
def test_with_ignore_context_manager_with_just_inp(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(A, self).__init__()
|
||||
|
||||
def forward(self):
|
||||
a: int = 4
|
||||
b: int = 5
|
||||
with torch.jit._IgnoreContextManager(a="inp:int", b="inp:int"):
|
||||
l = [2 + b for i in range(a) if i > 2]
|
||||
return a
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 4)
|
||||
self.assertEqual(s(), model())
|
||||
|
||||
@unittest.skipUnless(_IS_ASTUNPARSE_INSTALLED, "astunparse package is required")
|
||||
def test_with_ignore_context_manager_with_just_out(self):
|
||||
class A(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(A, self).__init__()
|
||||
|
||||
def forward(self):
|
||||
with torch.jit._IgnoreContextManager(c="out:List[int]"):
|
||||
c = [2 for i in range(7) if i > 2]
|
||||
c[0] = 3
|
||||
return c[0] + c[1]
|
||||
model = A()
|
||||
s = torch.jit.script(model)
|
||||
self.assertEqual(s(), 5)
|
||||
self.assertEqual(s(), model())
|
||||
|
|
@ -18,6 +18,7 @@ from jit.test_custom_operators import TestCustomOperators # noqa: F401
|
|||
from jit.test_export_modes import TestExportModes # noqa: F401
|
||||
from jit.test_class_type import TestClassType # noqa: F401
|
||||
from jit.test_builtins import TestBuiltins, TestTensorBuiltins # noqa: F401
|
||||
from jit.test_ignore_context_manager import TestIgnoreContextManager # noqa: F401
|
||||
from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401
|
||||
from jit.test_freezing import TestFreezing, TestFrozenOptimizations, TestMKLDNNReinplacing # noqa: F401
|
||||
from jit.test_peephole import TestPeephole # noqa: F401
|
||||
|
|
|
|||
|
|
@ -41,9 +41,17 @@ class JitPlugin(CoveragePlugin):
|
|||
filename = getsourcefile(obj)
|
||||
# We don't want to report for filename = None
|
||||
if filename:
|
||||
sourcelines, starting_lineno = getsourcelines(obj)
|
||||
line_data = {filename: range(starting_lineno, starting_lineno + len(sourcelines))}
|
||||
cov_data.add_lines(line_data)
|
||||
# TODO: Because torch.jit._IgnoreContextManager relies on Python's `exec` method
|
||||
# which doesn't generate source codelines, getsourcelines(obj) fails. For now,
|
||||
# we just ignore the exception until we figure out a better way to
|
||||
# implement torch.jit._IgnoreContextManager.
|
||||
try:
|
||||
sourcelines, starting_lineno = getsourcelines(obj)
|
||||
except OSError:
|
||||
pass
|
||||
else:
|
||||
line_data = {filename: range(starting_lineno, starting_lineno + len(sourcelines))}
|
||||
cov_data.add_lines(line_data)
|
||||
super().dynamic_context(frame)
|
||||
|
||||
def coverage_init(reg, options):
|
||||
|
|
|
|||
|
|
@ -540,6 +540,14 @@ def unused(fn):
|
|||
fn._torchscript_modifier = FunctionModifiers.UNUSED
|
||||
return fn
|
||||
|
||||
# No op context manager from python side
|
||||
class _IgnoreContextManager(contextlib.AbstractContextManager):
|
||||
def __init__(self, **kwargs):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
||||
pass
|
||||
|
||||
def ignore(drop=False, **kwargs):
|
||||
"""
|
||||
This decorator indicates to the compiler that a function or method should
|
||||
|
|
@ -961,6 +969,7 @@ class SourceContext(torch._C._jit_tree_views.SourceRangeFactory):
|
|||
def __init__(self, source, filename, file_lineno, leading_whitespace_len, uses_true_division=True):
|
||||
super(SourceContext, self).__init__(source, filename, file_lineno, leading_whitespace_len)
|
||||
self.uses_true_division = uses_true_division
|
||||
self.filename = filename
|
||||
|
||||
|
||||
def fake_range():
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch.utils import set_module
|
|||
from torch._jit_internal import (
|
||||
Final,
|
||||
Future,
|
||||
_IgnoreContextManager,
|
||||
_overload,
|
||||
_overload_method,
|
||||
ignore,
|
||||
|
|
|
|||
|
|
@ -3,8 +3,9 @@ import sys
|
|||
import ast
|
||||
import inspect
|
||||
import string
|
||||
from collections import namedtuple
|
||||
from textwrap import dedent
|
||||
from typing import List
|
||||
from typing import List, Tuple # noqa: F401
|
||||
from torch._C._jit_tree_views import (
|
||||
ClassDef, Ident, Stmt, Decl, Def, Var,
|
||||
EmptyTypeAnnotation, Param, ExprStmt, Assign,
|
||||
|
|
@ -18,9 +19,16 @@ from torch._C._jit_tree_views import (
|
|||
)
|
||||
from torch._utils_internal import get_source_lines_and_file
|
||||
from torch.jit._monkeytype_config import monkeytype_trace, get_qualified_name
|
||||
from torch._jit_internal import SourceContext, should_drop, is_static_fn
|
||||
from torch._jit_internal import SourceContext, should_drop, is_static_fn, FunctionModifiers # noqa: F401
|
||||
import torch.jit.annotations
|
||||
|
||||
_IS_ASTUNPARSE_INSTALLED = False
|
||||
try:
|
||||
import astunparse # type: ignore[import]
|
||||
_IS_ASTUNPARSE_INSTALLED = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# Borrowed from cPython implementation
|
||||
# https://github.com/python/cpython/blob/561612d8456cfab5672c9b445521113b847bd6b3/Lib/textwrap.py#L411#
|
||||
|
||||
|
|
@ -299,6 +307,22 @@ def get_jit_def(fn, def_name, self_name=None, is_classmethod=False):
|
|||
|
||||
return build_def(ctx, fn_def, type_line, def_name, self_name=self_name, pdt_arg_types=pdt_arg_types)
|
||||
|
||||
# TODO: more robust handling of recognizing ignore context manager
|
||||
def is_torch_jit_ignore_context_manager(stmt):
|
||||
# checks if the statement is torch.jit.ignore context manager
|
||||
if isinstance(stmt.items[0].context_expr, ast.Call):
|
||||
# extract torch part
|
||||
function = stmt.items[0].context_expr.func
|
||||
if isinstance(function, ast.Attribute):
|
||||
attr_name = function.attr
|
||||
attr_value = function.value
|
||||
if attr_name == "_IgnoreContextManager" and isinstance(attr_value, ast.Attribute):
|
||||
# there should be at most two nested attributes (e.g torch.jit._IgnoreContextManager)
|
||||
if attr_value.attr == "jit" and isinstance(attr_value.value, ast.Name):
|
||||
if attr_value.value.id == "torch":
|
||||
return True
|
||||
return False
|
||||
|
||||
class Builder(object):
|
||||
def __call__(self, ctx, node):
|
||||
method = getattr(self, 'build_' + node.__class__.__name__, None)
|
||||
|
|
@ -382,6 +406,89 @@ def build_param(ctx, py_arg, self_name, kwarg_only, pdt_arg_type=None):
|
|||
annotation_expr = EmptyTypeAnnotation(r)
|
||||
return Param(annotation_expr, Ident(r, name), kwarg_only)
|
||||
|
||||
def build_ignore_context_manager(ctx, stmt):
|
||||
InputType = namedtuple('InputType', ['name', 'ann'])
|
||||
OutputType = namedtuple('OutputType', ['name', 'ann'])
|
||||
|
||||
def process_ins_outs(args):
|
||||
# parse the context manager to figure out inputs and outputs
|
||||
# with their annotated types
|
||||
# TODO: add input, output validator
|
||||
inputs = []
|
||||
outputs = []
|
||||
for arg in args:
|
||||
var_name = arg.arg
|
||||
if sys.version_info < (3, 8):
|
||||
# Starting python3.8 ast.Str is deprecated
|
||||
var_ann = arg.value.s
|
||||
else:
|
||||
var_ann = arg.value.value
|
||||
var_decl_type, var_ann = var_ann.split(":")
|
||||
if var_decl_type == "inp":
|
||||
inputs.append(InputType(var_name, var_ann))
|
||||
if var_decl_type == "out":
|
||||
outputs.append(OutputType(var_name, var_ann))
|
||||
return inputs, outputs
|
||||
|
||||
def create_unique_name_ext(ctx, stmt):
|
||||
# extension will be based on the full path filename plus
|
||||
# the line number of original context manager
|
||||
return ctx.filename.replace(".", "_").replace("/", "_") + "_" + str(stmt.lineno)
|
||||
|
||||
def build_return_ann_stmt(outputs):
|
||||
return_type_ann = ""
|
||||
return_statement_str = "return "
|
||||
if len(outputs) == 0:
|
||||
return_type_ann += " -> None"
|
||||
if len(outputs) == 1:
|
||||
return_type_ann = " -> " + outputs[0].ann
|
||||
return_statement_str += outputs[0].name
|
||||
if len(outputs) > 1:
|
||||
return_type_ann = " -> Tuple"
|
||||
return_type_ann += "[" + ", ".join([var.ann for var in outputs]) + "]"
|
||||
return_statement_str += ", ".join([var.name for var in outputs])
|
||||
return return_type_ann, return_statement_str
|
||||
|
||||
def build_args(args):
|
||||
return ", ".join([arg.name for arg in args])
|
||||
|
||||
inputs, outputs = process_ins_outs(stmt.items[0].context_expr.keywords)
|
||||
|
||||
# build the replacement function str with given inputs and outputs
|
||||
ignore_function_name = "func_ignore_" + create_unique_name_ext(ctx, stmt)
|
||||
ignore_function_str = "\ndef " + ignore_function_name
|
||||
ignore_function_str += "(" + ", ".join([var.name + " :" + var.ann for var in inputs]) + ")"
|
||||
|
||||
return_ann, return_stmt = build_return_ann_stmt(outputs)
|
||||
ignore_function_str += return_ann + ": pass"
|
||||
|
||||
# first create the functionDef object from just declaration
|
||||
ignore_function = ast.parse(ignore_function_str).body[0]
|
||||
|
||||
# dump the body of context manager to dummy function
|
||||
ignore_function.body = stmt.body # type: ignore[attr-defined]
|
||||
|
||||
# insert return statement to the function
|
||||
return_stmt = ast.parse(return_stmt).body[0]
|
||||
ignore_function.body.append(return_stmt) # type: ignore[attr-defined]
|
||||
|
||||
# registers the custom function in the global context
|
||||
ignore_func_str = "@torch.jit.ignore\n" + astunparse.unparse(ignore_function)
|
||||
ignore_func_str += "\nglobals()[\"{}\"] = {}".format(ignore_function_name, ignore_function_name)
|
||||
exec(ignore_func_str) # noqa: P204
|
||||
|
||||
# build the statements as:
|
||||
# <out_1>, <out_2>, ... = torch.jit.frontend.<func>(<in_1>, <in_2>)
|
||||
assign_str_lhs = build_args(outputs)
|
||||
# this function will be registered in torch.jit.frontend module by default
|
||||
assign_str_rhs = "torch.jit.frontend.{}(".format(ignore_function_name) + build_args(inputs) + ")"
|
||||
|
||||
if len(outputs) > 0:
|
||||
assign_str = assign_str_lhs + " = " + assign_str_rhs
|
||||
else:
|
||||
assign_str = assign_str_rhs
|
||||
assign_ast = ast.parse(assign_str).body[0]
|
||||
return assign_ast
|
||||
|
||||
def get_default_args(fn):
|
||||
if fn is None:
|
||||
|
|
@ -563,6 +670,13 @@ class StmtBuilder(Builder):
|
|||
@staticmethod
|
||||
def build_With(ctx, stmt):
|
||||
r = ctx.make_range(stmt.lineno, stmt.col_offset, stmt.col_offset + len("with"))
|
||||
# Handle ignore context manager
|
||||
if is_torch_jit_ignore_context_manager(stmt):
|
||||
if not _IS_ASTUNPARSE_INSTALLED:
|
||||
raise RuntimeError("torch.jit._IgnoreContextManager requires installing Python library `astunparse`,\
|
||||
please install it in your Python environment")
|
||||
assign_ast = build_ignore_context_manager(ctx, stmt)
|
||||
return build_stmt(ctx, assign_ast)
|
||||
return With(r, build_withitems(ctx, stmt.items), build_stmts(ctx, stmt.body))
|
||||
|
||||
class ExprBuilder(Builder):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user