From 4c074ddfd2e60b602e1a7eb5df1346958d08b979 Mon Sep 17 00:00:00 2001 From: Khushi Agrawal Date: Thu, 26 Jan 2023 06:12:47 +0000 Subject: [PATCH] [functorch][reland] vmap: bitwise operators (#92836) Previous PR: #91971 Fixes: https://github.com/pytorch/functorch/issues/1069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/92836 Approved by: https://github.com/Chillee --- aten/src/ATen/functorch/BatchRulesBinaryOps.cpp | 10 ++++++++++ .../src/ATen/functorch/BatchRulesDecompositions.cpp | 3 +++ test/functorch/test_vmap.py | 13 ++++++++----- test/functorch/test_vmap_registrations.py | 3 --- 4 files changed, 21 insertions(+), 8 deletions(-) diff --git a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp index cc478faef7c..1c0f98949a5 100644 --- a/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp +++ b/aten/src/ATen/functorch/BatchRulesBinaryOps.cpp @@ -359,10 +359,20 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) { POINTWISE_BOXED(addcmul); BINARY_POINTWISE(atan2); BINARY_SCALAR_2(bitwise_and, Tensor, Scalar); + BINARY_POINTWISE2(bitwise_and_, Tensor); + POINTWISE_BOXED(bitwise_and.Scalar_Tensor); BINARY_POINTWISE2(bitwise_or, Tensor); + BINARY_POINTWISE2(bitwise_or_, Tensor); + POINTWISE_BOXED(bitwise_or.Scalar_Tensor); BINARY_POINTWISE2(bitwise_xor, Tensor); + BINARY_POINTWISE2(bitwise_xor_, Tensor); + POINTWISE_BOXED(bitwise_xor.Scalar_Tensor); BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor); + POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar); + POINTWISE_BOXED(bitwise_left_shift_.Tensor); BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor); + POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar); + POINTWISE_BOXED(bitwise_right_shift_.Tensor); UNARY_POINTWISE(clamp); POINTWISE_BOXED(clamp.Tensor); diff --git a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp index 359b9895457..5e2db011f97 100644 --- a/aten/src/ATen/functorch/BatchRulesDecompositions.cpp +++ b/aten/src/ATen/functorch/BatchRulesDecompositions.cpp @@ -61,8 +61,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) { OP_DECOMPOSE(atleast_3d); OP_DECOMPOSE2(atleast_3d, Sequence); OP_DECOMPOSE(batch_norm); + OP_DECOMPOSE2(bitwise_and_, Scalar); OP_DECOMPOSE2(bitwise_or, Scalar); + OP_DECOMPOSE2(bitwise_or_, Scalar); OP_DECOMPOSE2(bitwise_xor, Scalar); + OP_DECOMPOSE2(bitwise_xor_, Scalar); OP_DECOMPOSE(broadcast_tensors); m.impl("broadcast_to", native::broadcast_to_symint); OP_DECOMPOSE(cartesian_prod); diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index f322714b4b8..404e7c8b0fc 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -3739,14 +3739,17 @@ class TestVmapOperatorsOpInfo(TestCase): xfail('linalg.lu', ''), skip('linalg.ldl_solve', ''), skip('_softmax_backward_data'), + # AssertionError: Tensor-likes are not equal! + # Issue: https://github.com/pytorch/pytorch/issues/70904 + xfail('bitwise_left_shift', device_type='cpu'), + decorate('bitwise_right_shift', device_type='cpu', + decorator=expectedFailureIf(not (IS_MACOS and IS_X86))), + # UBSAN: runtime error: shift exponent -1 is negative + decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")), + decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")), # One or more of the overload doesn't have a Batch rule. xfail('where'), xfail('bincount'), - xfail('bitwise_and'), - xfail('bitwise_or'), - xfail('bitwise_xor'), - xfail('bitwise_left_shift'), - xfail('bitwise_right_shift'), xfail('float_power'), xfail('gt'), xfail('le'), diff --git a/test/functorch/test_vmap_registrations.py b/test/functorch/test_vmap_registrations.py index 26a489eb380..ed89f59ca44 100644 --- a/test/functorch/test_vmap_registrations.py +++ b/test/functorch/test_vmap_registrations.py @@ -85,9 +85,6 @@ xfail_not_implemented = { "aten::arctanh_", "aten::argwhere", "aten::bilinear", - "aten::bitwise_and_.Scalar", - "aten::bitwise_or_.Scalar", - "aten::bitwise_xor_.Scalar", "aten::can_cast", "aten::cat.names", "aten::chain_matmul",