mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[functorch][reland] vmap: bitwise operators (#92836)
Previous PR: #91971 Fixes: https://github.com/pytorch/functorch/issues/1069 Pull Request resolved: https://github.com/pytorch/pytorch/pull/92836 Approved by: https://github.com/Chillee
This commit is contained in:
parent
ccad2e5000
commit
4c074ddfd2
|
|
@ -359,10 +359,20 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
|
|||
POINTWISE_BOXED(addcmul);
|
||||
BINARY_POINTWISE(atan2);
|
||||
BINARY_SCALAR_2(bitwise_and, Tensor, Scalar);
|
||||
BINARY_POINTWISE2(bitwise_and_, Tensor);
|
||||
POINTWISE_BOXED(bitwise_and.Scalar_Tensor);
|
||||
BINARY_POINTWISE2(bitwise_or, Tensor);
|
||||
BINARY_POINTWISE2(bitwise_or_, Tensor);
|
||||
POINTWISE_BOXED(bitwise_or.Scalar_Tensor);
|
||||
BINARY_POINTWISE2(bitwise_xor, Tensor);
|
||||
BINARY_POINTWISE2(bitwise_xor_, Tensor);
|
||||
POINTWISE_BOXED(bitwise_xor.Scalar_Tensor);
|
||||
BINARY_SCALAR_3(bitwise_left_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
|
||||
POINTWISE_BOXED(bitwise_left_shift_.Tensor_Scalar);
|
||||
POINTWISE_BOXED(bitwise_left_shift_.Tensor);
|
||||
BINARY_SCALAR_3(bitwise_right_shift, Tensor, Tensor_Scalar, Scalar_Tensor);
|
||||
POINTWISE_BOXED(bitwise_right_shift_.Tensor_Scalar);
|
||||
POINTWISE_BOXED(bitwise_right_shift_.Tensor);
|
||||
|
||||
UNARY_POINTWISE(clamp);
|
||||
POINTWISE_BOXED(clamp.Tensor);
|
||||
|
|
|
|||
|
|
@ -61,8 +61,11 @@ TORCH_LIBRARY_IMPL(aten, FuncTorchBatchedDecomposition, m) {
|
|||
OP_DECOMPOSE(atleast_3d);
|
||||
OP_DECOMPOSE2(atleast_3d, Sequence);
|
||||
OP_DECOMPOSE(batch_norm);
|
||||
OP_DECOMPOSE2(bitwise_and_, Scalar);
|
||||
OP_DECOMPOSE2(bitwise_or, Scalar);
|
||||
OP_DECOMPOSE2(bitwise_or_, Scalar);
|
||||
OP_DECOMPOSE2(bitwise_xor, Scalar);
|
||||
OP_DECOMPOSE2(bitwise_xor_, Scalar);
|
||||
OP_DECOMPOSE(broadcast_tensors);
|
||||
m.impl("broadcast_to", native::broadcast_to_symint);
|
||||
OP_DECOMPOSE(cartesian_prod);
|
||||
|
|
|
|||
|
|
@ -3739,14 +3739,17 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('linalg.lu', ''),
|
||||
skip('linalg.ldl_solve', ''),
|
||||
skip('_softmax_backward_data'),
|
||||
# AssertionError: Tensor-likes are not equal!
|
||||
# Issue: https://github.com/pytorch/pytorch/issues/70904
|
||||
xfail('bitwise_left_shift', device_type='cpu'),
|
||||
decorate('bitwise_right_shift', device_type='cpu',
|
||||
decorator=expectedFailureIf(not (IS_MACOS and IS_X86))),
|
||||
# UBSAN: runtime error: shift exponent -1 is negative
|
||||
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
# One or more of the overload doesn't have a Batch rule.
|
||||
xfail('where'),
|
||||
xfail('bincount'),
|
||||
xfail('bitwise_and'),
|
||||
xfail('bitwise_or'),
|
||||
xfail('bitwise_xor'),
|
||||
xfail('bitwise_left_shift'),
|
||||
xfail('bitwise_right_shift'),
|
||||
xfail('float_power'),
|
||||
xfail('gt'),
|
||||
xfail('le'),
|
||||
|
|
|
|||
|
|
@ -85,9 +85,6 @@ xfail_not_implemented = {
|
|||
"aten::arctanh_",
|
||||
"aten::argwhere",
|
||||
"aten::bilinear",
|
||||
"aten::bitwise_and_.Scalar",
|
||||
"aten::bitwise_or_.Scalar",
|
||||
"aten::bitwise_xor_.Scalar",
|
||||
"aten::can_cast",
|
||||
"aten::cat.names",
|
||||
"aten::chain_matmul",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user