mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][pt2e] Add convert callback to Observer module (#115001)
Summary: This is to allow easier extension of quant workflow in the future, as we are seening more diverse ways of doing quantization putting up this for feedbacks first Test Plan: python test/test_quantization.py TestQuantizePT2E.test_observer_callback Reviewers: Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/115001 Approved by: https://github.com/kimishpatel
This commit is contained in:
parent
ca15671c30
commit
cc8f6f56dc
|
|
@ -1762,3 +1762,106 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
m = self._quantize(m, quantizer, example_inputs)
|
||||
# make sure it runs
|
||||
m(*example_inputs)
|
||||
|
||||
def test_observer_callback(self):
|
||||
from torch.library import Library, impl
|
||||
test_lib = Library("test_int4", "DEF")
|
||||
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
|
||||
|
||||
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
|
||||
def quantize_per_tensor_int4(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
) -> torch.Tensor:
|
||||
inv_scale = 1.0 / scale
|
||||
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)
|
||||
|
||||
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
|
||||
|
||||
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
|
||||
def dequantize_per_tensor_int4(
|
||||
input: torch.Tensor,
|
||||
scale: float,
|
||||
zero_point: int,
|
||||
) -> torch.Tensor:
|
||||
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
|
||||
|
||||
from torch.ao.quantization.observer import ObserverBase
|
||||
|
||||
class Int4Observer(ObserverBase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
# just faking a dtype here
|
||||
super().__init__(dtype=torch.int8)
|
||||
|
||||
def forward(self, x):
|
||||
return x
|
||||
|
||||
def calculate_qparams(self, **kwargs):
|
||||
pass
|
||||
|
||||
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
|
||||
with model.graph.inserting_before(observer_node):
|
||||
q_node = model.graph.call_function(
|
||||
torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {})
|
||||
dq_node = model.graph.call_function(
|
||||
torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {})
|
||||
observer_node.replace_all_uses_with(dq_node)
|
||||
model.graph.erase_node(observer_node)
|
||||
|
||||
class BackendAQuantizer(Quantizer):
|
||||
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
|
||||
for node in model.graph.nodes:
|
||||
if (
|
||||
node.op == "call_function"
|
||||
and node.target == torch.ops.aten.add.Tensor
|
||||
):
|
||||
input_act0 = node.args[0]
|
||||
assert isinstance(input_act0, Node)
|
||||
input_act1 = node.args[1]
|
||||
assert isinstance(input_act1, Node)
|
||||
|
||||
act_qspec = QuantizationSpec(
|
||||
dtype=torch.uint8,
|
||||
quant_min=0,
|
||||
quant_max=255,
|
||||
qscheme=torch.per_tensor_affine,
|
||||
is_dynamic=False,
|
||||
observer_or_fake_quant_ctr=Int4Observer,
|
||||
)
|
||||
node.meta["quantization_annotation"] = QuantizationAnnotation(
|
||||
input_qspec_map={
|
||||
input_act0: act_qspec,
|
||||
input_act1: act_qspec,
|
||||
},
|
||||
output_qspec=act_qspec,
|
||||
_annotated=True,
|
||||
)
|
||||
|
||||
def validate(self, model: torch.fx.GraphModule) -> None:
|
||||
pass
|
||||
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x1, x2):
|
||||
return x1 + x2
|
||||
|
||||
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
|
||||
node_occurrence = {
|
||||
# two for input of the first conv, one for output for the first conv
|
||||
torch.ops.test_int4.quantize_per_tensor_int4: 3,
|
||||
torch.ops.test_int4.dequantize_per_tensor_int4: 3,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.test_int4.dequantize_per_tensor_int4,
|
||||
torch.ops.test_int4.dequantize_per_tensor_int4,
|
||||
torch.ops.aten.add.Tensor,
|
||||
torch.ops.test_int4.quantize_per_tensor_int4,
|
||||
]
|
||||
self._test_quantizer(
|
||||
M().eval(),
|
||||
example_inputs,
|
||||
BackendAQuantizer(),
|
||||
node_occurrence,
|
||||
node_list,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -88,8 +88,7 @@ _QSCHEME_TO_CHOOSE_QPARAMS_OP = {
|
|||
}
|
||||
|
||||
def _replace_observer_with_quantize_dequantize_node_decomposed(
|
||||
model: torch.nn.Module,
|
||||
graph: Graph,
|
||||
model: torch.fx.GraphModule,
|
||||
node: Node,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
|
|
@ -105,10 +104,14 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
|
||||
or quantize_per_channel and dequantize_per_channel
|
||||
"""
|
||||
graph = model.graph
|
||||
assert modules is not None
|
||||
assert isinstance(node.target, str)
|
||||
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
||||
activation_post_process = modules[node.target]
|
||||
if hasattr(activation_post_process, "convert"):
|
||||
activation_post_process.convert(model, node)
|
||||
return
|
||||
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
||||
# consumers and producers of this observer are None
|
||||
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
|
||||
|
|
@ -312,8 +315,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
# activation_post_process is supported
|
||||
|
||||
def _replace_observer_with_quantize_dequantize_node(
|
||||
model: torch.nn.Module,
|
||||
graph: Graph,
|
||||
model: torch.fx.GraphModule,
|
||||
node: Node,
|
||||
modules: Dict[str, torch.nn.Module],
|
||||
node_name_to_scope: Dict[str, Tuple[str, type]],
|
||||
|
|
@ -328,6 +330,7 @@ def _replace_observer_with_quantize_dequantize_node(
|
|||
"""
|
||||
assert modules is not None
|
||||
assert isinstance(node.target, str)
|
||||
graph = model.graph
|
||||
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
|
||||
activation_post_process = modules[node.target]
|
||||
# skip replacing observers to quant/dequant nodes if the qconfigs of all
|
||||
|
|
@ -1062,11 +1065,11 @@ def convert(
|
|||
else:
|
||||
if is_decomposed:
|
||||
_replace_observer_with_quantize_dequantize_node_decomposed(
|
||||
model, model.graph, node, modules, node_name_to_scope,
|
||||
model, node, modules, node_name_to_scope,
|
||||
node_name_to_qconfig)
|
||||
else:
|
||||
_replace_observer_with_quantize_dequantize_node(
|
||||
model, model.graph, node, modules, node_name_to_scope,
|
||||
model, node, modules, node_name_to_scope,
|
||||
node_name_to_qconfig)
|
||||
elif isinstance(mod, DeQuantStub):
|
||||
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user