[pytorch] raise exception when calling dim order on sparse tensor (#145888)

This diff introduces a change to the PyTorch library that raises an exception when calling the `dim_order` method on a sparse tensor.

Differential Revision: [D68797044](https://our.internmc.facebook.com/intern/diff/D68797044/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145888
Approved by: https://github.com/Jack-Khuu
This commit is contained in:
gasoonjia 2025-01-28 15:03:39 -08:00 committed by PyTorch MergeBot
parent 2e8c080ab1
commit 501c5972f0
2 changed files with 14 additions and 1 deletions

View File

@ -8793,6 +8793,13 @@ tensor([[[1.+1.j, 1.+1.j, 1.+1.j, ..., 1.+1.j, 1.+1.j, 1.+1.j],
with self.assertRaises(TypeError):
torch.empty((1, 2, 3, 4)).dim_order(ambiguity_check="ILLEGAL_STR")
# sparse tensor does not support dim order
with self.assertRaises(AttributeError):
indices = torch.tensor([[0, 1, 2], [0, 1, 2]]) # (row, column) indices
values = torch.tensor([1.0, 2.0, 3.0]) # values at those indices
sparse_tensor = torch.sparse_coo_tensor(indices, values, size=(3, 3))
sparse_tensor.dim_order()
def test_subclass_tensors(self):
# raise an error when trying to subclass FloatTensor
with self.assertRaisesRegex(TypeError, "type 'torch.FloatTensor' is not an acceptable base type"):

View File

@ -1501,7 +1501,7 @@ class Tensor(torch._C.TensorBase):
Returns the uniquely determined tuple of int describing the dim order or
physical layout of :attr:`self`.
The dim order represents how dimensions are laid out in memory,
The dim order represents how dimensions are laid out in memory of dense tensors,
starting from the outermost to the innermost dimension.
Note that the dim order may not always be uniquely determined.
@ -1542,6 +1542,12 @@ class Tensor(torch._C.TensorBase):
if has_torch_function_unary(self):
return handle_torch_function(Tensor.dim_order, (self,), self)
if self.is_sparse:
raise AttributeError(
f"Can't get dim order on sparse type: {self.type()} "
"Use Tensor.to_dense() to convert to a dense tensor first."
)
# Sanity check ambiguity_check data types
if not isinstance(ambiguity_check, bool):
if not isinstance(ambiguity_check, list):