diff --git a/test/test_jit.py b/test/test_jit.py index a7fd4204a86..009e66a9868 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -12628,6 +12628,17 @@ dedent """ test_str.append(str(fn.foo.schema)) self.assertExpectedStripMangled("\n".join(test_str)) + # Tests that "# type: ignore[*]" is supported in type lines and is + # properly ignored. + def test_mypy_type_ignore(self): + @torch.jit.script + def foo(x): # type: ignore + return x + + @torch.jit.script + def bar(x): # type: ignore[no-redef] + return x + def test_method_casts_script(self): cast_types = [ 'byte', 'char', 'double', 'float', 'int', 'long', 'short' diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 82c3eac4ebc..55d08bfbc89 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -170,13 +170,19 @@ def get_type_line(source): type_lines = list(filter(lambda line: type_comment in line[1], lines)) # `type: ignore` comments may be needed in JIT'ed functions for mypy, due # to the hack in torch/_VF.py. - type_lines = list(filter(lambda line: not line[1].endswith("# type: ignore"), + + # An ignore type line can be of following format: + # 1) # type: ignore + # 2) # type: ignore[rule-code] + # This ignore statement must be at the end of the line + type_pattern = re.compile("# type: ignore(\\[[a-zA-Z-]+\\])?$") + type_lines = list(filter(lambda line: not type_pattern.search(line[1]), type_lines)) - lines_with_type = list(filter(lambda line: 'type' in line[1], lines)) if len(type_lines) == 0: - type_pattern = re.compile('#[\t ]*type[\t ]*(?!: ignore$):') - wrong_type_lines = list(filter(lambda line: type_pattern.search(line[1]), lines)) + # Catch common typo patterns like extra spaces, typo in 'ignore', etc. + wrong_type_pattern = re.compile("#[\t ]*type[\t ]*(?!: ignore(\\[.*\\])?$):") + wrong_type_lines = list(filter(lambda line: wrong_type_pattern.search(line[1]), lines)) if len(wrong_type_lines) > 0: raise RuntimeError("The annotation prefix in line " + str(wrong_type_lines[0][0]) + " is probably invalid.\nIt must be '# type:'"