[BE] Raise ValueError from torch.cat meta func (#158249)

Followup after https://github.com/pytorch/pytorch/pull/155460

From [Python documentation](https://docs.python.org/3/library/exceptions.html#ValueError):
> Raised when an operation or function receives an argument that has the right type but an inappropriate value, and the situation is not described by a more precise exception such as IndexError.

Raise [`TypeError`](https://docs.python.org/3/library/exceptions.html#TypeError) when input-output types are incompatible with each other
> Raised when an operation or function is applied to an object of inappropriate type. The associated value is a string giving details about the type mismatch.

> This exception may be raised by user code to indicate that an attempted operation on an object is not supported, and is not meant to be. If an object is meant to support a given operation but has not yet provided an implementation, [NotImplementedError](https://docs.python.org/3/library/exceptions.html#NotImplementedError) is the proper exception to raise.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158249
Approved by: https://github.com/jbschlosser, https://github.com/Skylion007, https://github.com/albanD
This commit is contained in:
Nikita Shulga 2025-07-20 23:49:18 +00:00 committed by PyTorch MergeBot
parent 4b02bd76d3
commit 2cdafab0bd
4 changed files with 21 additions and 6 deletions

View File

@ -247,7 +247,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
// Checking names before the actual dimensions.
auto maybe_outnames = namedinference::compute_cat_outnames(materialized);
TORCH_CHECK(
TORCH_CHECK_VALUE(
!materialized.empty(),
"torch.cat(): expected a non-empty list of Tensors");
@ -274,7 +274,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
// when computing the actual output dtype and the flags.
if (is_out_defined) {
// Check for type promotion, if the output tensor is defined.
TORCH_CHECK(
TORCH_CHECK_TYPE(
canCast(out_dtype, result.scalar_type()),
"torch.cat(): input types can't be cast to the desired output type ",
result.scalar_type());
@ -293,7 +293,7 @@ TORCH_PRECOMPUTE_META_FUNC(cat)(const ITensorListRef& tensors, int64_t dim) {
// are compatible, i.e. we can execute `cat` on them.
bool found_valid_tensor = valid < materialized.size();
if (found_valid_tensor) {
TORCH_CHECK(
TORCH_CHECK_INDEX(
dim <= materialized[valid].get().dim(),
"torch.cat(): dimension ",
dim,

View File

@ -1109,7 +1109,22 @@ class TestCommon(TestCase):
if op.is_factory_function and sample.kwargs.get("dtype", None) is None:
op_out(out=out)
else:
with self.assertRaises(RuntimeError, msg=msg_fail):
# TODO: Remove me when all ops will raise type error on mismatched types
exc_type = (
TypeError
if op.name
in [
"_chunk_cat",
"cat",
"column_stack",
"dstack",
"hstack",
"vstack",
"stack",
]
else RuntimeError
)
with self.assertRaises(exc_type, msg=msg_fail):
op_out(out=out)
@ops(

View File

@ -1046,7 +1046,7 @@ class TestTypePromotion(TestCase):
and not (out_dtype.is_floating_point or out_dtype.is_complex))
or ((x_dtype.is_complex or y_dtype.is_complex) and not out_dtype.is_complex)):
# This combinations do not support type conversion to a different class out type
with self.assertRaises(RuntimeError):
with self.assertRaises(TypeError):
torch.cat([x, y], out=out)
else:
torch.cat([x, y], out=out)

View File

@ -2456,7 +2456,7 @@ def error_inputs_cat(op_info, device, **kwargs):
# error inputs for empty tensors
yield ErrorInput(SampleInput([], kwargs={'dim': 1}),
error_regex='non-empty list of Tensors')
error_regex='non-empty list of Tensors', error_type=ValueError)
# error inputs for different sizes
yield ErrorInput(SampleInput([make_arg((S, S, L, L)), make_arg((S, 0, L - 1, L))], kwargs={'dim': 1}),