Support mypy ignore annotation with particular rule specified (#51675)

Summary:
Previously TorchScript allows a ignore-all type check suppression rule that looks like
```
code code code  # type: ignore
```

But a more common use case is
```
code code code  # type: ignore[specific-rule]
```
This PR allows the more common use case

Fixes https://github.com/pytorch/pytorch/issues/48643

Pull Request resolved: https://github.com/pytorch/pytorch/pull/51675

Reviewed By: ansley

Differential Revision: D26304870

Pulled By: gmagogsfm

fbshipit-source-id: 0ac9ee34f0219c86e428318a69484d5aa3ec433f
This commit is contained in:
Yanan Cao 2021-02-08 11:16:52 -08:00 committed by Facebook GitHub Bot
parent 41bab9a4b6
commit b9acfcddeb
2 changed files with 21 additions and 4 deletions

View File

@ -12628,6 +12628,17 @@ dedent """
test_str.append(str(fn.foo.schema))
self.assertExpectedStripMangled("\n".join(test_str))
# Tests that "# type: ignore[*]" is supported in type lines and is
# properly ignored.
def test_mypy_type_ignore(self):
@torch.jit.script
def foo(x): # type: ignore
return x
@torch.jit.script
def bar(x): # type: ignore[no-redef]
return x
def test_method_casts_script(self):
cast_types = [
'byte', 'char', 'double', 'float', 'int', 'long', 'short'

View File

@ -170,13 +170,19 @@ def get_type_line(source):
type_lines = list(filter(lambda line: type_comment in line[1], lines))
# `type: ignore` comments may be needed in JIT'ed functions for mypy, due
# to the hack in torch/_VF.py.
type_lines = list(filter(lambda line: not line[1].endswith("# type: ignore"),
# An ignore type line can be of following format:
# 1) # type: ignore
# 2) # type: ignore[rule-code]
# This ignore statement must be at the end of the line
type_pattern = re.compile("# type: ignore(\\[[a-zA-Z-]+\\])?$")
type_lines = list(filter(lambda line: not type_pattern.search(line[1]),
type_lines))
lines_with_type = list(filter(lambda line: 'type' in line[1], lines))
if len(type_lines) == 0:
type_pattern = re.compile('#[\t ]*type[\t ]*(?!: ignore$):')
wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines))
# Catch common typo patterns like extra spaces, typo in 'ignore', etc.
wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):")
wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines))
if len(wrong_type_lines) > 0:
raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0])
+ " is probably invalid.\nIt must be '# type:'"