# Owner(s): ["oncall: quantization"] import torch import torch._C from torch.ao.quantization import ( default_dynamic_qconfig, per_channel_dynamic_qconfig, ) from torch.ao.quantization.quantize_jit import ( prepare_dynamic_jit, convert_dynamic_jit, _prepare_ondevice_dynamic_jit, _quantize_ondevice_dynamic_jit, ) from torch.testing._internal.common_utils import TestCase from torch.testing._internal.common_quantization import ( get_script_module, LinearAddModel, ) from torch.jit.mobile import _load_for_lite_interpreter, LiteScriptModule from torch.testing import FileCheck from torch.utils import bundled_inputs as bundled_inputs import io from typing import Dict class myMod(torch.nn.Module): def __init__(self, weight): super().__init__() self.fc1 = torch.nn.Linear(5, 5).float() self.fc1.weight = weight self.fc2 = torch.nn.Linear(5, 5).float() def forward(self, x): return self.fc2(self.fc1(x)) class MyConvLinearModule(torch.nn.Module): def __init__(self): super().__init__() self.conv = torch.nn.Conv2d(3, 5, 3) weight = torch.nn.Parameter(torch.ones(5, 5)) self.weight1 = torch.nn.Parameter(torch.ones(5, 5)) self.mymod = myMod(weight) def forward(self, x): conv_output = self.conv(x) y = self.mymod(conv_output) z = torch.nn.functional.linear(y, self.weight1) return z def get_example_inputs(self): return (torch.rand(1, 3, 12, 7),) class OnDevicePTQUtils: observer_module_name = ['MinMaxObserver', 'PerChannelMinMaxObserver'] @staticmethod def insert_observers(model, qconfig_dict): inputs = model.get_example_inputs() scripted_model = get_script_module(model, False, inputs) scripted_model = _prepare_ondevice_dynamic_jit(scripted_model, qconfig_dict) return scripted_model @staticmethod def ptq_dynamic_quantize(model, qconfig_dict): inputs = model.get_example_inputs() m = get_script_module(model, False, inputs) m = _quantize_ondevice_dynamic_jit(m, qconfig_dict, 'forward', True) return m @staticmethod def find_observer_modules(m): observer_modules = [] for child_module in m.children(): if child_module.original_name in OnDevicePTQUtils.observer_module_name: observer_modules.append(child_module) return observer_modules @staticmethod def is_value_type_observer(value): type_name = value.type() for observer_type in OnDevicePTQUtils.observer_module_name: if observer_type in type_name.str(): return True return False @staticmethod def is_calculate_qparam(node): if node.kind() == "prim::CallMethod": if node.s('name') == "calculate_qparams": return True return False @staticmethod def get_linear_packed_param_fp_weight(node): weight = node.inputsAt(0).node() if weight.kind() != "aten::quantize_per_tensor" and weight.kind() != "aten::quantize_per_channel": raise ValueError("Quantized weight must be produced.") fp_weight = weight.inputsAt(0).node() assert fp_weight.kind() == "prim::GetAttr", "Weight must be an attribute of the module." fp_weight_name = fp_weight.s('name') return fp_weight_name @staticmethod def is_per_channel_quantized_packed_param(node): assert node.kind() == 'quantized::linear_prepack', "Node must corresponds to linear_prepack." weight = node.inputsAt(0).node() assert weight.kind() != "aten::quantize_per_tensor" or weight.kind() != "aten::quantize_per_channel" return weight.kind() != "aten::quantize_per_tensor" class TestOnDeviceDynamicPTQInsertObservers(TestCase): def _check_num_and_type_of_observers(self, model, num_observers): qconfig_dict = {"": default_dynamic_qconfig} scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) self.assertTrue(len(observer_modules) == num_observers) for observer in observer_modules: self.assertTrue(observer.original_name == 'MinMaxObserver') qconfig_dict = {"": per_channel_dynamic_qconfig} scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observer_modules = OnDevicePTQUtils.find_observer_modules(scripted_model) self.assertTrue(len(observer_modules) == num_observers) for observer in observer_modules: self.assertTrue(observer.original_name == 'PerChannelMinMaxObserver') def _check_observer_method(self, model, num_observers): qconfig_dict = {"": default_dynamic_qconfig} inputs = model.get_example_inputs() orig_scripted_model = get_script_module(model, False, inputs) torch._C._jit_pass_inline(orig_scripted_model.graph) orig_forward_graph = orig_scripted_model.graph.str() scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) quant_forward_graph = scripted_model.graph.str() # exact graph matching is difficult so just resorting to # of lines # instead of implementing graph matching self.assertEqual(len(orig_forward_graph.splitlines()), len(quant_forward_graph.splitlines())) observe_method = scripted_model.observe_forward.graph FileCheck().check_count("prim::CallMethod[name=\"forward\"](%_observer", num_observers, exactly=True).run(observe_method) reset_observers_method = scripted_model.reset_observers_forward.graph FileCheck().check_count( "prim::CallMethod[name=\"reset_min_max_vals\"](%_observer", num_observers, exactly=True).run(reset_observers_method) def _observer_is_weight_only(self, node): if (node.kind() == "prim::CallMethod") and node.s("name") == "forward": if (OnDevicePTQUtils.is_value_type_observer(node.inputsAt(0))): return (node.inputsAt(1).node().kind() == "prim::GetAttr") return False def test_num_observers(self): model = LinearAddModel() self._check_num_and_type_of_observers(model, 2) model = MyConvLinearModule() self._check_num_and_type_of_observers(model, 3) def test_observe_method(self): model = MyConvLinearModule() self._check_observer_method(model, 3) def test_weight_only_observers(self): model = MyConvLinearModule() qconfig_dict = {"": default_dynamic_qconfig} inputs = model.get_example_inputs() scripted_model = OnDevicePTQUtils.insert_observers(model, qconfig_dict) observe_forward_graph = scripted_model.observe_forward.graph num_weight_only_observers = 0 for node in observe_forward_graph.nodes(): if (self._observer_is_weight_only(node)): num_weight_only_observers += 1 self.assertEqual(num_weight_only_observers, 3) class TestOnDeviceDynamicPTQInsertQuantDequant(TestCase): def _validate_quant_dequant_nodes(self, model, num_nodes, per_channel=0): quantize_forward_graph = model.quantize_forward.graph quantize_per_tensor = quantize_per_channel = 0 for n in quantize_forward_graph.nodes(): if "aten::quantize_per_tensor" in n.kind(): quantize_per_tensor += 1 if "aten::quantize_per_channel" in n.kind(): quantize_per_channel += 1 self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) def _validate_calculate_qparams(self, model, num_nodes): quantize_forward_graph = model.quantize_forward.graph num_calculate_qparams = 0 for n in quantize_forward_graph.nodes(): if OnDevicePTQUtils.is_calculate_qparam(n): num_calculate_qparams += 1 self.assertEqual(num_calculate_qparams, num_nodes) def _validate_no_observer_forward(self, model): quantize_forward_graph = model.quantize_forward.graph for n in quantize_forward_graph.nodes(): if (n.kind() == "prim::CallMethod") and n.s("name") == "forward": if (OnDevicePTQUtils.is_value_type_observer(n.inputsAt(0))): return False return True def _check_quant_dequant_and_calc_qparams(self, model, num_nodes): qconfig_dict = {"" : default_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_quant_dequant_nodes(m, num_nodes) self._validate_calculate_qparams(m, num_nodes) self._validate_no_observer_forward(m) qconfig_dict = {"" : per_channel_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_quant_dequant_nodes(m, num_nodes, num_nodes) self._validate_calculate_qparams(m, num_nodes) self._validate_no_observer_forward(m) def _check_quantize_forward_runs(self, model): inputs = model.get_example_inputs() qconfig_dict = {"" : default_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) m.observe_forward(*inputs) m.quantize_forward(*inputs) qconfig_dict = {"" : per_channel_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) # First must run observe forward to record the stats to produce # correct scales and zero points m.observe_forward(*inputs) m.quantize_forward(*inputs) def test_num_quant_dequant_nodes(self): model = LinearAddModel() self._check_quant_dequant_and_calc_qparams(model, 2) model = MyConvLinearModule() self._check_quant_dequant_and_calc_qparams(model, 3) def test_quantize_forward_runs(self): model = LinearAddModel() self._check_quantize_forward_runs(model) model = MyConvLinearModule() self._check_quantize_forward_runs(model) class TestOnDeviceDynamicPTQFinalize(TestCase): def _validate_packed_params(self, model, num_nodes, per_channel=0): quantize_forward_graph = model.quantize_forward.graph quantize_per_tensor = quantize_per_channel = 0 linear_prepack = 0 linear_prepack_uses = 0 for n in quantize_forward_graph.nodes(): if n.kind() == 'prim::SetAttr': maybe_packed_param_value = n.inputsAt(1) maybe_packed_param = maybe_packed_param_value.node() if maybe_packed_param.kind() == 'quantized::linear_prepack': linear_prepack += 1 linear_prepack_uses += len(maybe_packed_param_value.uses()) if OnDevicePTQUtils.is_per_channel_quantized_packed_param(maybe_packed_param): quantize_per_channel += 1 else: quantize_per_tensor += 1 self.assertEqual(quantize_per_tensor + quantize_per_channel, num_nodes) self.assertEqual(quantize_per_channel, per_channel) self.assertEqual(linear_prepack, num_nodes) self.assertEqual(linear_prepack_uses, num_nodes) def _validate_no_linear_unpack(self, model): quantize_forward_graph = model.quantize_forward.graph for n in quantize_forward_graph.nodes(): if n.kind() == 'quantized::linear_unpack': return False return True def _validate_setattr_fp_weights(self, model, num_nodes): quantize_forward_graph = model.quantize_forward.graph fp_weights_setattr = 0 fp_weight_names = [] for n in quantize_forward_graph.nodes(): if n.kind() == 'prim::SetAttr': maybe_packed_param = n.inputsAt(1).node() if maybe_packed_param.kind() == 'quantized::linear_prepack': weight_name = OnDevicePTQUtils.get_linear_packed_param_fp_weight(maybe_packed_param) fp_weight_names.append(weight_name) for n in quantize_forward_graph.nodes(): # This is basically detecting # %x = prim::Constant # = prim::SetAttr()(module_value, x) # Thus making sure that the original fp weights are # reset if n.kind() == 'prim::SetAttr': weight_name = n.s('name') if weight_name in fp_weight_names: maybe_constant = n.inputsAt(1).node() if maybe_constant.kind() == 'prim::Constant': fp_weights_setattr += 1 self.assertEqual(fp_weights_setattr, num_nodes) def _validate_quantized_forward(self, model, num_nodes): quantized_forward_graph = model.quantized_forward.graph quantize_per_tensor = quantize_per_channel = 0 quantized_linear_dynamic = 0 linear_packed_params = 0 num_setattr = 0 for n in quantized_forward_graph.nodes(): if "aten::quantize_per_tensor" in n.kind(): quantize_per_tensor += 1 if "aten::quantize_per_channel" in n.kind(): quantize_per_channel += 1 if "quantized::linear_dynamic" in n.kind(): quantized_linear_dynamic += 1 if n.kind() == 'prim::GetAttr': output = n.outputsAt(0) output_type = output.type() if "LinearPackedParamsBase" in output_type.str(): linear_packed_params += 1 if n.kind() == 'prim::SetAttr': num_setattr += 1 self.assertEqual(quantize_per_tensor, 0) self.assertEqual(quantize_per_channel, 0) self.assertEqual(quantized_linear_dynamic, num_nodes) self.assertEqual(linear_packed_params, num_nodes) # self.assertEqual(num_setattr, 0) def _check_quantize_forward(self, model, num_nodes): qconfig_dict = {"" : default_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_packed_params(m, num_nodes) self._validate_no_linear_unpack(m) self._validate_setattr_fp_weights(m, num_nodes) qconfig_dict = {"" : per_channel_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_packed_params(m, num_nodes, num_nodes) self._validate_no_linear_unpack(m) self._validate_setattr_fp_weights(m, num_nodes) def _check_quantized_forward(self, model, num_nodes): qconfig_dict = {"" : default_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_quantized_forward(m, num_nodes) qconfig_dict = {"" : per_channel_dynamic_qconfig} m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) self._validate_quantized_forward(m, num_nodes) def _check_against_ref_dynamic_ptq(self, model): model.eval() inputs = model.get_example_inputs() ref_m = torch.jit.script(model) torch._C._jit_pass_inline(ref_m.graph) qconfig_dict = {"" : default_dynamic_qconfig} ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) ref_m = convert_dynamic_jit(ref_m) ref_output = ref_m(*inputs) m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) m.observe_forward(*inputs) m.quantize_forward(*inputs) output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) thrown = False try: m(*inputs) except Exception as e: thrown = True self.assertTrue(thrown) # test with per channel quant ref_m = torch.jit.script(model) torch._C._jit_pass_inline(ref_m.graph) qconfig_dict = {"" : per_channel_dynamic_qconfig} ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) ref_m = convert_dynamic_jit(ref_m) ref_output = ref_m(*inputs) m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) m.observe_forward(*inputs) m.quantize_forward(*inputs) output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) thrown = False try: m(*inputs) except Exception as e: thrown = True self.assertTrue(thrown) def _check_serdes_and_device_side_api_helper(self, model, check_device_side_api=False): model.eval() inputs = model.get_example_inputs() ref_m = torch.jit.script(model) torch._C._jit_pass_inline(ref_m.graph) qconfig_dict = {"" : default_dynamic_qconfig} ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) ref_m = convert_dynamic_jit(ref_m) buffer = io.BytesIO() torch.jit.save(ref_m, buffer) buffer.seek(0) ref_m = torch.jit.load(buffer) ref_output = ref_m(*inputs) if not check_device_side_api: m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) buffer = io.BytesIO() torch.jit.save(m, buffer) buffer.seek(0) m = torch.jit.load(buffer) m.reset_observers_forward() m.observe_forward(*inputs) m.quantize_forward(*inputs) output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) else: # check for lite interpreter m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) first_input, = inputs rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype) m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )]) buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) buffer.seek(0) m = _load_for_lite_interpreter(buffer) # Error here torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") self.assertFalse(m.find_method("quantized_forward")) self.assertFalse(m.find_method("quantize_forward")) self.assertFalse(m.find_method("observe_forward")) self.assertFalse(m.find_method("reset_observers_forward")) output = m(*inputs) self.assertTrue(torch.allclose(ref_output, output)) # Now serialize to flabuffer and load from fb and check dict: Dict[str, str] = {} bytes = torch._C._save_mobile_module_to_bytes(m._c, dict) m = LiteScriptModule(torch._C._load_mobile_module_from_bytes(bytes)) fb_output = m(*inputs) self.assertTrue(torch.allclose(ref_output, fb_output)) model.eval() inputs = model.get_example_inputs() ref_m = torch.jit.script(model) torch._C._jit_pass_inline(ref_m.graph) qconfig_dict = {"" : per_channel_dynamic_qconfig} ref_m = prepare_dynamic_jit(ref_m, qconfig_dict) ref_m = convert_dynamic_jit(ref_m) buffer = io.BytesIO() torch.jit.save(ref_m, buffer) buffer.seek(0) ref_m = torch.jit.load(buffer) ref_output = ref_m(*inputs) if not check_device_side_api: m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) buffer = io.BytesIO() torch.jit.save(m, buffer) buffer.seek(0) m = torch.jit.load(buffer) m.reset_observers_forward() m.observe_forward(*inputs) m.quantize_forward(*inputs) output = m.quantized_forward(*inputs) self.assertTrue(torch.allclose(ref_output, output)) else: # check for lite interpreter m = OnDevicePTQUtils.ptq_dynamic_quantize(model, qconfig_dict) first_input, = inputs rand_input = bundled_inputs.bundle_randn(first_input.size(), dtype=first_input.dtype) m = bundled_inputs.bundle_inputs(m, inputs=[(rand_input, )]) buffer = io.BytesIO(m._save_to_buffer_for_lite_interpreter()) buffer.seek(0) m = _load_for_lite_interpreter(buffer) # Error here torch._C._quantize_ondevice_ptq_dynamic(m._c, "forward") self.assertFalse(m.find_method("quantized_forward")) self.assertFalse(m.find_method("quantize_forward")) self.assertFalse(m.find_method("observe_forward")) self.assertFalse(m.find_method("reset_observers_forward")) output = m(*inputs) self.assertTrue(torch.allclose(ref_output, output)) def _check_serialization_deserialization(self, model): self._check_serdes_and_device_side_api_helper(model, False) def _check_device_side_api(self, model): self._check_serdes_and_device_side_api_helper(model, True) def test_quantize_forward(self): model = LinearAddModel() self._check_quantize_forward(model, 2) model = MyConvLinearModule() self._check_quantize_forward(model, 3) def test_quantized_forward(self): model = LinearAddModel() self._check_quantized_forward(model, 2) model = MyConvLinearModule() self._check_quantized_forward(model, 3) def test_against_offdevice_dynamic_ptq(self): model = LinearAddModel() self._check_against_ref_dynamic_ptq(model) model = MyConvLinearModule() self._check_against_ref_dynamic_ptq(model) def test_serialization_deserialization(self): model = MyConvLinearModule() self._check_serialization_deserialization(model) def test_device_side_api(self): model = MyConvLinearModule() self._check_device_side_api(model)