[opinfo] item (#100313)

Follows #100223

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100313
Approved by: https://github.com/ezyang
This commit is contained in:
Khushi 2023-05-10 11:32:45 +00:00 committed by PyTorch MergeBot
parent 55844dfdbc
commit 51fe53e619
9 changed files with 79 additions and 3 deletions

View File

@ -3594,6 +3594,9 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('addcmul'), xfail('addcmul'),
xfail('clamp'), xfail('clamp'),
# TypeError: expected Tensor as element 0 in argument 0, but got float
xfail('item'),
# UBSAN: runtime error: shift exponent -1 is negative # UBSAN: runtime error: shift exponent -1 is negative
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")), decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")), decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
@ -3645,6 +3648,8 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail('take'), xfail('take'),
xfail('tensor_split'), xfail('tensor_split'),
xfail('to_sparse'), xfail('to_sparse'),
# TypeError: expected Tensor as element 0 in argument 0, but got float
xfail('item'),
xfail('tril'), # Exception not raised on error input xfail('tril'), # Exception not raised on error input
xfail('triu'), # Exception not raised on error input xfail('triu'), # Exception not raised on error input
xfail('__getitem__', ''), xfail('__getitem__', ''),

View File

@ -169,6 +169,8 @@ inductor_expected_failures_single_sample["cpu"] = {
"index_add": {f16}, "index_add": {f16},
"index_reduce": {f16, f32, f64}, "index_reduce": {f16, f32, f64},
"istft": {f32, f64}, "istft": {f32, f64},
# Unsupported: data dependent operator: aten._local_scalar_dense.default
"item": {b8, f16, f32, f64, i32, i64},
"linalg.eig": {f32, f64}, "linalg.eig": {f32, f64},
"linalg.eigh": {f32, f64}, "linalg.eigh": {f32, f64},
"linalg.eigvals": {f32, f64}, "linalg.eigvals": {f32, f64},
@ -275,6 +277,8 @@ inductor_expected_failures_single_sample["cuda"] = {
"equal": {b8, f16, f32, f64, i32, i64}, "equal": {b8, f16, f32, f64, i32, i64},
"index_reduce": {f16, f32, f64}, "index_reduce": {f16, f32, f64},
"istft": {f32, f64}, "istft": {f32, f64},
# Unsupported: data dependent operator: aten._local_scalar_dense.default
"item": {b8, f16, f32, f64, i32, i64},
"linalg.eig": {f32, f64}, "linalg.eig": {f32, f64},
"linalg.eigh": {f32, f64}, "linalg.eigh": {f32, f64},
"linalg.eigvals": {f32, f64}, "linalg.eigvals": {f32, f64},

View File

@ -305,6 +305,9 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "empty_like"), (None, None, "empty_like"),
(None, None, "empty"), (None, None, "empty"),
# AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
(None, None, "item"),
# It's the only in-place op without an out-of-place equivalent in the Python API # It's the only in-place op without an out-of-place equivalent in the Python API
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`. # Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
(None, None, "zero_"), (None, None, "zero_"),

View File

@ -603,7 +603,7 @@ meta_function_expected_failures = {
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32}, torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
torch.ormqr : {f64, c64, c128, f32}, torch.ormqr : {f64, c64, c128, f32},
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32}, torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
torch.bincount : {i32, i64, u8, i16, i8}, torch.bincount : {i32, i64, u8, i16, i8},
torch.frexp : {f64, f16, bf16, f32}, torch.frexp : {f64, f16, bf16, f32},
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32}, torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},

View File

@ -95,6 +95,8 @@ def mps_ops_grad_modifier(ops):
# 'bool' object is not iterable # 'bool' object is not iterable
'allclose': [torch.float16, torch.float32], 'allclose': [torch.float16, torch.float32],
'equal': [torch.float16, torch.float32], 'equal': [torch.float16, torch.float32],
# 'float' object is not iterable
'item': [torch.float16, torch.float32],
# "mse_backward_cpu_out" not implemented for 'Half' # "mse_backward_cpu_out" not implemented for 'Half'
'nn.functional.mse_loss': [torch.float16], 'nn.functional.mse_loss': [torch.float16],
# "smooth_l1_backward_cpu_out" not implemented for 'Half' # "smooth_l1_backward_cpu_out" not implemented for 'Half'

View File

@ -1746,7 +1746,6 @@ class TestRefsOpsInfo(TestCase):
'_refs.equal', '_refs.equal',
'_refs.full', '_refs.full',
'_refs.full_like', '_refs.full_like',
'_refs.item',
'_refs.to', '_refs.to',
'_refs.ones', '_refs.ones',
'_refs.ones_like', '_refs.ones_like',

View File

@ -1374,6 +1374,7 @@ make_fx_failures = {
skip('linalg.lstsq'), # flaky, probably just a precision issue skip('linalg.lstsq'), # flaky, probably just a precision issue
# data-dependent control flow # data-dependent control flow
skip('item'),
xfail('cov'), xfail('cov'),
xfail('istft'), xfail('istft'),
xfail('nn.functional.gaussian_nll_loss'), xfail('nn.functional.gaussian_nll_loss'),

View File

@ -186,7 +186,7 @@ __all__ = [
# #
"clone", "clone",
"copy_to", # TODO: add OpInfo (or implement .to) "copy_to", # TODO: add OpInfo (or implement .to)
"item", # TODO: add OpInfo "item",
"to", "to",
# #
# Reduction ops # Reduction ops

View File

@ -370,6 +370,36 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2}) yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1}) yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
cases = (
(),
(()),
(1),
((1,)),
)
for shape in cases:
yield SampleInput(make_arg(shape))
def error_inputs_item(op, device, **kwargs):
make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
cases = (
(M),
((S,)),
(S, S),
(S, M, L),
)
for shape in cases:
yield ErrorInput(
SampleInput(make_arg(shape)), error_type=RuntimeError,
error_regex="elements cannot be converted to Scalar")
def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs): def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad) make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False) make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
@ -9021,6 +9051,28 @@ op_db: List[OpInfo] = [
'test_reference_numerics_extremal_values', 'test_reference_numerics_extremal_values',
dtypes=(torch.complex64, torch.complex128)), dtypes=(torch.complex64, torch.complex128)),
)), )),
OpInfo('item',
op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs),
ref=np.ndarray.item,
method_variant=None,
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool),
supports_out=False,
supports_autograd=False,
error_inputs_func=error_inputs_item,
sample_inputs_func=sample_inputs_item,
skips=(
# Error testing item function variant
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
dtypes=(torch.float32, torch.complex64)),
# FX failed to normalize op - add the op to the op_skip list.
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
# RuntimeError: Composite compliance check failed with the above error.
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
# Booleans mismatch: AssertionError: False is not true
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
# Booleans mismatch: AssertionError: False is not true
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
)),
OpInfo('arange', OpInfo('arange',
dtypes=all_types_and(torch.bfloat16, torch.float16), dtypes=all_types_and(torch.bfloat16, torch.float16),
supports_out=True, supports_out=True,
@ -18441,6 +18493,16 @@ python_ref_db = [
# https://github.com/pytorch/pytorch/issues/85258 # https://github.com/pytorch/pytorch/issues/85258
supports_nvfuser=False, supports_nvfuser=False,
), ),
PythonRefInfo(
"_refs.item",
torch_opinfo_name="item",
supports_nvfuser=False,
skips=(
# RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
# ValueError: Can't convert a tensor with 10 elements to a number!
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),),
),
ElementwiseUnaryPythonRefInfo( ElementwiseUnaryPythonRefInfo(
"_refs.conj_physical", "_refs.conj_physical",
torch_opinfo_name="conj_physical", torch_opinfo_name="conj_physical",