[quant][graphmode][fx] Add graph mode quantization on fx (#43175)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/43175

This PR added graph mode quantization on fx: https://github.com/pytorch/pytorch/pull/42741
Currently it matches eager mode quantization for torchvision with static/dynamic/qat
ddp/synbn test is still wip

Test Plan:
python test/test_quantization.py TestQuantizeFx

Imported from OSS

Reviewed By: vkuzo

Differential Revision: D23178602

fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
This commit is contained in:
Jerry Zhang 2020-08-20 14:48:04 -07:00 committed by Facebook GitHub Bot
parent c89d2c6bf2
commit dae2973fae
9 changed files with 1158 additions and 0 deletions

View File

@ -81,6 +81,12 @@ ignore_errors = True
[mypy-torch.quantization._numeric_suite]
ignore_errors = True
[mypy-torch.quantization._quantize_fx]
ignore_errors = True
[mypy-torch.quantization.fx.*]
ignore_errors = True
[mypy-torch.quasirandom]
ignore_errors = True

View File

@ -0,0 +1,118 @@
import torch
import torch.nn.functional as F
# symbolic trace
from torch.fx import symbolic_trace
# graph mode quantization based on fx
from torch.quantization._quantize_fx import (
Quantizer,
fuse,
)
# eager mode quantization
from torch.quantization import default_qconfig
# test utils
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM
)
class TestQuantizeFx(QuantizationTestCase):
@skipIfNoFBGEMM
def test_functional(self):
""" Test quantizing functional conv and linear
"""
class Conv(torch.nn.Module):
def __init__(self):
super().__init__()
self.stride = (1, 1)
self.padding = (0, 0)
self.dilation = (1, 1)
self.groups = 1
def forward(self, x, weight):
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
conv_input = torch.rand(1, 3, 224, 224)
conv_weight = torch.rand(3, 3, 3, 3)
class Linear(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x, weight):
return F.linear(x, weight)
linear_input = torch.rand(8, 5)
linear_weight = torch.rand(10, 5)
class LinearModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear = torch.nn.Linear(5, 10)
def forward(self, x):
return self.linear(x)
linear_module_input = torch.rand(8, 5)
tests = [
(False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)),
(True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)),
(False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)),
(True, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.dynamic.Linear)),
(False, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.Linear)),
]
for is_dynamic, M, inputs, quantized_node in tests:
m = M().eval()
qconfig = default_qconfig
graph = symbolic_trace(m)
script = torch.jit.script(graph)
a = m(*inputs)
b = graph(*inputs)
c = script(*inputs)
assert (a - b).abs().max() == 0
assert (a - c).abs().max() == 0
assert torch.allclose(a, b)
assert torch.allclose(a, c)
graph = fuse(graph)
quantizer = Quantizer()
qconfig_dict = {'': qconfig}
if is_dynamic:
prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
else:
prepared = quantizer.prepare(graph, qconfig_dict)
prepared(*inputs)
qgraph = quantizer.convert(prepared)
qgraph_debug = quantizer.convert(prepared, debug=True)
qgraph.eval()
qgraph_debug.eval()
qgraph_script = torch.jit.script(qgraph)
d = qgraph(*inputs)
d_debug = qgraph_debug(*inputs)
e = qgraph_script(*inputs)
e_debug = qgraph_debug(*inputs)
found = False
modules = dict(qgraph.root.named_modules())
for node in qgraph.graph.nodes:
if node.op == 'call_function':
found = found or node.op == quantized_node[0] and node.target == quantized_node[1]
elif node.op == 'call_module':
found = found or node.op == quantized_node[0] and type(modules[node.target]) == quantized_node[1]
assert found, 'Expected to find quantized node:' + str(quantized_op)
# assert (a-d).abs().max() < 2
assert torch.allclose(d, e)
assert (d - d_debug).abs().max() == 0
assert (e - e_debug).abs().max() == 0

View File

@ -60,6 +60,9 @@ from quantization.test_quantize_jit import TestQuantizeJitOps # noqa: F401
from quantization.test_quantize_jit import TestQuantizeDynamicJitPasses # noqa: F401
from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401
# 3. GraphModule based graph mode quantization
from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401
# Tooling: numric_suite
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401

View File

