mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
fix norrow_copy correctness issue for non-contiguous input for cpu path(reland) (#91883)
This PR is about re-land https://github.com/pytorch/pytorch/pull/91789. Pull Request resolved: https://github.com/pytorch/pytorch/pull/91883 Approved by: https://github.com/lezcano
This commit is contained in:
parent
d1cc64b2ac
commit
1892c75a45
|
|
@ -1217,7 +1217,9 @@ Tensor narrow_copy_dense(const Tensor& self, int64_t dim, int64_t start, int64_t
|
||||||
// Should just use narrow_copy_out, but this API is used internally at Meta:
|
// Should just use narrow_copy_out, but this API is used internally at Meta:
|
||||||
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
|
// https://github.com/pytorch/pytorch/pull/87045#issuecomment-1309353561
|
||||||
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
Tensor narrow_copy_dense_cpu(const Tensor& self, int64_t dim, int64_t start, int64_t length){
|
||||||
auto output = at::empty_like(self);
|
// narrow_copy_dense_cpu_out always resize output's size, so there only create
|
||||||
|
// a zero size tensor.
|
||||||
|
auto output = at::empty({0}, self.options());
|
||||||
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
|
return narrow_copy_dense_cpu_out(self, dim, start, length, output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3542,7 +3542,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail('bitwise_left_shift', device_type='cpu'),
|
xfail('bitwise_left_shift', device_type='cpu'),
|
||||||
decorate('bitwise_right_shift', device_type='cpu',
|
decorate('bitwise_right_shift', device_type='cpu',
|
||||||
decorator=expectedFailureIf(not (IS_MACOS and IS_X86))),
|
decorator=expectedFailureIf(not (IS_MACOS and IS_X86))),
|
||||||
xfail('narrow_copy', device_type='cpu'),
|
|
||||||
|
|
||||||
# UBSAN: runtime error: shift exponent -1 is negative
|
# UBSAN: runtime error: shift exponent -1 is negative
|
||||||
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||||
|
|
@ -3721,11 +3720,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
||||||
xfail('le'),
|
xfail('le'),
|
||||||
xfail('lt'),
|
xfail('lt'),
|
||||||
xfail('ne'),
|
xfail('ne'),
|
||||||
# AssertionError
|
|
||||||
# Mismatched elements: 18 / 20 (90.0%)
|
|
||||||
# Greatest absolute difference: 14.031710147857666 at index (0, 5) (up to 0.0001 allowed)
|
|
||||||
# Greatest relative difference: 2.9177700113052603 at index (0, 3) (up to 0.0001 allowed)
|
|
||||||
xfail('narrow_copy', device_type='cpu'),
|
|
||||||
# UBSAN: runtime error: 1.27043e+262 is outside the range of representable values of type 'float'
|
# UBSAN: runtime error: 1.27043e+262 is outside the range of representable values of type 'float'
|
||||||
decorate('special.zeta', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
decorate('special.zeta', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||||
# RuntimeError: Expected all tensors to be on the same device,
|
# RuntimeError: Expected all tensors to be on the same device,
|
||||||
|
|
|
||||||
|
|
@ -2971,6 +2971,13 @@ else:
|
||||||
sz[d] = 0
|
sz[d] = 0
|
||||||
self.assertEqual(sz, y.size())
|
self.assertEqual(sz, y.size())
|
||||||
|
|
||||||
|
def test_narrow_copy_non_contiguous(self, device):
|
||||||
|
# see https://github.com/pytorch/pytorch/issues/91690.
|
||||||
|
inp = torch.randn(10, 2, device=device).movedim(-1, 0)
|
||||||
|
expected = torch.narrow_copy(inp.contiguous(), 1, 0, 10)
|
||||||
|
actual = torch.narrow_copy(inp, 1, 0, 10)
|
||||||
|
self.assertEqual(expected, actual)
|
||||||
|
|
||||||
# FIXME: move to indexing test suite
|
# FIXME: move to indexing test suite
|
||||||
@parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
|
@parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
|
||||||
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
@dtypes(*all_types_and(torch.half, torch.bfloat16))
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user