mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] scripting, freezing, serialization for sparse csr (#69555)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/69555 1. Implement pickling/unpickling 2. Add `test_freeze_sparse_csr, tests_serialize_sparse_csr` tests Test Plan: Imported from OSS Reviewed By: mruberry Differential Revision: D33181367 Pulled By: davidberard98 fbshipit-source-id: a15d5193a7b1b1625a27e4af003cec33cdbc8071
This commit is contained in:
parent
bcb6076099
commit
41959ce77f
|
|
@ -288,12 +288,23 @@ private:
|
|||
// based on tensor impl. //TODO: find a way to use mkldnn storage
|
||||
if (a.is_mkldnn() || b.is_mkldnn()) {
|
||||
return a.unsafeGetTensorImpl() == b.unsafeGetTensorImpl();
|
||||
} else if (a.is_sparse() || b.is_sparse()) {
|
||||
if (a.is_sparse()) {
|
||||
return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
|
||||
} else {
|
||||
return isAliasOf(b._values(), a) || isAliasOf(b._indices(), a);
|
||||
}
|
||||
}
|
||||
|
||||
if (a.is_sparse()) {
|
||||
return isAliasOf(a._values(), b) || isAliasOf(a._indices(), b);
|
||||
}
|
||||
if (b.is_sparse()) {
|
||||
return isAliasOf(a, b._values()) || isAliasOf(a, b._indices());
|
||||
}
|
||||
if (a.is_sparse_csr()) {
|
||||
return isAliasOf(a.values(), b) ||
|
||||
isAliasOf(a.crow_indices(), b) ||
|
||||
isAliasOf(a.col_indices(), b);
|
||||
}
|
||||
if (b.is_sparse_csr()) {
|
||||
return isAliasOf(a, b.values()) ||
|
||||
isAliasOf(a, b.crow_indices()) ||
|
||||
isAliasOf(a, b.col_indices());
|
||||
}
|
||||
|
||||
return a.is_alias_of(b);
|
||||
|
|
@ -892,6 +903,11 @@ public:
|
|||
// so this will detect overlap of sparse tensors that share a values
|
||||
// tensor, but not sparse tensors that share an indices tensor.
|
||||
return hashTensor(ten._values());
|
||||
} else if (ten.is_sparse_csr()) {
|
||||
// COO sparse tensors have a "values" tensor and an "indices" tensor
|
||||
// so this will detect overlap of sparse tensors that share a values
|
||||
// tensor, but not sparse tensors that share an indices tensor.
|
||||
return hashTensor(ten.values());
|
||||
} else {
|
||||
return reinterpret_cast<size_t>(
|
||||
ten.storage().unsafeGetStorageImpl());
|
||||
|
|
|
|||
|
|
@ -1544,7 +1544,7 @@ TensorTypePtr TensorType::create(const at::Tensor& t) {
|
|||
VaryingShape<size_t> stride_indices;
|
||||
VaryingShape<int64_t> strides;
|
||||
VaryingShape<int64_t> sizes;
|
||||
if (!t.is_mkldnn() && !t.is_sparse()) {
|
||||
if (!t.is_mkldnn() && !t.is_sparse() && !t.is_sparse_csr()) {
|
||||
sizes = VaryingShape<int64_t>{t.sizes().vec()};
|
||||
strides = VaryingShape<int64_t>{t.strides().vec()};
|
||||
return TensorType::create(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
import io
|
||||
import torch
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import IS_WINDOWS, TEST_MKL
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
|
||||
|
|
@ -59,3 +61,61 @@ class TestSparse(JitTestCase):
|
|||
loaded_result = loaded_model.forward(x)
|
||||
|
||||
self.assertEqual(expected_result, loaded_result)
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
|
||||
def test_freeze_sparse_csr(self):
|
||||
class SparseTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.rand(4, 4).to_sparse_csr()
|
||||
self.b = torch.rand(4, 4).to_sparse_csr()
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
return x.matmul(self.a).matmul(self.b)
|
||||
|
||||
x = torch.rand(4, 4).to_sparse_csr()
|
||||
|
||||
m = SparseTensorModule()
|
||||
unfrozen_result = m.forward(x)
|
||||
|
||||
m.eval()
|
||||
frozen = torch.jit.freeze(torch.jit.script(m))
|
||||
|
||||
frozen_result = frozen.forward(x)
|
||||
|
||||
self.assertEqual(unfrozen_result.to_dense(), frozen_result.to_dense())
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(frozen, buffer)
|
||||
buffer.seek(0)
|
||||
loaded_model = torch.jit.load(buffer)
|
||||
|
||||
loaded_result = loaded_model.forward(x)
|
||||
|
||||
self.assertEqual(unfrozen_result.to_dense(), loaded_result.to_dense())
|
||||
|
||||
@unittest.skipIf(IS_WINDOWS or not TEST_MKL, "Need MKL to run CSR matmul")
|
||||
def test_serialize_sparse_csr(self):
|
||||
class SparseTensorModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.a = torch.rand(4, 4).to_sparse_csr()
|
||||
self.b = torch.rand(4, 4).to_sparse_csr()
|
||||
|
||||
def forward(self, x):
|
||||
return x.matmul(self.a).matmul(self.b)
|
||||
|
||||
x = torch.rand(4, 4).to_sparse_csr()
|
||||
m = SparseTensorModule()
|
||||
expected_result = m.forward(x)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
torch.jit.save(torch.jit.script(m), buffer)
|
||||
buffer.seek(0)
|
||||
loaded_model = torch.jit.load(buffer)
|
||||
|
||||
loaded_result = loaded_model.forward(x)
|
||||
|
||||
self.assertEqual(expected_result.to_dense(), loaded_result.to_dense())
|
||||
|
|
|
|||
|
|
@ -374,6 +374,18 @@ void Pickler::pushLiteralSparseTensor(const at::Tensor& tensor) {
|
|||
// values
|
||||
pushTensor(tensor._values());
|
||||
break;
|
||||
case static_cast<int>(c10::Layout::SparseCsr):
|
||||
push<PickleOpCode>(PickleOpCode::MARK);
|
||||
for (auto size : tensor.sizes()) {
|
||||
pushInt(size);
|
||||
}
|
||||
push<PickleOpCode>(PickleOpCode::TUPLE);
|
||||
|
||||
pushIValue(tensor.requires_grad());
|
||||
pushTensor(tensor.crow_indices());
|
||||
pushTensor(tensor.col_indices());
|
||||
pushTensor(tensor.values());
|
||||
break;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -726,6 +726,21 @@ void Unpickler::rebuildSparseTensor() {
|
|||
result = autograd::make_variable(result, options.requires_grad());
|
||||
break;
|
||||
}
|
||||
case static_cast<int>(c10::Layout::SparseCsr): {
|
||||
std::vector<int64_t> size = tupleToIntList(elements.at(idx++));
|
||||
bool requires_grad = elements.at(idx++).toBool();
|
||||
auto& crow_indices = elements.at(idx++).toTensor();
|
||||
auto& col_indices = elements.at(idx++).toTensor();
|
||||
auto& values_tensor = elements.at(idx++).toTensor();
|
||||
auto options = values_tensor.options()
|
||||
.layout(c10::Layout::SparseCsr)
|
||||
.requires_grad(requires_grad);
|
||||
result = at::_sparse_csr_tensor_unsafe(
|
||||
crow_indices, col_indices, values_tensor, size, options);
|
||||
result =
|
||||
autograd::make_variable(std::move(result), options.requires_grad());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user