@ -0,0 +1,2 @@
from .fx import Quantizer # noqa: F401
from .fx import fuse # noqa: F401

View File

@ -0,0 +1,3 @@
from __future__ import absolute_import, division, print_function, unicode_literals
from .quantize import Quantizer
from .fuse import fuse

View File

@ -0,0 +1,151 @@
import torch
from torch.quantization.fuse_modules import (
fuse_conv_bn,
fuse_conv_bn_relu,
)
from torch.fx import (
GraphModule,
)
from torch.fx.graph import (
Graph,
map_arg,
)
from .pattern_utils import (
matches,
register_fusion_pattern,
get_fusion_patterns,
)
from .utils import _parent_name
import copy
# Fusion Patterns
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
class ConvBNReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = None
self.bn_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
self.bn_node = node
self.bn = quantizer.modules[self.bn_node.target]
node = node.args[0]
assert node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.modules.Conv2d
self.conv_node = node
self.conv = quantizer.modules[self.conv_node.target]
def fuse(self, quantizer, load_arg):
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
if self.relu_node is not None:
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.conv.training
if self.bn_node is not None:
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn_relu(self.conv, self.bn, relu))
else:
# conv_relu
setattr(quantizer.modules[conv_parent_name], conv_name, torch.nn.intrinsic.ConvReLU2d(self.conv, relu))
else:
assert self.bn_node is not None
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn(self.conv, self.bn))
# TODO: do we need to make sure bn is only used once?
if self.bn_node is not None:
parent_name, name = _parent_name(self.bn_node.target)
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
class LinearReLUFusion():
def __init__(self, quantizer, node):
super().__init__()
self.relu_node = node
node = node.args[0]
assert node.op == 'call_module'
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Linear)
self.linear_node = node
self.linear = quantizer.modules[self.linear_node.target]
def fuse(self, quantizer, load_arg):
# since relu can be used multiple times, we'll need to create a relu module for each match
if self.relu_node.op == 'call_module':
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
else:
# TODO: get inplace argument from functional
relu = torch.nn.ReLU()
relu.training = self.linear.training
# linear_relu
linear_parent_name, linear_name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[linear_parent_name], linear_name, torch.nn.intrinsic.LinearReLU(self.linear, relu))
return quantizer.fused_graph.node_copy(self.linear_node, load_arg)
class Fuser:
def fuse_conv_bn(self, model, inplace=False):
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
input_graph = model.graph
self.modules = dict(input_root.named_modules())
fusion_patterns = get_fusion_patterns()
# find conv-bn pairs
conv_bn_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
self.fused_graph = Graph()
env = {}
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in input_graph.nodes:
root_node, obj = conv_bn_pairs.get(node.name, (None, None))
if root_node is node:
env[node.name] = obj.fuse(self, load_arg)
elif root_node is None:
env[node.name] = self.fused_graph.node_copy(node, load_arg)
# node matched in patterns and is not root is removed here
self.fused_graph.output(load_arg(input_graph.result))
return GraphModule(input_root, self.fused_graph)
def _find_matches(self, root, graph, patterns):
modules = dict(root.named_modules())
match_map = {} # node name -> (root_node, match_value?)
def apply_match(pattern, node, match):
if isinstance(pattern, tuple):
s, *args = pattern
apply_match(s, node, match)
for subpattern, arg in zip(args, node.args):
apply_match(subpattern, arg, match)
else:
match_map[node.name] = match
for node in reversed(graph.nodes):
if node.name not in match_map:
for pattern, value in patterns.items():
if matches(modules, node, pattern):
apply_match(pattern, node, (node, value(self, node)))
return match_map
def fuse(graph_module, inplace=False):
fuser = Fuser()
return fuser.fuse_conv_bn(graph_module, inplace)

View File

