OpInfos for torch.atleast_{1d, 2d, 3d} (#67355)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67355

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D32649416

Pulled By: anjali411

fbshipit-source-id: 1b42e86c7124427880fff52fbe490481059da967
This commit is contained in:
anjali411 2021-11-24 09:52:25 -08:00 committed by Facebook GitHub Bot
parent b69155f754
commit c7d5e0f53f
2 changed files with 52 additions and 0 deletions

View File

@ -1531,6 +1531,9 @@ class TestNormalizeOperators(JitTestCase):
'__ror__',
'__rxor__',
"__rmatmul__",
"atleast_1d",
"atleast_2d",
"atleast_3d",
}
# Unsupported input types

View File

@ -6601,6 +6601,16 @@ def sample_inputs_view_as_reshape_as(op_info, device, dtype, requires_grad, **kw
return list(generator())
def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
input_list = []
shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
samples = []
for shape in shapes:
input_list.append(make_tensor_partial(shape))
samples.append(SampleInput(make_tensor_partial(shape)))
samples.append(SampleInput(input_list, ))
return samples
def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
@ -11678,6 +11688,45 @@ op_db: List[OpInfo] = [
supports_forward_ad=True,
sample_inputs_func=sample_inputs_view_as_reshape_as,
),
OpInfo('atleast_1d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
sample_inputs_func=sample_inputs_atleast1d2d3d,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
),
OpInfo('atleast_2d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('atleast_3d',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
supports_out=False,
supports_forward_ad=True,
skips=(
# JIT does not support variadic tensors.
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
# please report a bug to PyTorch.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
),
sample_inputs_func=sample_inputs_atleast1d2d3d,
),
OpInfo('pinverse',
op=torch.pinverse,
dtypes=floating_and_complex_types(),