mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes the sparse tensor issue (#163535)
Fixes #148324 Pull Request resolved: https://github.com/pytorch/pytorch/pull/163535 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
fd68d409ad
commit
c01636e1bc
|
|
@ -467,6 +467,28 @@ Tensor sparse_coo_tensor(const Tensor& indices, const Tensor& values, IntArrayRe
|
|||
!options.has_layout() || options.layout() == kSparse,
|
||||
"expected sparse layout, but got layout ",
|
||||
options.layout());
|
||||
|
||||
if (indices.numel() > 0) {
|
||||
Tensor min_indices =
|
||||
std::get</* values */ 0>(indices.min(/* dim */ 1, /* keepdim */ false));
|
||||
Tensor cpu_min_indices;
|
||||
if (!indices.is_cpu()) {
|
||||
cpu_min_indices = min_indices.to(at::DeviceType::CPU);
|
||||
} else {
|
||||
cpu_min_indices = min_indices;
|
||||
}
|
||||
auto cpu_min_indices_accessor = cpu_min_indices.accessor<int64_t, 1>();
|
||||
for (const auto d : c10::irange(indices.size(0))) {
|
||||
int64_t min_index_in_dim = cpu_min_indices_accessor[d];
|
||||
TORCH_CHECK(
|
||||
min_index_in_dim >= 0,
|
||||
"found negative index ",
|
||||
min_index_in_dim,
|
||||
" for dim ",
|
||||
d);
|
||||
}
|
||||
}
|
||||
|
||||
return at::native::_sparse_coo_tensor_unsafe(
|
||||
indices,
|
||||
values,
|
||||
|
|
|
|||
|
|
@ -217,6 +217,12 @@ class TestSparse(TestSparseBase):
|
|||
else:
|
||||
existing_indices.add(index)
|
||||
|
||||
def test_negative_indices(self):
|
||||
indices = torch.tensor([[0, 1, -1], [2, 0, 1]])
|
||||
values = torch.tensor([1, 2, 3])
|
||||
shape = torch.Size([3, 3])
|
||||
self.assertRaisesRegex(RuntimeError, "found negative index", lambda: torch.sparse_coo_tensor(indices, values, shape))
|
||||
|
||||
def randn(self, *args, **kwargs):
|
||||
"""
|
||||
Variant of torch.randn that also works in the TEST_CUDA case.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user