@ -0,0 +1,86 @@
import torch
import sys
from collections import OrderedDict
# pattern for conv bn fusion
FUSION_PATTERNS = OrderedDict()
def register_fusion_pattern(pattern):
def insert(fn):
FUSION_PATTERNS[pattern] = fn
return fn
return insert
def get_fusion_patterns():
return FUSION_PATTERNS
# pattern for both static quantization and qat
QUANTIZATION_PATTERNS = OrderedDict()
def register_quant_pattern(pattern):
def insert(fn):
QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert
def get_quant_patterns():
return QUANTIZATION_PATTERNS
# pattern for dynamic quantization
DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
def register_dynamic_pattern(pattern):
def insert(fn):
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
return fn
return insert
def get_dynamic_quant_patterns():
return DYNAMIC_QUANTIZATION_PATTERNS
# Example use of register pattern function:
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
# class ConvBNReLUFusion():
# def __init__(...):
# ...
#
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
# we'll start from the last node of the graph and traverse back.
def matches(modules, node, pattern, max_uses=sys.maxsize):
""" Matches a node in fx against a pattern
"""
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
arg_matches = []
else:
self_match = pattern
arg_matches = []
if node.uses > max_uses:
return False
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
if node.op != 'call_module':
return False
if not type(modules[node.target]) == self_match:
return False
elif callable(self_match):
if node.op != 'call_function' or node.target is not self_match:
return False
elif node.target is getattr:
if node.args[1] != pattern[1]:
return False
elif node.target != self_match:
return False
if not arg_matches:
return True
if len(arg_matches) != len(node.args):
return False
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))

View File

