Make LinearPackedParams works with both torchscript and torch.package (#71656)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71656

Customized `__getstate__`/`__setstate__` didn't call super (torch.nn.Module), and won't restore attributes (e.g. `_modules`) after being serialized and deserialized via torch.package

After a few iteration, as it turns out, pack/unpack linear param has been supported in torchbind class already, no need to hack torch module anymore.

Test Plan: `buck test caffe2/test/:quantization -- test_linear_api`

Reviewed By: jerryzh168

Differential Revision: D33711086

fbshipit-source-id: 3a36d10c64b7da414d3657d2ef766bb9a9290ea9
(cherry picked from commit 6337b6c207)
This commit is contained in:
Shijun Kong 2022-02-07 10:29:07 -08:00 committed by PyTorch MergeBot
parent 717d8c6224
commit 09e2fb8f6e
2 changed files with 19 additions and 20 deletions

View File

@ -13,6 +13,7 @@ from torch.ao.quantization import (
default_float_qparams_observer,
PerChannelMinMaxObserver,
)
from torch.package import PackageExporter, PackageImporter
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
prepare_dynamic,
@ -107,6 +108,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
qlinear = class_map[use_fused](in_features, out_features)
qlinear_copy = copy.deepcopy(qlinear)
# set random quantized weight and bias before test torch scriptable
qlinear_copy.set_weight_bias(W_q, B)
self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
# Run module with default-initialized parameters.
# This tests that the constructor is correct.
@ -175,6 +178,22 @@ class TestStaticQuantizedModule(QuantizationTestCase):
self.assertEqual(qlinear.scale, loaded.scale)
self.assertEqual(qlinear.zero_point, loaded.zero_point)
# Test torch.package
buffer = io.BytesIO()
with PackageExporter(buffer) as pe:
pe.save_pickle("module", "qlinear.pkl", qlinear)
buffer.seek(0)
importer = PackageImporter(buffer)
loaded_from_package = importer.load_pickle("module", "qlinear.pkl")
self.assertEqual(qlinear.weight(), loaded_from_package.weight())
self.assertEqual(qlinear.scale, loaded_from_package.scale)
self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point)
for name, module in loaded_from_package.named_modules():
# noop, just make sure attribute "_modules" is restored correctly during torch.package import
assert(name is not None)
# Test copy and deepcopy
copied_linear = copy.copy(qlinear)
self.assertEqual(copied_linear.bias(), qlinear.bias())

View File

@ -83,26 +83,6 @@ class LinearPackedParams(torch.nn.Module):
super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
@torch.jit.export
def __getstate__(self):
qweight, bias = self._weight_bias()
return qweight, bias, self.training, self.dtype
@torch.jit.export
def __setstate__(self, state):
self.dtype = state[3]
self.set_weight_bias(state[0], state[1])
self.training = state[2]
def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
torch.nn.Module.__init__(new_instance)
state = self.__getstate__()
new_instance.__setstate__(state)
return new_instance
def __copy__(self):
return self.__deepcopy__({})
def __repr__(self):
return self._weight_bias().__repr__()