mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
OpInfos for torch.atleast_{1d, 2d, 3d} (#67355)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/67355 Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D32649416 Pulled By: anjali411 fbshipit-source-id: 1b42e86c7124427880fff52fbe490481059da967
This commit is contained in:
parent
b69155f754
commit
c7d5e0f53f
|
|
@ -1531,6 +1531,9 @@ class TestNormalizeOperators(JitTestCase):
|
|||
'__ror__',
|
||||
'__rxor__',
|
||||
"__rmatmul__",
|
||||
"atleast_1d",
|
||||
"atleast_2d",
|
||||
"atleast_3d",
|
||||
}
|
||||
|
||||
# Unsupported input types
|
||||
|
|
|
|||
|
|
@ -6601,6 +6601,16 @@ def sample_inputs_view_as_reshape_as(op_info, device, dtype, requires_grad, **kw
|
|||
|
||||
return list(generator())
|
||||
|
||||
def sample_inputs_atleast1d2d3d(op_info, device, dtype, requires_grad, **kwargs):
|
||||
input_list = []
|
||||
shapes = ((S, S, S, S), (S, S, S), (S, S), (S, ), (),)
|
||||
make_tensor_partial = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
samples = []
|
||||
for shape in shapes:
|
||||
input_list.append(make_tensor_partial(shape))
|
||||
samples.append(SampleInput(make_tensor_partial(shape)))
|
||||
samples.append(SampleInput(input_list, ))
|
||||
return samples
|
||||
|
||||
def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
|
||||
make_arg = partial(make_tensor, dtype=dtype, device=device, requires_grad=requires_grad)
|
||||
|
|
@ -11678,6 +11688,45 @@ op_db: List[OpInfo] = [
|
|||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_view_as_reshape_as,
|
||||
),
|
||||
OpInfo('atleast_1d',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
sample_inputs_func=sample_inputs_atleast1d2d3d,
|
||||
skips=(
|
||||
# JIT does not support variadic tensors.
|
||||
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
|
||||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
|
||||
# please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
|
||||
),
|
||||
),
|
||||
OpInfo('atleast_2d',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
skips=(
|
||||
# JIT does not support variadic tensors.
|
||||
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
|
||||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
|
||||
# please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_atleast1d2d3d,
|
||||
),
|
||||
OpInfo('atleast_3d',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16),
|
||||
supports_out=False,
|
||||
supports_forward_ad=True,
|
||||
skips=(
|
||||
# JIT does not support variadic tensors.
|
||||
# RuntimeError: input->type()->kind() == TypeKind::OptionalType
|
||||
# INTERNAL ASSERT FAILED at "../torch/csrc/jit/passes/utils/check_alias_annotation.cpp":252,
|
||||
# please report a bug to PyTorch.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit', dtypes=[torch.float32]),
|
||||
),
|
||||
sample_inputs_func=sample_inputs_atleast1d2d3d,
|
||||
),
|
||||
OpInfo('pinverse',
|
||||
op=torch.pinverse,
|
||||
dtypes=floating_and_complex_types(),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user