@ -0,0 +1,782 @@
import torch
from torch.quantization import (
propagate_qconfig_,
convert,
DEFAULT_QAT_MODULE_MAPPING,
)
from torch.fx import (
GraphModule,
Proxy,
)
from torch.fx.graph import (
Graph,
Node,
map_arg,
)
from .pattern_utils import (
matches,
register_quant_pattern,
get_quant_patterns,
register_dynamic_pattern,
get_dynamic_quant_patterns,
)
from .utils import _parent_name
from abc import ABC, abstractmethod
import copy
import enum
import operator
# Quantization type (dynamic quantization, static quantization).
# Should match the c++ enum in quantization_type.h
class QuantType(enum.IntEnum):
DYNAMIC = 0
STATIC = 1
QAT = 2
# ------------------------
# Helper Functions
# ------------------------
def get_qparams(activation_post_process):
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
dtype = activation_post_process.dtype
return scale, zero_point, dtype
def quantize_node(node, activation_post_process):
scale, zero_point, dtype = get_qparams(activation_post_process)
return torch.quantize_per_tensor(node, scale, zero_point, dtype)
def quantize(quantizer, node):
quantize_node(node, quantizer.activation_post_process_map[node.name])
# A dictionary for querying the weight index for a given op
WEIGHT_INDEX_DICT = {
torch.nn.functional.conv2d : [1],
torch.nn.functional.linear : [1],
}
# Pattern Registrations
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
# Base Pattern Handler
class QuantizeHandler(ABC):
""" Base handler class for the quantizer patterns
"""
def __init__(self, quantizer, node):
""" Records pattern information in __init__, which will be used
in convert
"""
# this is an indicator of whether all the inputs are Node or not
# since some op might be quantized differently depending on whether
# all inputs are tensors or not, e.g. add/mul
self.all_nodes = True
@abstractmethod
def convert(self, quantizer, node, load_arg, debug=False):
""" Convert the given node to a quantized node and insert
it to the quantized graph
"""
return NotImplemented
@register_quant_pattern(operator.add)
@register_quant_pattern((torch.nn.ReLU, operator.add))
@register_quant_pattern((torch.nn.functional.relu, operator.add))
class Add(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.add
self.add_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# add scalar
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
return quantizer.quantized_graph.create_node(
'call_function', op,
load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.add_relu
else:
op = torch.ops.quantized.add
kwargs = self.add_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
@register_quant_pattern(operator.mul)
@register_quant_pattern((torch.nn.ReLU, operator.mul))
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
class Mul(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
assert node.op == 'call_function' and node.target == operator.mul
self.mul_node = node
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
# mul scalar
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
return quantizer.quantized_graph.create_node(
'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs)
else:
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
if self.relu_node is not None:
op = torch.ops.quantized.mul_relu
else:
op = torch.ops.quantized.mul
kwargs = self.mul_node.kwargs
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
@register_quant_pattern(torch.cat)
class Cat(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
if not self.all_nodes:
return NotImplemented
activation_post_process = quantizer.activation_post_process_map[node.name]
scale, zero_point = activation_post_process.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
kwargs = load_arg(quantized=False)(node.kwargs)
kwargs.update({'scale': scale, 'zero_point': zero_point})
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
# handle conv, maybe followed by relu
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
@register_quant_pattern(torch.nn.Conv2d)
@register_quant_pattern(torch.nn.functional.conv2d)
@register_quant_pattern(torch.nn.qat.Conv2d)
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
# just for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
class ConvRelu(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.conv_node = node
if node.op == 'call_module':
self.conv = quantizer.modules[self.conv_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for conv module
if self.conv_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.conv) == torch.nn.intrinsic.ConvReLU2d:
self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
if type(self.conv) in [torch.nn.Conv2d,
torch.nn.qat.Conv2d,
torch.nn.intrinsic.qat.ConvBn2d]:
qconv = torch.nn.quantized.Conv2d
elif type(self.conv) in [torch.nn.intrinsic.ConvReLU2d,
torch.nn.intrinsic.qat.ConvReLU2d,
torch.nn.intrinsic.qat.ConvBnReLU2d]:
qconv = torch.nn.intrinsic.quantized.ConvReLU2d
else:
raise Exception("unhandled conv type:", type(self.conv))
quantized = qconv.from_float(self.conv)
parent_name, name = _parent_name(self.conv_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.conv_node.target,
(load_arg(quantized=True)(self.conv_node.args[0]),),
{})
elif self.conv_node.op == 'call_function':
if self.relu_node is not None:
raise Exception("functional conv + relu is not supported yet")
if debug:
args = load_arg(quantized=[0, 1])(self.conv_node.args)
args = load_arg(quantized=False)(self.conv_node.args)
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
conv_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.conv2d, args, kwargs)
return quantize_node(
conv_out, quantizer.activation_post_process_map[self.conv_node.name])
else:
assert len(self.conv_node.args) == 7, \
'only conv2d calls with all arguments specified is support right now in debug=False option'
args = load_arg(quantized=[0, 1])(self.conv_node.args)
# pack weight
weight = load_arg(quantized=True)(self.conv_node.args[1])
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
prepack_args = [weight] + list(other_args)
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
# construct conv input
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
qconv_args = [conv_input, packed_weight, scale, zero_point]
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
# handle linear, maybe followed by relu
@register_quant_pattern(torch.nn.Linear)
@register_quant_pattern(torch.nn.functional.linear)
@register_quant_pattern(torch.nn.qat.Linear)
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
# for error checks
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
class LinearReLU(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.relu_node = None
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
self.relu_node = node
node = node.args[0]
self.linear_node = node
if node.op == 'call_module':
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
# TODO: debug option for linear module
if self.linear_node.op == 'call_module':
# note that relu should already be fused into conv module in the fusion step
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
'please make sure to run fusion before prepare'
# 1. attach activation post process to module
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
else:
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
# 2. select quantized class
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
qlinear = torch.nn.quantized.Linear
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
qlinear = torch.nn.intrinsic.quantized.LinearReLU
else:
raise Exception("unhandled linear type:", type(self.linear))
quantized = qlinear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
elif self.linear_node.op == 'call_function':
if debug:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
linear_out = quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
return quantize_node(
linear_out,
quantizer.activation_post_process_map[self.linear_node.name])
else:
args = load_arg(quantized=[0, 1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct linear input
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
activation_post_process = \
quantizer.activation_post_process_map[self.linear_node.name]
scale, zero_point, _ = get_qparams(activation_post_process)
qlinear_args = [linear_input, packed_weight, scale, zero_point]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
# these ops have quantized equivalents that do not need any extra information
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
@register_quant_pattern(torch.nn.AvgPool2d)
@register_quant_pattern(torch.nn.Dropout)
@register_quant_pattern(torch.nn.MaxPool2d)
@register_quant_pattern(torch.nn.ReLU)
@register_quant_pattern(torch.nn.ReLU6)
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
@register_quant_pattern(torch.nn.functional.dropout)
@register_quant_pattern(torch.nn.functional.max_pool2d)
@register_quant_pattern(torch._C._nn.avg_pool2d)
@register_quant_pattern(torch.flatten)
@register_quant_pattern(torch.transpose)
@register_quant_pattern(torch.mean)
@register_quant_pattern(torch.unsqueeze)
@register_quant_pattern(operator.getitem)
@register_quant_pattern(operator.floordiv)
@register_quant_pattern('chunk')
@register_quant_pattern('contiguous')
@register_quant_pattern('mean')
@register_quant_pattern('reshape')
@register_quant_pattern('shape')
@register_quant_pattern('size')
@register_quant_pattern('view')
class CopyNode(QuantizeHandler):
def convert(self, quantizer, node, load_arg, debug=False):
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
class DefaultQuant(QuantizeHandler):
def convert(self, quantizer, node):
assert self.all_nodes
return quantize(quantizer, node)
# 2. Post Training Dynamic Quantizatoin Patterns
@register_dynamic_pattern(torch.nn.Linear)
@register_dynamic_pattern(torch.nn.functional.linear)
class DynamicLinear(QuantizeHandler):
def __init__(self, quantizer, node):
super().__init__(quantizer, node)
self.linear_node = node
if node.op == 'call_module':
assert isinstance(quantizer.modules[node.target], torch.nn.Linear)
self.linear = quantizer.modules[self.linear_node.target]
def convert(self, quantizer, node, load_arg, debug=False):
if self.linear_node.op == 'call_module':
quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear)
parent_name, name = _parent_name(self.linear_node.target)
setattr(quantizer.modules[parent_name], name, quantized)
return quantizer.quantized_graph.create_node(
'call_module',
self.linear_node.target,
(load_arg(quantized=False)(self.linear_node.args[0]),),
{})
elif self.linear_node.op == 'call_function':
if debug:
# quantize and dequantize weight
args = load_arg(quantized=[1])(self.linear_node.args)
args = load_arg(quantized=False)(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
return quantizer.quantized_graph.create_node(
'call_function', torch.nn.functional.linear, args, kwargs)
else:
# quantize and dequantize weight
args = load_arg(quantized=[1])(self.linear_node.args)
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
# pack weight
weight = load_arg(quantized=True)(self.linear_node.args[1])
bias = None
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
if len(self.linear_node.args) > 2:
bias = load_arg(quantized=False)(self.linear_node.args[2])
other_args = other_args[1:] # remove the bias argument
else:
assert 'bias' in kwargs, \
'expect bias provided as a keyword argument when it is not a positional argument'
bias = kwargs['bias']
kwargs.pop('bias')
prepack_args = [weight, bias]
packed_weight = quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
# construct dynamic linear input
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
qdynamic_linear_args = [linear_input, packed_weight]
return quantizer.quantized_graph.create_node(
'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)
class Quantizer:
def __init__(self):
# mapping from matched node to activation_post_process
# must be filled before convert
self.activation_post_process_map = None
def _qat_swap_modules(self, root):
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False)
def _generate_qconfig_map(self, root, input_graph):
def get_qconfig(module):
return module.qconfig if hasattr(module, 'qconfig') else None
self.qconfig_map = dict()
for node in input_graph.nodes:
if node.op == 'get_param':
parent, _ = _parent_name(node.target)
self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
elif node.op == 'call_function':
self.qconfig_map[node.name] = get_qconfig(root)
elif node.op == 'call_method':
self_obj = node.args[0]
# qconfig for call_method should be the same as the `self` object for the call
self.qconfig_map[node.name] = self.qconfig_map[self_obj.name]
elif node.op == 'call_module':
self.qconfig_map[node.name] = get_qconfig(self.modules[node.target])
def _prepare(self, model, qconfig_dict, inplace, quant_type):
input_root = model.root
if not inplace:
input_root = copy.deepcopy(input_root)
input_graph = model.graph
self.quant_type = quant_type
# TODO: allow user specified patterns
if self.quant_type == QuantType.DYNAMIC:
self.patterns = get_dynamic_quant_patterns()
else:
self.patterns = get_quant_patterns()
propagate_qconfig_(input_root, qconfig_dict)
if input_root.training:
self._qat_swap_modules(input_root)
self.modules = dict(input_root.named_modules())
# map from node name to qconfig, used in _find_matches
self._generate_qconfig_map(input_root, input_graph)
# match the patterns that will get quantized
matches = self._find_matches(input_graph, self.modules, self.patterns)
# find _inputs_ to matched nodes that are not quantized, these
# have to be quantized, which requires measuring stats,
# initialize an DefaultQuant object for each
quants = self._find_quants(input_graph, matches)
self.activation_post_process_map = dict()
env = {}
observed_graph = Graph()
observed = set()
def load_arg(a):
return map_arg(a, lambda node: env[node.name])
for node in input_graph.nodes:
if node.name in observed:
continue
def get_new_observer_name(parent_module):
i = 0
def get_observer_name(i):
return 'activation_post_process_' + str(i)
observer_name = get_observer_name(i)
while hasattr(parent_module, observer_name):
i += 1
observer_name = get_observer_name(i)
return observer_name
root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
if root_node is None:
env[node.name] = observed_graph.node_copy(node, load_arg)
elif root_node is node:
env[node.name] = observed_graph.node_copy(node, load_arg)
def insert_observer(node, observer):
observer_name = get_new_observer_name(input_root)
setattr(input_root, observer_name, observer)
self.activation_post_process_map[node.name] = observer
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
observed.add(node.name)
# don't need to insert observer for output in dynamic quantization
if self.quant_type == QuantType.DYNAMIC:
continue
if isinstance(obj, CopyNode):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
# propagate observed property from input
if node.args[0].name in observed:
observed.add(node.name)
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
if node.args[0].name in observed:
observed.add(node.name)
elif qconfig is not None and obj.all_nodes:
# observer for outputs
insert_observer(node, qconfig.activation())
else:
env[node.name] = observed_graph.node_copy(node, load_arg)
if node.name not in observed and node.name in quants:
observer_name = get_new_observer_name(input_root)
_, qconfig, is_weight = quants[node.name]
if qconfig is not None:
self.activation_post_process_map[node.name] = qconfig.weight() if is_weight else qconfig.activation()
setattr(input_root, observer_name, self.activation_post_process_map[node.name])
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
observed.add(node.name)
observed_graph.output(load_arg(input_graph.result))
return GraphModule(input_root, observed_graph)
def prepare(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
def convert(self, observed, inplace=False, debug=False):
assert self.activation_post_process_map is not None
observed_root = observed.root
observed_graph = observed.graph
if not inplace:
observed_root = copy.deepcopy(observed_root)
self.modules = dict(observed_root.named_modules())
matches = self._find_matches(observed.graph, self.modules, self.patterns)
quants = self._find_quants(observed.graph, matches)
self.quantized_graph = Graph()
env = {}
quant_env = {}
def load_non_quantized(n):
if n.name not in env:
assert n.name in quant_env, \
'trying to load float node but did not find node:' + n.name + \
' in quantized environment:' + str(quant_env)
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
return env[n.name]
def load_quantized(n):
if n.name not in quant_env:
assert n.name in env, \
'trying to load quantized node but did not find node:' + n.name + \
' in float environment:' + str(env)
assert n.name in quants, 'did not find quant object for node:' + n.name
quant = quants[n.name][0]
quant_env[n.name] = quant.convert(self, env[n.name])
return quant_env[n.name]
def load_x(n):
assert n.name in env or n.name in quant_env, \
'node ' + n.name + ' does not exist in either of the environment'
if n.name in quant_env:
return quant_env[n.name]
else:
return env[n.name]
def load_arg(quantized):
"""
if quantized is a list, then arg should be a list and the args with corresponding
indexes will be quantized
if quantized is a boolean, then all args will be quantized/not quantized
if quantized is None, then we'll load the node as long as it exists
"""
assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)
def load_arg_impl(arg):
if quantized is None:
return map_arg(arg, load_x)
if isinstance(quantized, bool):
return map_arg(arg, load_quantized if quantized else load_non_quantized)
elif isinstance(quantized, (tuple, list)):
assert isinstance(arg, (tuple, list)), arg
loaded_arg = []
# for now, we only support quantizing positional arguments
for i, a in enumerate(arg):
if i in quantized:
loaded_arg.append(map_arg(a, load_quantized))
else:
loaded_arg.append(map_arg(a, load_non_quantized))
return type(arg)(loaded_arg)
return load_arg_impl
def is_quantized(node):
assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
# there might be nodes appearing in both environemnts, but quant_env will take
# precedence
if node.name in quant_env:
return True
elif node.name in env:
return False
for node in observed_graph.nodes:
root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
if root_node is node:
result = obj.convert(self, node, load_arg)
quantized = True
# Need to get correct quantized/non-quantized state for the output of CopyNode
if isinstance(obj, CopyNode):
assert node.op in [
'call_module',
'call_function',
'call_method'], \
'CopyNode of type ' + node.op + ' is not handled'
quantized = is_quantized(node.args[0])
if self.quant_type == QuantType.DYNAMIC:
quantized = False
if quantized:
quant_env[node.name] = result
else:
env[node.name] = result
continue
elif root_node is not None:
continue
# handle activation post process calls
if node.op == 'call_module':
if node.target.split('.')[-1].startswith('activation_post_process_'):
observer_module = self.modules[node.target]
prev_node = node.args[0]
if prev_node.name in quant_env:
# if previous node is already quantized, we'll just remove the activation_post_process
quant_env[node.name] = quant_env[prev_node.name]
continue
# replace activation post process with quantization ops
parent_name = ''
scale, zero_point = observer_module.calculate_qparams()
scale = float(scale)
zero_point = int(zero_point)
dtype = observer_module.dtype
qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype}
i = 0
def noattr(module, qparams, i):
for name in qparams.keys():
if hasattr(module, name + str(i)):
return False
return True
def get_next_i(module, qparams):
i = 0
while not noattr(module, qparams, i):
i += 1
return i
parent_module = self.modules[parent_name]
i = get_next_i(parent_module, qparams)
inputs = [load_non_quantized(node.args[0])]
for key, value in qparams.items():
setattr(parent_module, key + str(i), value)
qparam_full_path = key + str(i)
if parent_name:
qparam_full_path = parent_name + '.' + qparam_full_path
inputs.append(self.quantized_graph.get_param(qparam_full_path))
quant_env[node.name] = self.quantized_graph.create_node('call_function', torch.quantize_per_tensor, inputs, {})
continue
# dequantize inputs for the node that are not quantized
env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
self.quantized_graph.output(load_non_quantized(observed_graph.result))
to_be_removed = []
for name, _ in observed_root.named_modules():
if name.split('.')[-1].startswith('activation_post_process_'):
to_be_removed.append(name)
for n in to_be_removed:
delattr(observed_root, n)
return GraphModule(observed_root, self.quantized_graph)
def _find_matches(self, graph, modules, patterns):
match_map = {} # node name -> (root_node, match_value?)
all_matched = set()
def record_match(pattern, node, matched):
if isinstance(pattern, tuple):
s, *args = pattern
record_match(s, node, matched)
if pattern[0] is not getattr:
for subpattern, arg in zip(args, node.args):
record_match(subpattern, arg, matched)
else:
matched.append(node)
for node in reversed(graph.nodes):
if node.name not in match_map and node.name not in all_matched:
for pattern, value in patterns.items():
if matches(modules, node, pattern):
matched = []
record_match(pattern, node, matched)
for n in matched:
match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name])
all_matched.add(n.name)
# break after finding the first match
break
return match_map
def _find_quants(self, graph, matches):
quants = {}
def visit(node, qconfig):
def visit_arg(arg):
# note: we have to measure quantization information
# even for nodes where we might not use it because it is already
# quantized. This is because each match has the option to
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
is_weight = False
if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
for i, node_arg in enumerate(node.args):
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
is_weight = True
if self.quant_type != QuantType.DYNAMIC or is_weight:
# overwrite previous quant config
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
return visit_arg
for node in graph.nodes:
if node.name in matches:
root_node, matched, obj, qconfig = matches[node.name]
# don't attach observer/fake_quant for CopyNode
if isinstance(obj, CopyNode):
qconfig = None
if root_node is node:
# matched[-1] is the first op in the sequence and
# matched[0] is the last op in the sequence
# inputs
map_arg(matched[-1].args, visit(matched[-1], qconfig))
map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
# output
map_arg(matched[0], visit(None, qconfig))
return quants

View File

@ -0,0 +1,7 @@
# turn foo.bar -> ['foo', 'bar']
def _parent_name(target):
r = target.rsplit('.', 1)
if len(r) == 1:
return '', r[0]
else:
return r[0], r[1]