mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[opinfo] item (#100313)
Follows #100223 Pull Request resolved: https://github.com/pytorch/pytorch/pull/100313 Approved by: https://github.com/ezyang
This commit is contained in:
parent
55844dfdbc
commit
51fe53e619
|
|
@ -3594,6 +3594,9 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('addcmul'),
|
||||
xfail('clamp'),
|
||||
|
||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||
xfail('item'),
|
||||
|
||||
# UBSAN: runtime error: shift exponent -1 is negative
|
||||
decorate('bitwise_left_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
decorate('bitwise_right_shift', decorator=unittest.skipIf(TEST_WITH_UBSAN, "Fails with above error")),
|
||||
|
|
@ -3645,6 +3648,8 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail('take'),
|
||||
xfail('tensor_split'),
|
||||
xfail('to_sparse'),
|
||||
# TypeError: expected Tensor as element 0 in argument 0, but got float
|
||||
xfail('item'),
|
||||
xfail('tril'), # Exception not raised on error input
|
||||
xfail('triu'), # Exception not raised on error input
|
||||
xfail('__getitem__', ''),
|
||||
|
|
|
|||
|
|
@ -169,6 +169,8 @@ inductor_expected_failures_single_sample["cpu"] = {
|
|||
"index_add": {f16},
|
||||
"index_reduce": {f16, f32, f64},
|
||||
"istft": {f32, f64},
|
||||
# Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
"item": {b8, f16, f32, f64, i32, i64},
|
||||
"linalg.eig": {f32, f64},
|
||||
"linalg.eigh": {f32, f64},
|
||||
"linalg.eigvals": {f32, f64},
|
||||
|
|
@ -275,6 +277,8 @@ inductor_expected_failures_single_sample["cuda"] = {
|
|||
"equal": {b8, f16, f32, f64, i32, i64},
|
||||
"index_reduce": {f16, f32, f64},
|
||||
"istft": {f32, f64},
|
||||
# Unsupported: data dependent operator: aten._local_scalar_dense.default
|
||||
"item": {b8, f16, f32, f64, i32, i64},
|
||||
"linalg.eig": {f32, f64},
|
||||
"linalg.eigh": {f32, f64},
|
||||
"linalg.eigvals": {f32, f64},
|
||||
|
|
|
|||
|
|
@ -305,6 +305,9 @@ CROSS_REF_EXCLUDE_SET = {
|
|||
(None, None, "empty_like"),
|
||||
(None, None, "empty"),
|
||||
|
||||
# AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
|
||||
(None, None, "item"),
|
||||
|
||||
# It's the only in-place op without an out-of-place equivalent in the Python API
|
||||
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
|
||||
(None, None, "zero_"),
|
||||
|
|
|
|||
|
|
@ -603,7 +603,7 @@ meta_function_expected_failures = {
|
|||
torch.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.Tensor.nonzero : {f64, i32, c128, i64, i16, c32, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.ormqr : {f64, c64, c128, f32},
|
||||
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c64, bf16, b8, i8, f32},
|
||||
torch.Tensor.item : {f64, i32, c128, i64, i16, f16, u8, c32, c64, bf16, b8, i8, f32},
|
||||
torch.bincount : {i32, i64, u8, i16, i8},
|
||||
torch.frexp : {f64, f16, bf16, f32},
|
||||
torch.functional.unique : {f64, i32, i64, u8, i16, f16, bf16, b8, i8, f32},
|
||||
|
|
|
|||
|
|
@ -95,6 +95,8 @@ def mps_ops_grad_modifier(ops):
|
|||
# 'bool' object is not iterable
|
||||
'allclose': [torch.float16, torch.float32],
|
||||
'equal': [torch.float16, torch.float32],
|
||||
# 'float' object is not iterable
|
||||
'item': [torch.float16, torch.float32],
|
||||
# "mse_backward_cpu_out" not implemented for 'Half'
|
||||
'nn.functional.mse_loss': [torch.float16],
|
||||
# "smooth_l1_backward_cpu_out" not implemented for 'Half'
|
||||
|
|
|
|||
|
|
@ -1746,7 +1746,6 @@ class TestRefsOpsInfo(TestCase):
|
|||
'_refs.equal',
|
||||
'_refs.full',
|
||||
'_refs.full_like',
|
||||
'_refs.item',
|
||||
'_refs.to',
|
||||
'_refs.ones',
|
||||
'_refs.ones_like',
|
||||
|
|
|
|||
|
|
@ -1374,6 +1374,7 @@ make_fx_failures = {
|
|||
skip('linalg.lstsq'), # flaky, probably just a precision issue
|
||||
|
||||
# data-dependent control flow
|
||||
skip('item'),
|
||||
xfail('cov'),
|
||||
xfail('istft'),
|
||||
xfail('nn.functional.gaussian_nll_loss'),
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ __all__ = [
|
|||
#
|
||||
"clone",
|
||||
"copy_to", # TODO: add OpInfo (or implement .to)
|
||||
"item", # TODO: add OpInfo
|
||||
"item",
|
||||
"to",
|
||||
#
|
||||
# Reduction ops
|
||||
|
|
|
|||
|
|
@ -370,6 +370,36 @@ def sample_inputs_cosine_similarity(op_info, device, dtype, requires_grad, **kwa
|
|||
yield SampleInput(make_arg((1, 2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -2})
|
||||
yield SampleInput(make_arg((2, 3)), args=(make_arg((2, 1, 3)),), kwargs={'dim': -1})
|
||||
|
||||
|
||||
def sample_inputs_item(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=False)
|
||||
|
||||
cases = (
|
||||
(),
|
||||
(()),
|
||||
(1),
|
||||
((1,)),
|
||||
)
|
||||
|
||||
for shape in cases:
|
||||
yield SampleInput(make_arg(shape))
|
||||
|
||||
def error_inputs_item(op, device, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=torch.float32, device=device, requires_grad=False)
|
||||
|
||||
cases = (
|
||||
(M),
|
||||
((S,)),
|
||||
(S, S),
|
||||
(S, M, L),
|
||||
)
|
||||
|
||||
for shape in cases:
|
||||
yield ErrorInput(
|
||||
SampleInput(make_arg(shape)), error_type=RuntimeError,
|
||||
error_regex="elements cannot be converted to Scalar")
|
||||
|
||||
|
||||
def sample_inputs_batch_norm(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
|
||||
make_arg_without_requires_grad = partial(make_tensor, device=device, dtype=dtype, requires_grad=False)
|
||||
|
|
@ -9021,6 +9051,28 @@ op_db: List[OpInfo] = [
|
|||
'test_reference_numerics_extremal_values',
|
||||
dtypes=(torch.complex64, torch.complex128)),
|
||||
)),
|
||||
OpInfo('item',
|
||||
op=lambda inp, *args, **kwargs: wrapper_set_seed(torch.Tensor.item, inp, *args, **kwargs),
|
||||
ref=np.ndarray.item,
|
||||
method_variant=None,
|
||||
dtypes=all_types_and_complex_and(torch.bfloat16, torch.float16, torch.chalf, torch.bool),
|
||||
supports_out=False,
|
||||
supports_autograd=False,
|
||||
error_inputs_func=error_inputs_item,
|
||||
sample_inputs_func=sample_inputs_item,
|
||||
skips=(
|
||||
# Error testing item function variant
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32, torch.complex64)),
|
||||
# FX failed to normalize op - add the op to the op_skip list.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestNormalizeOperators', 'test_normalize_operator_exhaustive'),
|
||||
# RuntimeError: Composite compliance check failed with the above error.
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCompositeCompliance', 'test_operator'),
|
||||
# Booleans mismatch: AssertionError: False is not true
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake_autocast'),
|
||||
# Booleans mismatch: AssertionError: False is not true
|
||||
DecorateInfo(unittest.expectedFailure, 'TestFakeTensor', 'test_fake'),
|
||||
)),
|
||||
OpInfo('arange',
|
||||
dtypes=all_types_and(torch.bfloat16, torch.float16),
|
||||
supports_out=True,
|
||||
|
|
@ -18441,6 +18493,16 @@ python_ref_db = [
|
|||
# https://github.com/pytorch/pytorch/issues/85258
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.item",
|
||||
torch_opinfo_name="item",
|
||||
supports_nvfuser=False,
|
||||
skips=(
|
||||
# RuntimeError: Cannot cast FakeTensor(FakeTensor(..., device='meta', size=()), cpu) to number
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_meta'),
|
||||
# ValueError: Can't convert a tensor with 10 elements to a number!
|
||||
DecorateInfo(unittest.expectedFailure, 'TestCommon', 'test_python_ref_errors'),),
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.conj_physical",
|
||||
torch_opinfo_name="conj_physical",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user