From 09e2fb8f6ebf402baa31c4866ec07fd99a2736eb Mon Sep 17 00:00:00 2001 From: Shijun Kong Date: Mon, 7 Feb 2022 10:29:07 -0800 Subject: [PATCH] Make LinearPackedParams works with both torchscript and torch.package (#71656) Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71656 Customized `__getstate__`/`__setstate__` didn't call super (torch.nn.Module), and won't restore attributes (e.g. `_modules`) after being serialized and deserialized via torch.package After a few iteration, as it turns out, pack/unpack linear param has been supported in torchbind class already, no need to hack torch module anymore. Test Plan: `buck test caffe2/test/:quantization -- test_linear_api` Reviewed By: jerryzh168 Differential Revision: D33711086 fbshipit-source-id: 3a36d10c64b7da414d3657d2ef766bb9a9290ea9 (cherry picked from commit 6337b6c20747d661a3920bf553dcc911dfa77671) --- .../core/test_quantized_module.py | 19 ++++++++++++++++++ torch/nn/quantized/modules/linear.py | 20 ------------------- 2 files changed, 19 insertions(+), 20 deletions(-) diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 4da3a163fc4..613c237bdad 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -13,6 +13,7 @@ from torch.ao.quantization import ( default_float_qparams_observer, PerChannelMinMaxObserver, ) +from torch.package import PackageExporter, PackageImporter from torch.testing._internal.common_quantization import ( QuantizationTestCase, prepare_dynamic, @@ -107,6 +108,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): qlinear = class_map[use_fused](in_features, out_features) qlinear_copy = copy.deepcopy(qlinear) + # set random quantized weight and bias before test torch scriptable + qlinear_copy.set_weight_bias(W_q, B) self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) # Run module with default-initialized parameters. # This tests that the constructor is correct. @@ -175,6 +178,22 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) + # Test torch.package + buffer = io.BytesIO() + with PackageExporter(buffer) as pe: + pe.save_pickle("module", "qlinear.pkl", qlinear) + buffer.seek(0) + + importer = PackageImporter(buffer) + loaded_from_package = importer.load_pickle("module", "qlinear.pkl") + self.assertEqual(qlinear.weight(), loaded_from_package.weight()) + self.assertEqual(qlinear.scale, loaded_from_package.scale) + self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point) + + for name, module in loaded_from_package.named_modules(): + # noop, just make sure attribute "_modules" is restored correctly during torch.package import + assert(name is not None) + # Test copy and deepcopy copied_linear = copy.copy(qlinear) self.assertEqual(copied_linear.bias(), qlinear.bias()) diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index da2de6c2a2f..d343ed1b00d 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -83,26 +83,6 @@ class LinearPackedParams(torch.nn.Module): super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - @torch.jit.export - def __getstate__(self): - qweight, bias = self._weight_bias() - return qweight, bias, self.training, self.dtype - - @torch.jit.export - def __setstate__(self, state): - self.dtype = state[3] - self.set_weight_bias(state[0], state[1]) - self.training = state[2] - - def __deepcopy__(self, memo): - new_instance = type(self).__new__(type(self)) - torch.nn.Module.__init__(new_instance) - state = self.__getstate__() - new_instance.__setstate__(state) - return new_instance - - def __copy__(self): - return self.__deepcopy__({}) def __repr__(self): return self._weight_bias().__repr__()