mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
as_strided: Fix default storage_offset for reference implementation (#89513)
This fixes the default storage_offset to take it from the input. This was previously untested, so I've also added a new OpInfo which includes samples with non-zero storage_offsets on the input tensor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89513 Approved by: https://github.com/ezyang, https://github.com/ngimel
This commit is contained in:
parent
199b8b6025
commit
eded97ac72
|
|
@ -1776,6 +1776,7 @@ aot_autograd_failures = {
|
|||
xfail('scatter_reduce', 'prod'),
|
||||
|
||||
skip('as_strided_scatter'),
|
||||
xfail('as_strided', 'partial_views'),
|
||||
|
||||
# Too annoying to generate random inputs
|
||||
xfail('cholesky'),
|
||||
|
|
|
|||
|
|
@ -414,6 +414,7 @@ class TestOperators(TestCase):
|
|||
# BUG
|
||||
# AssertionError: Tensor-likes are not close!
|
||||
xfail('as_strided'),
|
||||
xfail('as_strided', 'partial_views'),
|
||||
decorate('linalg.det', 'singular',
|
||||
decorator=unittest.skipIf(IS_MACOS and IS_X86, "Fails on x86 MacOS CI")),
|
||||
}))
|
||||
|
|
@ -655,6 +656,7 @@ class TestOperators(TestCase):
|
|||
skip("atleast_3d"), # Takes too long
|
||||
skip("ormqr"), # Takes too long
|
||||
xfail("as_strided"), # incorrect output
|
||||
xfail("as_strided", "partial_views"), # incorrect output
|
||||
xfail("as_strided_scatter"), # incorrect output
|
||||
skip("bernoulli"), # calls random op
|
||||
xfail("bfloat16"), # rank 4 tensor for channels_last
|
||||
|
|
@ -735,6 +737,9 @@ class TestOperators(TestCase):
|
|||
tol1('svd',
|
||||
{torch.float32: tol(atol=1e-03, rtol=5e-04)}),
|
||||
))
|
||||
@skipOps('TestOperators', 'test_vmapvjpvjp', {
|
||||
xfail('as_strided', 'partial_views'),
|
||||
})
|
||||
def test_vmapvjpvjp(self, device, dtype, op):
|
||||
# Since, we test `vjpvjp` independently,
|
||||
# for this test, we just verify that vmap
|
||||
|
|
@ -802,6 +807,7 @@ class TestOperators(TestCase):
|
|||
xfail('svd_lowrank', ''), # randomness
|
||||
xfail('to_sparse', ''), # non-dense output
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
xfail('as_strided', 'partial_views'),
|
||||
# ----------------------------------------------------------------------
|
||||
|
||||
# ---------------------------- BUGS ------------------------------------
|
||||
|
|
@ -851,7 +857,9 @@ class TestOperators(TestCase):
|
|||
tol1('linalg.householder_product',
|
||||
{torch.float32: tol(atol=1e-04, rtol=1e-04)}),
|
||||
))
|
||||
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail)
|
||||
@skipOps('TestOperators', 'test_vmapvjp', vmapvjp_fail.union({
|
||||
xfail('as_strided', 'partial_views'),
|
||||
}))
|
||||
def test_vmapvjp(self, device, dtype, op):
|
||||
if not op.supports_autograd:
|
||||
self.skipTest("Skipped! Autograd not supported.")
|
||||
|
|
@ -899,6 +907,7 @@ class TestOperators(TestCase):
|
|||
decorate('linalg.det', 'singular', decorator=unittest.skipIf(IS_MACOS, "Fails on x86 MacOS CI")),
|
||||
skip('nn.functional.max_pool1d'), # fails on cpu, runs on cuda
|
||||
xfail('masked.mean'), # silent incorrectness (nan difference)
|
||||
xfail('as_strided', 'partial_views'), # Tensor-likes are not close!
|
||||
|
||||
xfail('nn.functional.soft_margin_loss', ''), # soft_margin_loss_backward does not support forward-ad
|
||||
xfail('tensor_split'), # data_ptr composite compliance
|
||||
|
|
@ -1201,6 +1210,7 @@ class TestOperators(TestCase):
|
|||
xfail('sparse.sampled_addmm', ''),
|
||||
xfail("native_batch_norm"),
|
||||
xfail("_native_batch_norm_legit"),
|
||||
xfail('as_strided', 'partial_views'),
|
||||
}))
|
||||
def test_vjpvmap(self, device, dtype, op):
|
||||
# NB: there is no vjpvmap_has_batch_rule test because that is almost
|
||||
|
|
@ -1383,6 +1393,7 @@ class TestOperators(TestCase):
|
|||
|
||||
# Potential bugs/errors
|
||||
xfail('as_strided'), # AssertionError: Tensor-likes are not close!
|
||||
xfail('as_strided', 'partial_views'), # AssertionError: Tensor-likes are not close!
|
||||
xfail('as_strided_scatter'), # AssertionError: Tensor-likes are not close!
|
||||
xfail('bernoulli'), # calls random op
|
||||
xfail('bfloat16'), # required rank 4 tensor to use channels_last format
|
||||
|
|
|
|||
|
|
@ -3302,6 +3302,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('triu'), # Exception not raised on error input
|
||||
# The error inputs are vectors, that pass when batched as they are treated as a matrix
|
||||
xfail('trace'),
|
||||
xfail('as_strided', 'partial_views'),
|
||||
}))
|
||||
def test_vmap_exhaustive(self, device, dtype, op):
|
||||
# needs to be fixed
|
||||
|
|
@ -3317,6 +3318,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
|
||||
xfail('as_strided', 'partial_views'),
|
||||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
xfail('complex'),
|
||||
xfail('copysign'),
|
||||
|
|
|
|||
|
|
@ -2513,9 +2513,15 @@ def atleast_3d(
|
|||
|
||||
|
||||
def as_strided(
|
||||
a: TensorLikeType, size: ShapeType, stride: StrideType, storage_offset: int = 0
|
||||
a: TensorLikeType,
|
||||
size: ShapeType,
|
||||
stride: StrideType,
|
||||
storage_offset: Optional[int] = None,
|
||||
) -> TensorLikeType:
|
||||
return prims.as_strided(a, size, stride, storage_offset)
|
||||
storage_offset_int = (
|
||||
storage_offset if storage_offset is not None else a.storage_offset()
|
||||
)
|
||||
return prims.as_strided(a, size, stride, storage_offset_int)
|
||||
|
||||
|
||||
def broadcast_shapes(*shapes) -> ShapeType:
|
||||
|
|
|
|||
|
|
@ -263,9 +263,15 @@ def sample_inputs_as_strided(op_info, device, dtype, requires_grad, **kwargs):
|
|||
kwargs = dict(storage_offset=storage_offset)
|
||||
yield SampleInput(input_t, args=(output_shape, stride), kwargs=kwargs)
|
||||
|
||||
def sample_inputs_as_strided_partial_views(op_info, device, dtype, requires_grad, **kwargs):
|
||||
def make_arg():
|
||||
base = make_tensor((20,), device=device, dtype=dtype)
|
||||
return base[5:15].requires_grad_(requires_grad)
|
||||
|
||||
# as_strided on offset, partial views
|
||||
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)))
|
||||
# yield SampleInput(make_arg((20,))[5:15], args=((2, 2), (1, 2)), kwargs={'storage_offset': 0})
|
||||
yield SampleInput(make_arg(), (2, 2), (1, 2))
|
||||
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=0)
|
||||
yield SampleInput(make_arg(), (2, 2), (1, 2), storage_offset=10)
|
||||
|
||||
def sample_inputs_as_strided_scatter(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
|
|
@ -10721,8 +10727,6 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_out', device_type='cpu'),
|
||||
)),
|
||||
OpInfo('as_strided',
|
||||
op=lambda x, size, stride, storage_offset=0:
|
||||
torch.as_strided(x, size, stride, storage_offset=storage_offset),
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
|
|
@ -10743,7 +10747,47 @@ op_db: List[OpInfo] = [
|
|||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestFwdGradients'),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'))),
|
||||
DecorateInfo(unittest.skip("Numerous errors"), 'TestBwdGradients'),
|
||||
)),
|
||||
OpInfo('as_strided',
|
||||
variant_test_name='partial_views',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
sample_inputs_func=sample_inputs_as_strided_partial_views,
|
||||
skips=(
|
||||
# Note: This xfail is fine -- it's inherent to how as_strided works
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_noncontiguous_samples'),
|
||||
# RuntimeError: This operator is not Composite Compliant: the
|
||||
# storage_offset of the tensor was modified directly without
|
||||
# going through the PyTorch dispatcher.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance'),
|
||||
|
||||
|
||||
# These fail because the test changes the input's in-memory layout
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_complex_half_reference_testing'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_variant_consistency_eager'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_fn_fwgrad_bwgrad',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_forward_mode_AD'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFwdGradients', 'test_inplace_forward_mode_AD'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_grad'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestBwdGradients', 'test_inplace_gradgrad'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestProxyTensorOpInfo',
|
||||
'test_make_fx_symbolic_exhaustive_inplace'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNNCOpInfo', 'test_nnc_correctness'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_correctness'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCudaFuserOpInfo', 'test_nvfuser_extremal_values'),
|
||||
# Fail but are also flaky
|
||||
DecorateInfo(unittest.skip("Test changes in memory layout"), 'TestMathBits'),
|
||||
DecorateInfo(unittest.skip("Modifies input strides and storage_offset"), 'TestCommon',
|
||||
'test_non_standard_bool_values'),
|
||||
)),
|
||||
OpInfo('as_strided_scatter',
|
||||
op=lambda x, src, size, stride, storage_offset=0:
|
||||
torch.as_strided_scatter(x, src, size, stride, storage_offset=storage_offset),
|
||||
|
|
@ -18282,15 +18326,27 @@ python_ref_db = [
|
|||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_nvfuser=False,
|
||||
skips=(
|
||||
# TODO: fix and/or update to xfails
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"),
|
||||
'TestCommon', 'test_python_ref_meta'),
|
||||
# cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.as_strided",
|
||||
torch_opinfo_name="as_strided",
|
||||
torch_opinfo_variant_name="partial_views",
|
||||
# FIXME: doesn't support chalf
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_nvfuser=False,
|
||||
skips=(
|
||||
# cloned_mutable_input.is_same(returned_output) INTERNAL ASSERT FAILED
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_conj_view'),
|
||||
DecorateInfo(unittest.skip("Errors when storage_offset is included"), 'TestMathBits', 'test_neg_conj_view'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_compare_cpu'),
|
||||
),
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.broadcast_shapes",
|
||||
torch_opinfo_name="broadcast_shapes",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user