mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make LinearPackedParams works with both torchscript and torch.package (#71656)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71656
Customized `__getstate__`/`__setstate__` didn't call super (torch.nn.Module), and won't restore attributes (e.g. `_modules`) after being serialized and deserialized via torch.package
After a few iteration, as it turns out, pack/unpack linear param has been supported in torchbind class already, no need to hack torch module anymore.
Test Plan: `buck test caffe2/test/:quantization -- test_linear_api`
Reviewed By: jerryzh168
Differential Revision: D33711086
fbshipit-source-id: 3a36d10c64b7da414d3657d2ef766bb9a9290ea9
(cherry picked from commit 6337b6c207)
This commit is contained in:
parent
717d8c6224
commit
09e2fb8f6e
|
|
@ -13,6 +13,7 @@ from torch.ao.quantization import (
|
|||
default_float_qparams_observer,
|
||||
PerChannelMinMaxObserver,
|
||||
)
|
||||
from torch.package import PackageExporter, PackageImporter
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
prepare_dynamic,
|
||||
|
|
@ -107,6 +108,8 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
qlinear = class_map[use_fused](in_features, out_features)
|
||||
|
||||
qlinear_copy = copy.deepcopy(qlinear)
|
||||
# set random quantized weight and bias before test torch scriptable
|
||||
qlinear_copy.set_weight_bias(W_q, B)
|
||||
self.checkScriptable(qlinear_copy, [[X_q]], check_save_load=True)
|
||||
# Run module with default-initialized parameters.
|
||||
# This tests that the constructor is correct.
|
||||
|
|
@ -175,6 +178,22 @@ class TestStaticQuantizedModule(QuantizationTestCase):
|
|||
self.assertEqual(qlinear.scale, loaded.scale)
|
||||
self.assertEqual(qlinear.zero_point, loaded.zero_point)
|
||||
|
||||
# Test torch.package
|
||||
buffer = io.BytesIO()
|
||||
with PackageExporter(buffer) as pe:
|
||||
pe.save_pickle("module", "qlinear.pkl", qlinear)
|
||||
buffer.seek(0)
|
||||
|
||||
importer = PackageImporter(buffer)
|
||||
loaded_from_package = importer.load_pickle("module", "qlinear.pkl")
|
||||
self.assertEqual(qlinear.weight(), loaded_from_package.weight())
|
||||
self.assertEqual(qlinear.scale, loaded_from_package.scale)
|
||||
self.assertEqual(qlinear.zero_point, loaded_from_package.zero_point)
|
||||
|
||||
for name, module in loaded_from_package.named_modules():
|
||||
# noop, just make sure attribute "_modules" is restored correctly during torch.package import
|
||||
assert(name is not None)
|
||||
|
||||
# Test copy and deepcopy
|
||||
copied_linear = copy.copy(qlinear)
|
||||
self.assertEqual(copied_linear.bias(), qlinear.bias())
|
||||
|
|
|
|||
|
|
@ -83,26 +83,6 @@ class LinearPackedParams(torch.nn.Module):
|
|||
super(LinearPackedParams, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
@torch.jit.export
|
||||
def __getstate__(self):
|
||||
qweight, bias = self._weight_bias()
|
||||
return qweight, bias, self.training, self.dtype
|
||||
|
||||
@torch.jit.export
|
||||
def __setstate__(self, state):
|
||||
self.dtype = state[3]
|
||||
self.set_weight_bias(state[0], state[1])
|
||||
self.training = state[2]
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
new_instance = type(self).__new__(type(self))
|
||||
torch.nn.Module.__init__(new_instance)
|
||||
state = self.__getstate__()
|
||||
new_instance.__setstate__(state)
|
||||
return new_instance
|
||||
|
||||
def __copy__(self):
|
||||
return self.__deepcopy__({})
|
||||
|
||||
def __repr__(self):
|
||||
return self._weight_bias().__repr__()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user