[quant][pt2e] Add convert callback to Observer module (#115001)

Summary:
This is to allow easier extension of quant workflow in the future, as we are seening more
diverse ways of doing quantization

putting up this for feedbacks first

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_observer_callback

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115001
Approved by: https://github.com/kimishpatel
This commit is contained in:
Jerry Zhang 2023-12-07 18:35:06 -08:00 committed by PyTorch MergeBot
parent ca15671c30
commit cc8f6f56dc
2 changed files with 112 additions and 6 deletions

View File

@ -1762,3 +1762,106 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
m = self._quantize(m, quantizer, example_inputs)
# make sure it runs
m(*example_inputs)
def test_observer_callback(self):
from torch.library import Library, impl
test_lib = Library("test_int4", "DEF")
test_lib.define("quantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
@impl(test_lib, "quantize_per_tensor_int4", "CompositeExplicitAutograd")
def quantize_per_tensor_int4(
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
inv_scale = 1.0 / scale
return torch.clamp(torch.round(input * inv_scale) + zero_point, 0, 15).to(torch.uint8).view(torch.bits8)
test_lib.define("dequantize_per_tensor_int4(Tensor input, float scale, int zero_point) -> Tensor")
@impl(test_lib, "dequantize_per_tensor_int4", "CompositeExplicitAutograd")
def dequantize_per_tensor_int4(
input: torch.Tensor,
scale: float,
zero_point: int,
) -> torch.Tensor:
return (input.view(torch.uint8).to(torch.float32) - zero_point) * scale
from torch.ao.quantization.observer import ObserverBase
class Int4Observer(ObserverBase):
def __init__(self, *args, **kwargs):
# just faking a dtype here
super().__init__(dtype=torch.int8)
def forward(self, x):
return x
def calculate_qparams(self, **kwargs):
pass
def convert(self, model: torch.fx.GraphModule, observer_node: Node):
with model.graph.inserting_before(observer_node):
q_node = model.graph.call_function(
torch.ops.test_int4.quantize_per_tensor_int4, (observer_node.args[0], 1.0, 0), {})
dq_node = model.graph.call_function(
torch.ops.test_int4.dequantize_per_tensor_int4, (q_node, 1.0, 0), {})
observer_node.replace_all_uses_with(dq_node)
model.graph.erase_node(observer_node)
class BackendAQuantizer(Quantizer):
def annotate(self, model: torch.fx.GraphModule) -> torch.fx.GraphModule:
for node in model.graph.nodes:
if (
node.op == "call_function"
and node.target == torch.ops.aten.add.Tensor
):
input_act0 = node.args[0]
assert isinstance(input_act0, Node)
input_act1 = node.args[1]
assert isinstance(input_act1, Node)
act_qspec = QuantizationSpec(
dtype=torch.uint8,
quant_min=0,
quant_max=255,
qscheme=torch.per_tensor_affine,
is_dynamic=False,
observer_or_fake_quant_ctr=Int4Observer,
)
node.meta["quantization_annotation"] = QuantizationAnnotation(
input_qspec_map={
input_act0: act_qspec,
input_act1: act_qspec,
},
output_qspec=act_qspec,
_annotated=True,
)
def validate(self, model: torch.fx.GraphModule) -> None:
pass
class M(torch.nn.Module):
def forward(self, x1, x2):
return x1 + x2
example_inputs = (torch.randn(1, 3, 5, 5), torch.randn(1, 3, 5, 5),)
node_occurrence = {
# two for input of the first conv, one for output for the first conv
torch.ops.test_int4.quantize_per_tensor_int4: 3,
torch.ops.test_int4.dequantize_per_tensor_int4: 3,
}
node_list = [
torch.ops.test_int4.dequantize_per_tensor_int4,
torch.ops.test_int4.dequantize_per_tensor_int4,
torch.ops.aten.add.Tensor,
torch.ops.test_int4.quantize_per_tensor_int4,
]
self._test_quantizer(
M().eval(),
example_inputs,
BackendAQuantizer(),
node_occurrence,
node_list,
)

View File

@ -88,8 +88,7 @@ _QSCHEME_TO_CHOOSE_QPARAMS_OP = {
}
def _replace_observer_with_quantize_dequantize_node_decomposed(
model: torch.nn.Module,
graph: Graph,
model: torch.fx.GraphModule,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
@ -105,10 +104,14 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
or quantize_per_channel and dequantize_per_channel
"""
graph = model.graph
assert modules is not None
assert isinstance(node.target, str)
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
activation_post_process = modules[node.target]
if hasattr(activation_post_process, "convert"):
activation_post_process.convert(model, node)
return
# skip replacing observers to quant/dequant nodes if the qconfigs of all
# consumers and producers of this observer are None
skip_replacement = all(_has_none_qconfig(n, node_name_to_qconfig) for n in
@ -312,8 +315,7 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
# activation_post_process is supported
def _replace_observer_with_quantize_dequantize_node(
model: torch.nn.Module,
graph: Graph,
model: torch.fx.GraphModule,
node: Node,
modules: Dict[str, torch.nn.Module],
node_name_to_scope: Dict[str, Tuple[str, type]],
@ -328,6 +330,7 @@ def _replace_observer_with_quantize_dequantize_node(
"""
assert modules is not None
assert isinstance(node.target, str)
graph = model.graph
module_path, prefix = _get_module_path_and_prefix(node, node_name_to_scope, node_name_to_qconfig)
activation_post_process = modules[node.target]
# skip replacing observers to quant/dequant nodes if the qconfigs of all
@ -1062,11 +1065,11 @@ def convert(
else:
if is_decomposed:
_replace_observer_with_quantize_dequantize_node_decomposed(
model, model.graph, node, modules, node_name_to_scope,
model, node, modules, node_name_to_scope,
node_name_to_qconfig)
else:
_replace_observer_with_quantize_dequantize_node(
model, model.graph, node, modules, node_name_to_scope,
model, node, modules, node_name_to_scope,
node_name_to_qconfig)
elif isinstance(mod, DeQuantStub):
_replace_observer_or_dequant_stub_with_dequantize_node(node, model.graph)