mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[quant][graphmode][fx] Add support for fp16 bmm pattern (#52808)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52808 Add support for producing fp16 bmm pattern Test Plan: python test/test_quantization.py TestQuantizeFxOps.test_bmm Imported from OSS Reviewed By: vkuzo Differential Revision: D26655616 fbshipit-source-id: 1d0639303e5ca2ca4ceae08d03ebc3b25256de57
This commit is contained in:
parent
4d94ee566e
commit
2c44b256d8
|
|
@ -87,6 +87,8 @@ from typing import Callable
|
|||
|
||||
class BinaryOp(torch.nn.Module):
|
||||
def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
|
||||
""" ibinary_op means inplace binary op
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
|
||||
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
|
||||
|
|
@ -104,9 +106,11 @@ class BinaryOp(torch.nn.Module):
|
|||
|
||||
class BinaryOpNonQuantizedInput(torch.nn.Module):
|
||||
def __init__(self, binary_op, ibinary_op, is_inplace, is_scalar):
|
||||
""" ibinary_op means inplace binary op
|
||||
"""
|
||||
super().__init__()
|
||||
self.is_scalar = is_scalar
|
||||
self.op = ibinary_op if is_inplace else binary_op
|
||||
self.op = ibinary_op if ibinary_op and is_inplace else binary_op
|
||||
|
||||
def forward(self, x, y):
|
||||
y = 3 if self.is_scalar else y
|
||||
|
|
@ -116,6 +120,8 @@ class BinaryOpNonQuantizedInput(torch.nn.Module):
|
|||
class BinaryOpRelu(torch.nn.Module):
|
||||
def __init__(self, binary_op, ibinary_op, is_inplace, is_functional_relu,
|
||||
is_scalar):
|
||||
""" ibinary_op means inplace binary op
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = torch.nn.Conv2d(1, 1, 1).float()
|
||||
self.conv2 = torch.nn.Conv2d(1, 1, 1).float()
|
||||
|
|
@ -2237,6 +2243,40 @@ class TestQuantizeFxOps(QuantizationTestCase):
|
|||
operator.mul, operator.imul, torch.ops.quantized.mul)
|
||||
self._test_binary_op_float16_impl(operator.mul, operator.imul)
|
||||
|
||||
def test_bmm(self):
|
||||
class BMMMethod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
return x.bmm(y)
|
||||
|
||||
data = (torch.randn(1, 1, 1, dtype=torch.float),
|
||||
torch.randn(1, 1, 1, dtype=torch.float))
|
||||
quant_type = QuantType.STATIC
|
||||
# testing for fp16 static quant
|
||||
# we are producing fp16 patterns
|
||||
custom_qconfig_dict = {
|
||||
"object_type": [(torch.bmm, float16_static_qconfig),
|
||||
("bmm", float16_static_qconfig)]
|
||||
}
|
||||
node_occurrence = {
|
||||
# input_bmm1, input_bmm2, output_bmm
|
||||
ns.call_method("to"): 3
|
||||
}
|
||||
self.checkGraphModeFxOp(
|
||||
BinaryOpNonQuantizedInput(torch.bmm, None, False, False), data, quant_type,
|
||||
expected_node_occurrence=node_occurrence,
|
||||
custom_qconfig_dict=custom_qconfig_dict)
|
||||
|
||||
# TODO: support call_method("bmm")
|
||||
# we can transform call_method("bmm") to call_function(torch.bmm)
|
||||
# self.checkGraphModeFxOp(
|
||||
# BMMMethod(), data, quant_type,
|
||||
# expected_node_occurrence=node_occurrence,
|
||||
# custom_qconfig_dict=custom_qconfig_dict,
|
||||
# print_debug_info=True)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_add_relu(self):
|
||||
self._test_binary_op_relu_int8_impl(
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ from abc import ABC, abstractmethod
|
|||
import operator
|
||||
import warnings
|
||||
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from typing import Any, Callable, Dict, Union, Optional, Tuple, List
|
||||
|
||||
# -------------------------
|
||||
# Pattern Registrations
|
||||
|
|
@ -77,6 +77,7 @@ class QuantizeHandler(ABC):
|
|||
@register_quant_pattern(operator.mul)
|
||||
@register_quant_pattern(torch.add)
|
||||
@register_quant_pattern(torch.mul)
|
||||
@register_quant_pattern(torch.bmm)
|
||||
@register_quant_pattern((torch.nn.ReLU, operator.add))
|
||||
@register_quant_pattern((torch.nn.ReLU, operator.mul))
|
||||
@register_quant_pattern((torch.nn.ReLU, torch.add))
|
||||
|
|
@ -93,9 +94,9 @@ class BinaryOp(QuantizeHandler):
|
|||
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
||||
self.relu_node = node
|
||||
node = node.args[0] # type: ignore
|
||||
self.bop_node = node
|
||||
self.bop = node.target
|
||||
self.num_node_args = len([a for a in self.bop_node.args[:2] if isinstance(a, Node)])
|
||||
self.binary_op_node = node
|
||||
self.binary_op = node.target
|
||||
self.num_node_args = len([a for a in self.binary_op_node.args[:2] if isinstance(a, Node)])
|
||||
qbin_op_mapping: Dict[Union[Callable, str], Callable] = {
|
||||
operator.add: torch.ops.quantized.add,
|
||||
torch.add: torch.ops.quantized.add,
|
||||
|
|
@ -109,9 +110,11 @@ class BinaryOp(QuantizeHandler):
|
|||
torch.mul: torch.ops.quantized.mul_relu,
|
||||
}
|
||||
# corresponding quantized op
|
||||
self.qop = qbin_relu_op_mapping[self.bop] \
|
||||
if self.relu_node is not None \
|
||||
else qbin_op_mapping[self.bop] # type: ignore
|
||||
self.quantized_binary_op: Optional[Callable] = None
|
||||
if self.binary_op in qbin_op_mapping:
|
||||
self.quantized_binary_op = qbin_relu_op_mapping[self.binary_op] \
|
||||
if self.relu_node is not None \
|
||||
else qbin_op_mapping[self.binary_op] # type: ignore
|
||||
|
||||
def convert(self, quantizer: QuantizerCls, node: Node, load_arg: Callable,
|
||||
is_reference: bool = False,
|
||||
|
|
@ -121,21 +124,32 @@ class BinaryOp(QuantizeHandler):
|
|||
# static quint8 qint8
|
||||
|
||||
# tuple (activation_dtype, weight_dtype, compute_dtype)
|
||||
supported_dtypes = [
|
||||
# these are supported types for common binary ops like add/mul etc.
|
||||
all_bop_dtypes = [
|
||||
(torch.quint8, torch.qint8, None),
|
||||
(torch.float16, torch.float16, None),
|
||||
]
|
||||
float16_dtypes = [
|
||||
(torch.float16, torch.float16, None)
|
||||
]
|
||||
supported_dtypes : Dict[Union[Callable, str], List[Tuple[torch.dtype, torch.dtype, None]]] = {
|
||||
operator.add: all_bop_dtypes,
|
||||
torch.add: all_bop_dtypes,
|
||||
operator.mul: all_bop_dtypes,
|
||||
torch.mul: all_bop_dtypes,
|
||||
torch.bmm: float16_dtypes,
|
||||
}
|
||||
|
||||
qconfig = quantizer.qconfig_map[node.name]
|
||||
dtypes = get_qconfig_dtypes(qconfig)
|
||||
# leave the op unquantized if the dtype combination is not supported
|
||||
if dtypes not in supported_dtypes:
|
||||
if dtypes not in supported_dtypes[self.binary_op]:
|
||||
warnings.warn(
|
||||
"dtype combination: {} is not "
|
||||
"supported by add/mul "
|
||||
"supported dtype combinations are: {}".format(dtypes, supported_dtypes))
|
||||
"supported by {} "
|
||||
"supported dtype combinations are: {}".format(dtypes, self.binary_op, supported_dtypes[self.binary_op]))
|
||||
if self.relu_node:
|
||||
op_out = quantizer.quantized_graph.node_copy(self.bop_node, load_arg(quantized=False))
|
||||
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
||||
relu_args = [op_out]
|
||||
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
||||
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
||||
|
|
@ -145,16 +159,17 @@ class BinaryOp(QuantizeHandler):
|
|||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=False))
|
||||
|
||||
if dtypes in [(torch.quint8, torch.qint8, None)]:
|
||||
assert self.quantized_binary_op is not None
|
||||
if self.num_node_args == 1:
|
||||
# add/mul scalar
|
||||
if isinstance(self.bop_node.args[0], Node):
|
||||
if isinstance(self.binary_op_node.args[0], Node):
|
||||
quantized_index = 0
|
||||
else:
|
||||
quantized_index = 1
|
||||
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', self.qop,
|
||||
load_arg(quantized=[quantized_index])(self.bop_node.args), self.bop_node.kwargs)
|
||||
'call_function', self.quantized_binary_op,
|
||||
load_arg(quantized=[quantized_index])(self.binary_op_node.args), self.binary_op_node.kwargs)
|
||||
else:
|
||||
activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
|
|
@ -166,15 +181,16 @@ class BinaryOp(QuantizeHandler):
|
|||
op = torch.ops.quantized.add_relu
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
kwargs = {**self.bop_node.kwargs}
|
||||
add_args = (*load_arg(quantized=True)(self.bop_node.args), scale_arg, zero_point_arg)
|
||||
kwargs = {**self.binary_op_node.kwargs}
|
||||
add_args = (*load_arg(quantized=True)(self.binary_op_node.args), scale_arg, zero_point_arg)
|
||||
op = quantizer.quantized_graph.create_node(
|
||||
'call_function', self.qop, add_args, kwargs)
|
||||
'call_function', self.quantized_binary_op, add_args, kwargs)
|
||||
return op
|
||||
elif dtypes in [(torch.float16, torch.float16, None)]:
|
||||
else:
|
||||
assert dtypes == (torch.float16, torch.float16, None)
|
||||
# TODO (refactor) this is duplicated, maybe have a helper function
|
||||
if self.relu_node:
|
||||
op_out = quantizer.quantized_graph.node_copy(self.bop_node, load_arg(quantized=False))
|
||||
op_out = quantizer.quantized_graph.node_copy(self.binary_op_node, load_arg(quantized=False))
|
||||
relu_args = [op_out]
|
||||
relu_args.extend(load_arg(quantized=False)(self.relu_node.args[1:]))
|
||||
relu_kwargs = load_arg(quantized=False)(self.relu_node.kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user