[JIT] scripting, freezing, serialization for sparse csr (#69555)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69555

1. Implement pickling/unpickling
2. Add `test_freeze_sparse_csr, tests_serialize_sparse_csr` tests

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D33181367

Pulled By: davidberard98

fbshipit-source-id: a15d5193a7b1b1625a27e4af003cec33cdbc8071
This commit is contained in:
David Berard 2021-12-20 11:10:34 -08:00 committed by Facebook GitHub Bot
parent bcb6076099
commit 41959ce77f
5 changed files with 110 additions and 7 deletions

View File

@ -288,12 +288,23 @@ private:
// based on tensor impl. //TODO: find a way to use mkldnn storage
if (a.is_mkldnn() || b.is_mkldnn()) {
return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
} else if (a.is_sparse() || b.is_sparse()) {
if (a.is_sparse()) {
return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
} else {
return isAliasOf(b._values(), a) || isAliasOf(b._indices(), a);
}
}
if (a.is_sparse()) {
return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
}
if (b.is_sparse()) {
return isAliasOf(a, b._values()) || isAliasOf(a, b._indices());
}
if (a.is_sparse_csr()) {
return isAliasOf(a.values(), b) ||
isAliasOf(a.crow_indices(), b) ||
isAliasOf(a.col_indices(), b);
}
if (b.is_sparse_csr()) {
return isAliasOf(a, b.values()) ||
isAliasOf(a, b.crow_indices()) ||
isAliasOf(a, b.col_indices());
}
return a.is_alias_of(b);
@ -892,6 +903,11 @@ public:
// so this will detect overlap of sparse tensors that share a values
// tensor, but not sparse tensors that share an indices tensor.
return hashTensor(ten._values());
} else if (ten.is_sparse_csr()) {
// COO sparse tensors have a "values" tensor and an "indices" tensor
// so this will detect overlap of sparse tensors that share a values
// tensor, but not sparse tensors that share an indices tensor.
return hashTensor(ten.values());
} else {
return reinterpret_cast<size_t>(
ten.storage().unsafeGetStorageImpl());

View File

@ -1544,7 +1544,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) {
VaryingShape<size_t> stride_indices;
VaryingShape<int64_t> strides;
VaryingShape<int64_t> sizes;
if (!t.is_mkldnn() && !t.is_sparse()) {
if (!t.is_mkldnn() && !t.is_sparse() && !t.is_sparse_csr()) {
sizes = VaryingShape<int64_t>{t.sizes().vec()};
strides = VaryingShape<int64_t>{t.strides().vec()};
return TensorType::create(

View File

@ -2,6 +2,8 @@
import io
import torch
import unittest
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
from torch.testing._internal.jit_utils import JitTestCase
@ -59,3 +61,61 @@ class TestSparse(JitTestCase):
loaded_result = loaded_model.forward(x)
self.assertEqual(expected_result, loaded_result)
@unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
def test_freeze_sparse_csr(self):
class SparseTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand(4, 4).to_sparse_csr()
self.b = torch.rand(4, 4).to_sparse_csr()
def forward(self, x):
return x.matmul(self.a).matmul(self.b)
x = torch.rand(4, 4).to_sparse_csr()
m = SparseTensorModule()
unfrozen_result = m.forward(x)
m.eval()
frozen = torch.jit.freeze(torch.jit.script(m))
frozen_result = frozen.forward(x)
self.assertEqual(unfrozen_result.to_dense(), frozen_result.to_dense())
buffer = io.BytesIO()
torch.jit.save(frozen, buffer)
buffer.seek(0)
loaded_model = torch.jit.load(buffer)
loaded_result = loaded_model.forward(x)
self.assertEqual(unfrozen_result.to_dense(), loaded_result.to_dense())
@unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
def test_serialize_sparse_csr(self):
class SparseTensorModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.a = torch.rand(4, 4).to_sparse_csr()
self.b = torch.rand(4, 4).to_sparse_csr()
def forward(self, x):
return x.matmul(self.a).matmul(self.b)
x = torch.rand(4, 4).to_sparse_csr()
m = SparseTensorModule()
expected_result = m.forward(x)
buffer = io.BytesIO()
torch.jit.save(torch.jit.script(m), buffer)
buffer.seek(0)
loaded_model = torch.jit.load(buffer)
loaded_result = loaded_model.forward(x)
self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())

View File

@ -374,6 +374,18 @@ void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
// values
pushTensor(tensor._values());
break;
case static_cast<int>(c10::Layout::SparseCsr):
push<PickleOpCode>(PickleOpCode::MARK);
for (auto size : tensor.sizes()) {
pushInt(size);
}
push<PickleOpCode>(PickleOpCode::TUPLE);
pushIValue(tensor.requires_grad());
pushTensor(tensor.crow_indices());
pushTensor(tensor.col_indices());
pushTensor(tensor.values());
break;
default:
TORCH_CHECK(
false,

View File

@ -726,6 +726,21 @@ void Unpickler::rebuildSparseTensor() {
result = autograd::make_variable(result, options.requires_grad());
break;
}
case static_cast<int>(c10::Layout::SparseCsr): {
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
bool requires_grad = elements.at(idx++).toBool();
auto& crow_indices = elements.at(idx++).toTensor();
auto& col_indices = elements.at(idx++).toTensor();
auto& values_tensor = elements.at(idx++).toTensor();
auto options = values_tensor.options()
.layout(c10::Layout::SparseCsr)
.requires_grad(requires_grad);
result = at::_sparse_csr_tensor_unsafe(
crow_indices, col_indices, values_tensor, size, options);
result =
autograd::make_variable(std::move(result), options.requires_grad());
break;
}
default:
TORCH_CHECK(
false,