[functorch][reland] vmap: bitwise operators (#92836)

Previous PR: #91971

Fixes: https://github.com/pytorch/functorch/issues/1069

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92836
Approved by: https://github.com/Chillee
This commit is contained in:
Khushi Agrawal 2023-01-26 06:12:47 +00:00 committed by PyTorch MergeBot
parent ccad2e5000
commit 4c074ddfd2
4 changed files with 21 additions and 8 deletions

View File

@ -359,10 +359,20 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
POINTWISE_BOXED(addcmul);
BINARY_POINTWISE(atan2);
BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
BINARY_POINTWISE2(bitwise_and_, Tensor);
POINTWISE_BOXED(bitwise_and.Scalar_Tensor);
BINARY_POINTWISE2(bitwise_or, Tensor);
BINARY_POINTWISE2(bitwise_or_, Tensor);
POINTWISE_BOXED(bitwise_or.Scalar_Tensor);
BINARY_POINTWISE2(bitwise_xor, Tensor);
BINARY_POINTWISE2(bitwise_xor_, Tensor);
POINTWISE_BOXED(bitwise_xor.Scalar_Tensor);
BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar);
POINTWISE_BOXED(bitwise_left_shift_.Tensor);
BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar);
POINTWISE_BOXED(bitwise_right_shift_.Tensor);
UNARY_POINTWISE(clamp);
POINTWISE_BOXED(clamp.Tensor);

View File

@ -61,8 +61,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
OP_DECOMPOSE(atleast_3d);
OP_DECOMPOSE2(atleast_3d, Sequence);
OP_DECOMPOSE(batch_norm);
OP_DECOMPOSE2(bitwise_and_, Scalar);
OP_DECOMPOSE2(bitwise_or, Scalar);
OP_DECOMPOSE2(bitwise_or_, Scalar);
OP_DECOMPOSE2(bitwise_xor, Scalar);
OP_DECOMPOSE2(bitwise_xor_, Scalar);
OP_DECOMPOSE(broadcast_tensors);
m.impl("broadcast_to", native::broadcast_to_symint);
OP_DECOMPOSE(cartesian_prod);

View File

@ -3739,14 +3739,17 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('linalg.lu', ''),
skip('linalg.ldl_solve', ''),
skip('_softmax_backward_data'),
# AssertionError: Tensor-likes are not equal!
# Issue: https://github.com/pytorch/pytorch/issues/70904
xfail('bitwise_left_shift', device_type='cpu'),
decorate('bitwise_right_shift', device_type='cpu',
decorator=expectedFailureIf(not (IS_MACOS and IS_X86))),
# UBSAN: runtime error: shift exponent -1 is negative
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
# One or more of the overload doesn't have a Batch rule.
xfail('where'),
xfail('bincount'),
xfail('bitwise_and'),
xfail('bitwise_or'),
xfail('bitwise_xor'),
xfail('bitwise_left_shift'),
xfail('bitwise_right_shift'),
xfail('float_power'),
xfail('gt'),
xfail('le'),

View File

@ -85,9 +85,6 @@ xfail_not_implemented = {
"aten::arctanh_",
"aten::argwhere",
"aten::bilinear",
"aten::bitwise_and_.Scalar",
"aten::bitwise_or_.Scalar",
"aten::bitwise_xor_.Scalar",
"aten::can_cast",
"aten::cat.names",
"aten::chain_matmul",