diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 96ddc75d14a..4c1e310f6e6 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -288,12 +288,23 @@ private: // based on tensor impl. //TODO: find a way to use mkldnn storage if (a.is_mkldnn() || b.is_mkldnn()) { return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl(); - } else if (a.is_sparse() || b.is_sparse()) { - if (a.is_sparse()) { - return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); - } else { - return isAliasOf(b._values(), a) || isAliasOf(b._indices(), a); - } + } + + if (a.is_sparse()) { + return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b); + } + if (b.is_sparse()) { + return isAliasOf(a, b._values()) || isAliasOf(a, b._indices()); + } + if (a.is_sparse_csr()) { + return isAliasOf(a.values(), b) || + isAliasOf(a.crow_indices(), b) || + isAliasOf(a.col_indices(), b); + } + if (b.is_sparse_csr()) { + return isAliasOf(a, b.values()) || + isAliasOf(a, b.crow_indices()) || + isAliasOf(a, b.col_indices()); } return a.is_alias_of(b); @@ -892,6 +903,11 @@ public: // so this will detect overlap of sparse tensors that share a values // tensor, but not sparse tensors that share an indices tensor. return hashTensor(ten._values()); + } else if (ten.is_sparse_csr()) { + // COO sparse tensors have a "values" tensor and an "indices" tensor + // so this will detect overlap of sparse tensors that share a values + // tensor, but not sparse tensors that share an indices tensor. + return hashTensor(ten.values()); } else { return reinterpret_cast( ten.storage().unsafeGetStorageImpl()); diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index 5c2aa83994a..730d129a41f 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -1544,7 +1544,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) { VaryingShape stride_indices; VaryingShape strides; VaryingShape sizes; - if (!t.is_mkldnn() && !t.is_sparse()) { + if (!t.is_mkldnn() && !t.is_sparse() && !t.is_sparse_csr()) { sizes = VaryingShape{t.sizes().vec()}; strides = VaryingShape{t.strides().vec()}; return TensorType::create( diff --git a/test/jit/test_sparse.py b/test/jit/test_sparse.py index 4017ab3b310..00102ccc1c2 100644 --- a/test/jit/test_sparse.py +++ b/test/jit/test_sparse.py @@ -2,6 +2,8 @@ import io import torch +import unittest +from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL from torch.testing._internal.jit_utils import JitTestCase @@ -59,3 +61,61 @@ class TestSparse(JitTestCase): loaded_result = loaded_model.forward(x) self.assertEqual(expected_result, loaded_result) + + @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul") + def test_freeze_sparse_csr(self): + class SparseTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.rand(4, 4).to_sparse_csr() + self.b = torch.rand(4, 4).to_sparse_csr() + + + def forward(self, x): + + return x.matmul(self.a).matmul(self.b) + + x = torch.rand(4, 4).to_sparse_csr() + + m = SparseTensorModule() + unfrozen_result = m.forward(x) + + m.eval() + frozen = torch.jit.freeze(torch.jit.script(m)) + + frozen_result = frozen.forward(x) + + self.assertEqual(unfrozen_result.to_dense(), frozen_result.to_dense()) + + buffer = io.BytesIO() + torch.jit.save(frozen, buffer) + buffer.seek(0) + loaded_model = torch.jit.load(buffer) + + loaded_result = loaded_model.forward(x) + + self.assertEqual(unfrozen_result.to_dense(), loaded_result.to_dense()) + + @unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul") + def test_serialize_sparse_csr(self): + class SparseTensorModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.a = torch.rand(4, 4).to_sparse_csr() + self.b = torch.rand(4, 4).to_sparse_csr() + + def forward(self, x): + return x.matmul(self.a).matmul(self.b) + + x = torch.rand(4, 4).to_sparse_csr() + m = SparseTensorModule() + expected_result = m.forward(x) + + buffer = io.BytesIO() + torch.jit.save(torch.jit.script(m), buffer) + buffer.seek(0) + loaded_model = torch.jit.load(buffer) + + loaded_result = loaded_model.forward(x) + + self.assertEqual(expected_result.to_dense(), loaded_result.to_dense()) diff --git a/torch/csrc/jit/serialization/pickler.cpp b/torch/csrc/jit/serialization/pickler.cpp index 03846b4b3de..3f384f023b8 100644 --- a/torch/csrc/jit/serialization/pickler.cpp +++ b/torch/csrc/jit/serialization/pickler.cpp @@ -374,6 +374,18 @@ void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) { // values pushTensor(tensor._values()); break; + case static_cast(c10::Layout::SparseCsr): + push(PickleOpCode::MARK); + for (auto size : tensor.sizes()) { + pushInt(size); + } + push(PickleOpCode::TUPLE); + + pushIValue(tensor.requires_grad()); + pushTensor(tensor.crow_indices()); + pushTensor(tensor.col_indices()); + pushTensor(tensor.values()); + break; default: TORCH_CHECK( false, diff --git a/torch/csrc/jit/serialization/unpickler.cpp b/torch/csrc/jit/serialization/unpickler.cpp index 1f5602a888c..4a01c9d81a0 100644 --- a/torch/csrc/jit/serialization/unpickler.cpp +++ b/torch/csrc/jit/serialization/unpickler.cpp @@ -726,6 +726,21 @@ void Unpickler::rebuildSparseTensor() { result = autograd::make_variable(result, options.requires_grad()); break; } + case static_cast(c10::Layout::SparseCsr): { + std::vector size = tupleToIntList(elements.at(idx++)); + bool requires_grad = elements.at(idx++).toBool(); + auto& crow_indices = elements.at(idx++).toTensor(); + auto& col_indices = elements.at(idx++).toTensor(); + auto& values_tensor = elements.at(idx++).toTensor(); + auto options = values_tensor.options() + .layout(c10::Layout::SparseCsr) + .requires_grad(requires_grad); + result = at::_sparse_csr_tensor_unsafe( + crow_indices, col_indices, values_tensor, size, options); + result = + autograd::make_variable(std::move(result), options.requires_grad()); + break; + } default: TORCH_CHECK( false,