Grandfather in torchgen'ed aten ops to torch.Tag.pt2_compliant_tag (#112053)

In torchgen, we add the pt2_compliant_tag to all aten ops.

Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112053
Approved by: https://github.com/soulitzer
This commit is contained in:
rzou 2023-10-26 07:25:00 -07:00 committed by PyTorch MergeBot
parent 27cf49549a
commit d91a18c433
2 changed files with 8 additions and 0 deletions

View File

@ -1693,6 +1693,10 @@ def forward(self, x_1):
self.assertTrue(isinstance(actual, list))
self.assertEqual(actual, list(tags))
def test_builtin_aten_ops_are_pt2_compliant(self):
for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
def test_define_bad_schema(self):
lib = self.lib()
with self.assertRaisesRegex(ValueError, "expected schema to look like"):

View File

@ -649,6 +649,10 @@ class NativeFunction:
tags_inp = [tags_inp]
assert isinstance(tags_inp, list)
# All aten ops generated by torchgen receive the pt2_compliant tag.
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
tags_inp.append("pt2_compliant_tag")
tags: Set[str] = set()
for t in tags_inp:
assert len(valid_tags) > 0