import torch from torch.fx.graph import ( Node, ) from ..quantization_mappings import ( get_static_quant_module_class, get_quantized_operator, ) from .pattern_utils import ( register_quant_pattern, register_dynamic_quant_pattern, ) from .utils import ( _parent_name, quantize_node, get_per_tensor_qparams, ) from abc import ABC, abstractmethod import operator # ------------------------- # Pattern Registrations # ------------------------- # 1. Post Training Static Quantization and Quantization Aware Training Patterns # Base Pattern Handler class QuantizeHandler(ABC): """ Base handler class for the quantizer patterns """ def __init__(self, quantizer, node): """ Records pattern information in __init__, which will be used in convert """ # this is an indicator of whether all the inputs are Node or not # since some op might be quantized differently depending on whether # all inputs are tensors or not, e.g. add/mul self.all_nodes = True @abstractmethod def convert(self, quantizer, node, load_arg, debug=False): """ Convert the given node to a quantized node and insert it to the quantized graph """ return NotImplemented @register_quant_pattern(operator.add) @register_quant_pattern((torch.nn.ReLU, operator.add)) @register_quant_pattern((torch.nn.functional.relu, operator.add)) class Add(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] assert node.op == 'call_function' and node.target == operator.add self.add_node = node self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]]) def convert(self, quantizer, node, load_arg, debug=False): if not self.all_nodes: # add scalar if self.relu_node is not None: op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) if self.relu_node is not None: op = torch.ops.quantized.add_relu else: op = torch.ops.quantized.add kwargs = self.add_node.kwargs kwargs.update({'scale': scale, 'zero_point': zero_point}) return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs) @register_quant_pattern(operator.mul) @register_quant_pattern((torch.nn.ReLU, operator.mul)) @register_quant_pattern((torch.nn.functional.relu, operator.mul)) class Mul(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] assert node.op == 'call_function' and node.target == operator.mul self.mul_node = node self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]]) def convert(self, quantizer, node, load_arg, debug=False): if not self.all_nodes: # mul scalar if self.relu_node is not None: op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul return quantizer.quantized_graph.create_node( 'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs) else: activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) if self.relu_node is not None: op = torch.ops.quantized.mul_relu else: op = torch.ops.quantized.mul kwargs = self.mul_node.kwargs kwargs.update({'scale': scale, 'zero_point': zero_point}) return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs) @register_quant_pattern(torch.cat) class Cat(QuantizeHandler): def convert(self, quantizer, node, load_arg, debug=False): if not self.all_nodes: return NotImplemented activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) kwargs = load_arg(quantized=False)(node.kwargs) kwargs.update({'scale': scale, 'zero_point': zero_point}) return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs) # handle conv, maybe followed by relu # NB: matching order is reversed, that is we match from the bottom of this list to the beginning @register_quant_pattern(torch.nn.Conv1d) @register_quant_pattern(torch.nn.Conv2d) @register_quant_pattern(torch.nn.Conv3d) @register_quant_pattern(torch.nn.functional.conv2d) @register_quant_pattern(torch.nn.qat.Conv2d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU1d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU2d) @register_quant_pattern(torch.nn.intrinsic.ConvReLU3d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d) @register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d) @register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d)) @register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d)) # just for error checks @register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d)) class ConvRelu(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] self.conv_node = node if node.op == 'call_module': self.conv = quantizer.modules[self.conv_node.target] def convert(self, quantizer, node, load_arg, debug=False): # TODO: debug option for conv module if self.conv_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \ 'please make sure to run fusion before prepare' # 1. attach activation post process to module if type(self.conv) in [ torch.nn.intrinsic.ConvReLU1d, torch.nn.intrinsic.ConvReLU2d, torch.nn.intrinsic.ConvReLU3d ]: self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name] else: self.conv.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class qconv_cls = get_static_quant_module_class(type(self.conv)) quantized = qconv_cls.from_float(self.conv) parent_name, name = _parent_name(self.conv_node.target) setattr(quantizer.modules[parent_name], name, quantized) return quantizer.quantized_graph.create_node( 'call_module', self.conv_node.target, (load_arg(quantized=True)(self.conv_node.args[0]),), {}) elif self.conv_node.op == 'call_function': if self.relu_node is not None: raise Exception("functional conv + relu is not supported yet") if debug: args = load_arg(quantized=[0, 1])(self.conv_node.args) args = load_arg(quantized=False)(self.conv_node.args) kwargs = load_arg(quantized=False)(self.conv_node.kwargs) conv_out = quantizer.quantized_graph.create_node( 'call_function', torch.nn.functional.conv2d, args, kwargs) root_module = quantizer.modules[''] return quantize_node( root_module, quantizer.quantized_graph, conv_out, quantizer.activation_post_process_map[self.conv_node.name]) else: assert len(self.conv_node.args) == 7, \ 'only conv2d calls with all arguments specified is support right now in debug=False option' args = load_arg(quantized=[0, 1])(self.conv_node.args) # pack weight weight = load_arg(quantized=True)(self.conv_node.args[1]) other_args = load_arg(quantized=False)(self.conv_node.args[2:]) prepack_args = tuple([weight] + list(other_args)) packed_weight = quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {}) # construct conv input conv_input = load_arg(quantized=True)(self.conv_node.args[0]) activation_post_process = quantizer.activation_post_process_map[self.conv_node.name] scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) qconv_args = (conv_input, packed_weight, scale, zero_point) kwargs = load_arg(quantized=False)(self.conv_node.kwargs) return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs) # handle linear, maybe followed by relu @register_quant_pattern(torch.nn.Linear) @register_quant_pattern(torch.nn.functional.linear) @register_quant_pattern(torch.nn.qat.Linear) @register_quant_pattern(torch.nn.intrinsic.LinearReLU) @register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU) @register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear)) @register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear)) # for error checks @register_quant_pattern((torch.nn.ReLU, torch.nn.Linear)) @register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear)) class LinearReLU(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) self.relu_node = None if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \ (node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)): self.relu_node = node node = node.args[0] self.linear_node = node if node.op == 'call_module': self.linear = quantizer.modules[self.linear_node.target] def convert(self, quantizer, node, load_arg, debug=False): # TODO: debug option for linear module if self.linear_node.op == 'call_module': # note that relu should already be fused into conv module in the fusion step assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \ 'please make sure to run fusion before prepare' # 1. attach activation post process to module if type(self.linear) == torch.nn.intrinsic.LinearReLU: self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name] else: self.linear.activation_post_process = quantizer.activation_post_process_map[node.name] # 2. select quantized class if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]: qlinear = torch.nn.quantized.Linear elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]: qlinear = torch.nn.intrinsic.quantized.LinearReLU else: raise Exception("unhandled linear type:", type(self.linear)) quantized = qlinear.from_float(self.linear) parent_name, name = _parent_name(self.linear_node.target) setattr(quantizer.modules[parent_name], name, quantized) return quantizer.quantized_graph.create_node( 'call_module', self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {}) elif self.linear_node.op == 'call_function': if debug: args = load_arg(quantized=[0, 1])(self.linear_node.args) args = load_arg(quantized=False)(self.linear_node.args) kwargs = load_arg(quantized=False)(self.linear_node.kwargs) linear_out = quantizer.quantized_graph.create_node( 'call_function', torch.nn.functional.linear, args, kwargs) root_module = quantizer.modules[''] return quantize_node( root_module, quantizer.quantized_graph, linear_out, quantizer.activation_post_process_map[self.linear_node.name]) else: # TODO: this code can be merged with dynamic linear code # linear args # (x, weight, bias, ...) args = load_arg(quantized=[0, 1])(self.linear_node.args) kwargs = load_arg(quantized=False)(self.linear_node.kwargs) # pack weight weight = load_arg(quantized=True)(self.linear_node.args[1]) bias = None # all args after bias, including bias other_args = load_arg(quantized=False)(self.linear_node.args[2:]) if len(self.linear_node.args) > 2: bias = load_arg(quantized=False)(self.linear_node.args[2]) other_args = other_args[1:] # remove the bias argument else: assert 'bias' in kwargs, \ 'expect bias provided as a keyword argument when it is not a positional argument' bias = kwargs['bias'] kwargs.pop('bias') prepack_args = (weight, bias) packed_weight = quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.linear_prepack, prepack_args, {}) # construct linear input linear_input = load_arg(quantized=True)(self.linear_node.args[0]) activation_post_process = \ quantizer.activation_post_process_map[self.linear_node.name] scale, zero_point, _ = get_per_tensor_qparams(activation_post_process) qlinear_args = (linear_input, packed_weight, scale, zero_point) return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.linear, qlinear_args, kwargs) @register_quant_pattern(torch.nn.BatchNorm2d) @register_quant_pattern(torch.nn.BatchNorm3d) @register_quant_pattern(torch.nn.intrinsic.BNReLU2d) @register_quant_pattern(torch.nn.intrinsic.BNReLU3d) class BatchNorm(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) assert node.op == 'call_module' self.bn_node = node self.bn = quantizer.modules[self.bn_node.target] def convert(self, quantizer, node, load_arg, debug=False): # 1. attach activation post process to module activation_post_process = quantizer.activation_post_process_map[node.name] if type(self.bn) in \ [torch.nn.intrinsic.BNReLU2d, torch.nn.intrinsic.BNReLU3d]: self.bn[1].activation_post_process = activation_post_process else: self.bn.activation_post_process = activation_post_process qbn_cls = get_static_quant_module_class(type(self.bn)) quantized = qbn_cls.from_float(self.bn) parent_name, name = _parent_name(self.bn_node.target) setattr(quantizer.modules[parent_name], name, quantized) return quantizer.quantized_graph.create_node( 'call_module', self.bn_node.target, load_arg(quantized=[0])(self.bn_node.args), load_arg(quantized=False)(self.bn_node.kwargs)) ARGS_TO_SKIP = { torch._ops.ops.quantized.hardswish: ['inplace'], torch._ops.ops.quantized.instance_norm: ['running_mean', 'running_var', 'use_input_stats', 'momentum'], } @register_quant_pattern(torch.nn.ELU) @register_quant_pattern(torch.nn.Hardswish) @register_quant_pattern(torch.nn.InstanceNorm1d) @register_quant_pattern(torch.nn.InstanceNorm2d) @register_quant_pattern(torch.nn.InstanceNorm3d) @register_quant_pattern(torch.nn.LayerNorm) @register_quant_pattern(torch.nn.functional.hardswish) @register_quant_pattern(torch.nn.functional.instance_norm) @register_quant_pattern(torch.nn.functional.layer_norm) class DefaultNode(QuantizeHandler): ''' Common quantized op, first input and first output will be quantized ''' def convert(self, quantizer, node, load_arg, debug=False): if not self.all_nodes: return NotImplemented assert node.op in ['call_module', 'call_function'], 'Only call_module and ' + \ 'call_function are handled in DefaultNode' activation_post_process = quantizer.activation_post_process_map[node.name] if node.op == 'call_module': module = quantizer.modules[node.target] module.activation_post_process = activation_post_process quantized_module_cls = get_static_quant_module_class(type(module)) quantized_module = quantized_module_cls.from_float(module) parent_name, name = _parent_name(node.target) setattr(quantizer.modules[parent_name], name, quantized_module) return quantizer.quantized_graph.create_node( 'call_module', node.target, load_arg(quantized=[0])(node.args), load_arg(quantized=False)(node.kwargs)) else: # call_function scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) kwargs = load_arg(quantized=False)(node.kwargs) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) if quantized_op in ARGS_TO_SKIP: args_to_skip = ARGS_TO_SKIP[quantized_op] for arg in args_to_skip: if arg in kwargs: kwargs.pop(arg) return quantizer.quantized_graph.create_node( 'call_function', quantized_op, args, kwargs) # TODO: elu is using scale/zero_point instead of output_scale, output_zero_point @register_quant_pattern(torch.nn.functional.elu) class ELU(QuantizeHandler): def convert(self, quantizer, node, load_arg, debug=False): activation_post_process = quantizer.activation_post_process_map[node.name] scale, zero_point = activation_post_process.calculate_qparams() scale = float(scale) zero_point = int(zero_point) quantized_op = get_quantized_operator(node.target) args = load_arg(quantized=[0])(node.args) kwargs = load_arg(quantized=False)(node.kwargs) kwargs.update({'output_scale': scale, 'output_zero_point': zero_point}) kwargs.pop('inplace') return quantizer.quantized_graph.create_node( 'call_function', quantized_op, args, kwargs) # these ops have quantized equivalents that do not need any extra information @register_quant_pattern(torch.nn.AdaptiveAvgPool1d) @register_quant_pattern(torch.nn.AdaptiveAvgPool2d) @register_quant_pattern(torch.nn.AdaptiveAvgPool3d) @register_quant_pattern(torch.nn.AvgPool1d) @register_quant_pattern(torch.nn.AvgPool2d) @register_quant_pattern(torch.nn.AvgPool3d) @register_quant_pattern(torch.nn.Dropout) @register_quant_pattern(torch.nn.Hardsigmoid) @register_quant_pattern(torch.nn.Hardtanh) @register_quant_pattern(torch.nn.LeakyReLU) @register_quant_pattern(torch.nn.MaxPool1d) @register_quant_pattern(torch.nn.MaxPool2d) @register_quant_pattern(torch.nn.MaxPool3d) @register_quant_pattern(torch.nn.ReLU) @register_quant_pattern(torch.nn.ReLU6) @register_quant_pattern(torch.nn.Sigmoid) @register_quant_pattern(torch.nn.Tanh) @register_quant_pattern(torch.adaptive_avg_pool1d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d) @register_quant_pattern(torch.nn.functional.adaptive_avg_pool3d) @register_quant_pattern(torch.nn.functional.dropout) @register_quant_pattern(torch.nn.functional.hardsigmoid) @register_quant_pattern(torch.nn.functional.hardtanh) @register_quant_pattern(torch.nn.functional.hardtanh_) @register_quant_pattern(torch.nn.functional.interpolate) @register_quant_pattern(torch.nn.functional.leaky_relu) @register_quant_pattern(torch.nn.functional.max_pool1d) @register_quant_pattern(torch.nn.functional.max_pool2d) @register_quant_pattern(torch.nn.functional.max_pool3d) @register_quant_pattern(torch.nn.functional.relu) @register_quant_pattern(torch.nn.functional.relu6) @register_quant_pattern(torch.avg_pool1d) @register_quant_pattern(torch._C._nn.avg_pool2d) @register_quant_pattern(torch._C._nn.avg_pool3d) @register_quant_pattern(torch.chunk) @register_quant_pattern(torch.clamp) @register_quant_pattern(torch.flatten) @register_quant_pattern(torch.transpose) @register_quant_pattern(torch.max) @register_quant_pattern(torch.mean) @register_quant_pattern(torch.min) @register_quant_pattern(torch.repeat_interleave) @register_quant_pattern(torch.sigmoid) @register_quant_pattern(torch.sort) @register_quant_pattern(torch.squeeze) @register_quant_pattern(torch.stack) @register_quant_pattern(torch.tanh) @register_quant_pattern(torch.unsqueeze) @register_quant_pattern(operator.getitem) @register_quant_pattern(operator.floordiv) @register_quant_pattern('chunk') @register_quant_pattern('clamp') @register_quant_pattern('contiguous') @register_quant_pattern('detach') @register_quant_pattern('detach_') @register_quant_pattern('hardsigmoid') @register_quant_pattern('hardsigmoid_') @register_quant_pattern('leaky_relu') @register_quant_pattern('leaky_relu_') @register_quant_pattern('mean') @register_quant_pattern('numel') @register_quant_pattern('permute') @register_quant_pattern('relu') @register_quant_pattern('relu_') @register_quant_pattern('repeat') @register_quant_pattern('repeat_interleave') @register_quant_pattern('reshape') @register_quant_pattern('resize_') @register_quant_pattern('shape') @register_quant_pattern('sigmoid') @register_quant_pattern('sigmoid_') @register_quant_pattern('size') @register_quant_pattern('squeeze') @register_quant_pattern('squeeze_') @register_quant_pattern('tanh') @register_quant_pattern('tanh_') @register_quant_pattern('transpose') @register_quant_pattern('unsqueeze') @register_quant_pattern('unsqueeze_') @register_quant_pattern('view') class CopyNode(QuantizeHandler): def convert(self, quantizer, node, load_arg, debug=False): return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None)) # Default quantization handler, used for quantization of input and output # of quantizable objects (e.g. modules and functionals) class DefaultQuant(QuantizeHandler): def convert(self, quantizer, node): assert self.all_nodes root_module = quantizer.modules[''] return quantize_node( root_module, quantizer.quantized_graph, node, quantizer.activation_post_process_map[node.name]) # 2. Post Training Dynamic Quantizatoin Patterns @register_dynamic_quant_pattern(torch.nn.Linear) @register_dynamic_quant_pattern(torch.nn.functional.linear) class DynamicLinear(QuantizeHandler): def __init__(self, quantizer, node): super().__init__(quantizer, node) self.linear_node = node if node.op == 'call_module': assert isinstance(quantizer.modules[node.target], torch.nn.Linear) self.linear = quantizer.modules[self.linear_node.target] def convert(self, quantizer, node, load_arg, debug=False): if self.linear_node.op == 'call_module': quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear) parent_name, name = _parent_name(self.linear_node.target) setattr(quantizer.modules[parent_name], name, quantized) return quantizer.quantized_graph.create_node( 'call_module', self.linear_node.target, (load_arg(quantized=False)(self.linear_node.args[0]),), {}) elif self.linear_node.op == 'call_function': if debug: # quantize and dequantize weight args = load_arg(quantized=[1])(self.linear_node.args) args = load_arg(quantized=False)(self.linear_node.args) kwargs = load_arg(quantized=False)(self.linear_node.kwargs) return quantizer.quantized_graph.create_node( 'call_function', torch.nn.functional.linear, args, kwargs) else: # linear args: # (x, weight, bias) # quantize weight quantized_weight = load_arg(quantized=True)(self.linear_node.args[1]) bias = None # all args after bias, including bias other_args = load_arg(quantized=False)(self.linear_node.args[2:]) kwargs = load_arg(quantized=False)(self.linear_node.kwargs) if len(self.linear_node.args) > 2: bias = load_arg(quantized=False)(self.linear_node.args[2]) other_args = other_args[1:] # remove the bias argument else: assert 'bias' in kwargs, \ 'expect bias provided as a keyword argument when it is not a positional argument' bias = kwargs['bias'] kwargs.pop('bias') prepack_args = (quantized_weight, bias) # pack weight packed_weight = quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.linear_prepack, prepack_args, {}) # construct dynamic linear input non_quantized_input = load_arg(quantized=False)(self.linear_node.args[0]) qdynamic_linear_args = (non_quantized_input, packed_weight) return quantizer.quantized_graph.create_node( 'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)