as_strided: Fix default storage_offset for reference implementation (#89513)

This fixes the default storage_offset to take it from the input. This was
previously untested, so I've also added a new OpInfo which includes samples with
non-zero storage_offsets on the input tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89513
Approved by: https://github.com/ezyang, https://github.com/ngimel
This commit is contained in:
Peter Bell 2022-12-04 19:33:46 +00:00 committed by PyTorch MergeBot
parent 199b8b6025
commit eded97ac72
5 changed files with 87 additions and 11 deletions

View File

@ -1776,6 +1776,7 @@ aot_autograd_failures = {
xfail('scatter_reduce', 'prod'),
skip('as_strided_scatter'),
xfail('as_strided', 'partial_views'),
# Too annoying to generate random inputs
xfail('cholesky'),

View File

@ -414,6 +414,7 @@ class TestOperators(TestCase):
# BUG
# AssertionError: Tensor-likes are not close!
xfail('as_strided'),
xfail('as_strided', 'partial_views'),
decorate('linalg.det', 'singular',
decorator=unittest.skipIf(IS_MACOS and IS_X86, "Fails on x86 MacOS CI")),
}))
@ -655,6 +656,7 @@ class TestOperators(TestCase):
skip("atleast_3d"), # Takes too long
skip("ormqr"), # Takes too long
xfail("as_strided"), # incorrect output
xfail("as_strided", "partial_views"), # incorrect output
xfail("as_strided_scatter"), # incorrect output
skip("bernoulli"), # calls random op
xfail("bfloat16"), # rank 4 tensor for channels_last
@ -735,6 +737,9 @@ class TestOperators(TestCase):
tol1('svd',
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
))
@skipOps('TestOperators', 'test_vmapvjpvjp', {
xfail('as_strided', 'partial_views'),
})
def test_vmapvjpvjp(self, device, dtype, op):
# Since, we test `vjpvjp` independently,
# for this test, we just verify that vmap
@ -802,6 +807,7 @@ class TestOperators(TestCase):
xfail('svd_lowrank', ''), # randomness
xfail('to_sparse', ''), # non-dense output
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('as_strided', 'partial_views'),
# ----------------------------------------------------------------------
# ---------------------------- BUGS ------------------------------------
@ -851,7 +857,9 @@ class TestOperators(TestCase):
tol1('linalg.householder_product',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
))
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail.union({
xfail('as_strided', 'partial_views'),
}))
def test_vmapvjp(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
@ -899,6 +907,7 @@ class TestOperators(TestCase):
decorate('linalg.det', 'singular', decorator=unittest.skipIf(IS_MACOS, "Fails on x86 MacOS CI")),
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
xfail('masked.mean'), # silent incorrectness (nan difference)
xfail('as_strided', 'partial_views'), # Tensor-likes are not close!
xfail('nn.functional.soft_margin_loss', ''), # soft_margin_loss_backward does not support forward-ad
xfail('tensor_split'), # data_ptr composite compliance
@ -1201,6 +1210,7 @@ class TestOperators(TestCase):
xfail('sparse.sampled_addmm', ''),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
xfail('as_strided', 'partial_views'),
}))
def test_vjpvmap(self, device, dtype, op):
# NB: there is no vjpvmap_has_batch_rule test because that is almost
@ -1383,6 +1393,7 @@ class TestOperators(TestCase):
# Potential bugs/errors
xfail('as_strided'), # AssertionError: Tensor-likes are not close!
xfail('as_strided', 'partial_views'), # AssertionError: Tensor-likes are not close!
xfail('as_strided_scatter'), # AssertionError: Tensor-likes are not close!
xfail('bernoulli'), # calls random op
xfail('bfloat16'), # required rank 4 tensor to use channels_last format

View File

@ -3302,6 +3302,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('triu'), # Exception not raised on error input
# The error inputs are vectors, that pass when batched as they are treated as a matrix
xfail('trace'),
xfail('as_strided', 'partial_views'),
}))
def test_vmap_exhaustive(self, device, dtype, op):
# needs to be fixed
@ -3317,6 +3318,7 @@ class TestVmapOperatorsOpInfo(TestCase):
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
xfail('as_strided', 'partial_views'),
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
xfail('complex'),
xfail('copysign'),

View File

@ -2513,9 +2513,15 @@ def atleast_3d(
def as_strided(
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int = 0
a: TensorLikeType,
size: ShapeType,
stride: StrideType,
storage_offset: Optional[int] = None,
) -> TensorLikeType:
return prims.as_strided(a, size, stride, storage_offset)
storage_offset_int = (
storage_offset if storage_offset is not None else a.storage_offset()
)
return prims.as_strided(a, size, stride, storage_offset_int)
def broadcast_shapes(*shapes) -> ShapeType:

View File

@ -263,9 +263,15 @@ def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
kwargs = dict(storage_offset=storage_offset)
yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
def make_arg():
base = make_tensor((20,), device=device, dtype=dtype)
return base[5:15].requires_grad_(requires_grad)
# as_strided on offset, partial views
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)))
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)), kwargs={'storage_offset': 0})
yield SampleInput(make_arg(), (2, 2), (1, 2))
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
@ -10721,8 +10727,6 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
)),
OpInfo('as_strided',
op=lambda x, size, stride, storage_offset=0:
torch.as_strided(x, size, stride, storage_offset=storage_offset),
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
@ -10743,7 +10747,47 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'))),
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
)),
OpInfo('as_strided',
variant_test_name='partial_views',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=False,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
sample_inputs_func=sample_inputs_as_strided_partial_views,
skips=(
# Note: This xfail is fine -- it's inherent to how as_strided works
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
# RuntimeError: This operator is not Composite Compliant: the
# storage_offset of the tensor was modified directly without
# going through the PyTorch dispatcher.
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance'),
# These fail because the test changes the input's in-memory layout
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
dtypes=(torch.complex64, torch.complex128)),
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'),
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'),
DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo',
'test_make_fx_symbolic_exhaustive_inplace'),
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
# Fail but are also flaky
DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'),
DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon',
'test_non_standard_bool_values'),
)),
OpInfo('as_strided_scatter',
op=lambda x, src, size, stride, storage_offset=0:
torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
@ -18282,15 +18326,27 @@ python_ref_db = [
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_nvfuser=False,
skips=(
# TODO: fix and/or update to xfails
DecorateInfo(unittest.skip("Errors when storage_offset is included"),
'TestCommon', 'test_python_ref_meta'),
# cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
),
),
PythonRefInfo(
"_refs.as_strided",
torch_opinfo_name="as_strided",
torch_opinfo_variant_name="partial_views",
# FIXME: doesn't support chalf
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_nvfuser=False,
skips=(
# cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
),
),
PythonRefInfo(
"_refs.broadcast_shapes",
torch_opinfo_name="broadcast_shapes",