mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Grandfather in torchgen'ed aten ops to torch.Tag.pt2_compliant_tag (#112053)
In torchgen, we add the pt2_compliant_tag to all aten ops. Test Plan: - new test Pull Request resolved: https://github.com/pytorch/pytorch/pull/112053 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
27cf49549a
commit
d91a18c433
|
|
@ -1693,6 +1693,10 @@ def forward(self, x_1):
|
|||
self.assertTrue(isinstance(actual, list))
|
||||
self.assertEqual(actual, list(tags))
|
||||
|
||||
def test_builtin_aten_ops_are_pt2_compliant(self):
|
||||
for op in [torch.ops.aten.sin.default, torch.ops.aten.sum.dim_IntList]:
|
||||
self.assertIn(torch.Tag.pt2_compliant_tag, op.tags)
|
||||
|
||||
def test_define_bad_schema(self):
|
||||
lib = self.lib()
|
||||
with self.assertRaisesRegex(ValueError, "expected schema to look like"):
|
||||
|
|
|
|||
|
|
@ -649,6 +649,10 @@ class NativeFunction:
|
|||
tags_inp = [tags_inp]
|
||||
assert isinstance(tags_inp, list)
|
||||
|
||||
# All aten ops generated by torchgen receive the pt2_compliant tag.
|
||||
if namespace == "aten" and "pt2_compliant_tag" in valid_tags:
|
||||
tags_inp.append("pt2_compliant_tag")
|
||||
|
||||
tags: Set[str] = set()
|
||||
for t in tags_inp:
|
||||
assert len(valid_tags) > 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user