diff --git a/test/quantization/core/test_quantized_module.py b/test/quantization/core/test_quantized_module.py index 4da3a163fc4..613c237bdad 100644 --- a/test/quantization/core/test_quantized_module.py +++ b/test/quantization/core/test_quantized_module.py @@ -13,6 +13,7 @@ from torch.ao.quantization import ( default_float_qparams_observer, PerChannelMinMaxObserver, ) +from torch.package import PackageExporter, PackageImporter from torch.testing._internal.common_quantization import ( QuantizationTestCase, prepare_dynamic, @@ -107,6 +108,8 @@ class TestStaticQuantizedModule(QuantizationTestCase): qlinear = class_map[use_fused](in_features, out_features) qlinear_copy = copy.deepcopy(qlinear) + # set random quantized weight and bias before test torch scriptable + qlinear_copy.set_weight_bias(W_q, B) self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True) # Run module with default-initialized parameters. # This tests that the constructor is correct. @@ -175,6 +178,22 @@ class TestStaticQuantizedModule(QuantizationTestCase): self.assertEqual(qlinear.scale, loaded.scale) self.assertEqual(qlinear.zero_point, loaded.zero_point) + # Test torch.package + buffer = io.BytesIO() + with PackageExporter(buffer) as pe: + pe.save_pickle("module", "qlinear.pkl", qlinear) + buffer.seek(0) + + importer = PackageImporter(buffer) + loaded_from_package = importer.load_pickle("module", "qlinear.pkl") + self.assertEqual(qlinear.weight(), loaded_from_package.weight()) + self.assertEqual(qlinear.scale, loaded_from_package.scale) + self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point) + + for name, module in loaded_from_package.named_modules(): + # noop, just make sure attribute "_modules" is restored correctly during torch.package import + assert(name is not None) + # Test copy and deepcopy copied_linear = copy.copy(qlinear) self.assertEqual(copied_linear.bias(), qlinear.bias()) diff --git a/torch/nn/quantized/modules/linear.py b/torch/nn/quantized/modules/linear.py index da2de6c2a2f..d343ed1b00d 100644 --- a/torch/nn/quantized/modules/linear.py +++ b/torch/nn/quantized/modules/linear.py @@ -83,26 +83,6 @@ class LinearPackedParams(torch.nn.Module): super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False, missing_keys, unexpected_keys, error_msgs) - @torch.jit.export - def __getstate__(self): - qweight, bias = self._weight_bias() - return qweight, bias, self.training, self.dtype - - @torch.jit.export - def __setstate__(self, state): - self.dtype = state[3] - self.set_weight_bias(state[0], state[1]) - self.training = state[2] - - def __deepcopy__(self, memo): - new_instance = type(self).__new__(type(self)) - torch.nn.Module.__init__(new_instance) - state = self.__getstate__() - new_instance.__setstate__(state) - return new_instance - - def __copy__(self): - return self.__deepcopy__({}) def __repr__(self): return self._weight_bias().__repr__()