mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Raise ValueError from torch.cat meta func (#158249)
Followup after https://github.com/pytorch/pytorch/pull/155460 From [Python documentation](https://docs.python.org/3/library/exceptions.html#ValueError): > Raised when an operation or function receives an argument that has the right type but an inappropriate value, and the situation is not described by a more precise exception such as IndexError. Raise [`TypeError`](https://docs.python.org/3/library/exceptions.html#TypeError) when input-output types are incompatible with each other > Raised when an operation or function is applied to an object of inappropriate type. The associated value is a string giving details about the type mismatch. > This exception may be raised by user code to indicate that an attempted operation on an object is not supported, and is not meant to be. If an object is meant to support a given operation but has not yet provided an implementation, [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) is the proper exception to raise. Pull Request resolved: https://github.com/pytorch/pytorch/pull/158249 Approved by: https://github.com/jbschlosser, https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
parent
4b02bd76d3
commit
2cdafab0bd
|
|
@ -247,7 +247,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
|
|||
// Checking names before the actual dimensions.
|
||||
auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
|
||||
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_VALUE(
|
||||
!materialized.empty(),
|
||||
"torch.cat(): expected a non-empty list of Tensors");
|
||||
|
||||
|
|
@ -274,7 +274,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
|
|||
// when computing the actual output dtype and the flags.
|
||||
if (is_out_defined) {
|
||||
// Check for type promotion, if the output tensor is defined.
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_TYPE(
|
||||
canCast(out_dtype, result.scalar_type()),
|
||||
"torch.cat(): input types can't be cast to the desired output type ",
|
||||
result.scalar_type());
|
||||
|
|
@ -293,7 +293,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
|
|||
// are compatible, i.e. we can execute `cat` on them.
|
||||
bool found_valid_tensor = valid < materialized.size();
|
||||
if (found_valid_tensor) {
|
||||
TORCH_CHECK(
|
||||
TORCH_CHECK_INDEX(
|
||||
dim <= materialized[valid].get().dim(),
|
||||
"torch.cat(): dimension ",
|
||||
dim,
|
||||
|
|
|
|||
|
|
@ -1109,7 +1109,22 @@ class TestCommon(TestCase):
|
|||
if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
|
||||
op_out(out=out)
|
||||
else:
|
||||
with self.assertRaises(RuntimeError, msg=msg_fail):
|
||||
# TODO: Remove me when all ops will raise type error on mismatched types
|
||||
exc_type = (
|
||||
TypeError
|
||||
if op.name
|
||||
in [
|
||||
"_chunk_cat",
|
||||
"cat",
|
||||
"column_stack",
|
||||
"dstack",
|
||||
"hstack",
|
||||
"vstack",
|
||||
"stack",
|
||||
]
|
||||
else RuntimeError
|
||||
)
|
||||
with self.assertRaises(exc_type, msg=msg_fail):
|
||||
op_out(out=out)
|
||||
|
||||
@ops(
|
||||
|
|
|
|||
|
|
@ -1046,7 +1046,7 @@ class TestTypePromotion(TestCase):
|
|||
and not (out_dtype.is_floating_point or out_dtype.is_complex))
|
||||
or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)):
|
||||
# This combinations do not support type conversion to a different class out type
|
||||
with self.assertRaises(RuntimeError):
|
||||
with self.assertRaises(TypeError):
|
||||
torch.cat([x, y], out=out)
|
||||
else:
|
||||
torch.cat([x, y], out=out)
|
||||
|
|
|
|||
|
|
@ -2456,7 +2456,7 @@ def error_inputs_cat(op_info, device, **kwargs):
|
|||
|
||||
# error inputs for empty tensors
|
||||
yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
|
||||
error_regex='non-empty list of Tensors')
|
||||
error_regex='non-empty list of Tensors', error_type=ValueError)
|
||||
|
||||
# error inputs for different sizes
|
||||
yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user