mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][graphmode][fx] Add graph mode quantization on fx (#43175)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43175 This PR added graph mode quantization on fx: https://github.com/pytorch/pytorch/pull/42741 Currently it matches eager mode quantization for torchvision with static/dynamic/qat ddp/synbn test is still wip Test Plan: python test/test_quantization.py TestQuantizeFx Imported from OSS Reviewed By: vkuzo Differential Revision: D23178602 fbshipit-source-id: 8e7e0322846fbda2cfa79ad188abd7235326f879
This commit is contained in:
parent
c89d2c6bf2
commit
dae2973fae
6
mypy.ini
6
mypy.ini
|
|
@ -81,6 +81,12 @@ ignore_errors = True
|
|||
[mypy-torch.quantization._numeric_suite]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization._quantize_fx]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quantization.fx.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.quasirandom]
|
||||
ignore_errors = True
|
||||
|
||||
|
|
|
|||
118
test/quantization/test_quantize_fx.py
Normal file
118
test/quantization/test_quantize_fx.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
# symbolic trace
|
||||
from torch.fx import symbolic_trace
|
||||
|
||||
# graph mode quantization based on fx
|
||||
from torch.quantization._quantize_fx import (
|
||||
Quantizer,
|
||||
fuse,
|
||||
)
|
||||
|
||||
# eager mode quantization
|
||||
from torch.quantization import default_qconfig
|
||||
|
||||
# test utils
|
||||
from torch.testing._internal.common_quantization import (
|
||||
QuantizationTestCase,
|
||||
skipIfNoFBGEMM
|
||||
)
|
||||
|
||||
class TestQuantizeFx(QuantizationTestCase):
|
||||
@skipIfNoFBGEMM
|
||||
def test_functional(self):
|
||||
""" Test quantizing functional conv and linear
|
||||
"""
|
||||
class Conv(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.stride = (1, 1)
|
||||
self.padding = (0, 0)
|
||||
self.dilation = (1, 1)
|
||||
self.groups = 1
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.conv2d(x, weight, None, self.stride, self.padding, self.dilation, self.groups)
|
||||
|
||||
conv_input = torch.rand(1, 3, 224, 224)
|
||||
conv_weight = torch.rand(3, 3, 3, 3)
|
||||
|
||||
class Linear(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.linear(x, weight)
|
||||
|
||||
linear_input = torch.rand(8, 5)
|
||||
linear_weight = torch.rand(10, 5)
|
||||
|
||||
class LinearModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(5, 10)
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x)
|
||||
|
||||
linear_module_input = torch.rand(8, 5)
|
||||
|
||||
tests = [
|
||||
(False, Conv, (conv_input, conv_weight), ('call_function', torch.ops.quantized.conv2d)),
|
||||
(True, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear_dynamic)),
|
||||
(False, Linear, (linear_input, linear_weight), ('call_function', torch.ops.quantized.linear)),
|
||||
(True, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.dynamic.Linear)),
|
||||
(False, LinearModule, (linear_module_input,), ('call_module', torch.nn.quantized.Linear)),
|
||||
]
|
||||
|
||||
for is_dynamic, M, inputs, quantized_node in tests:
|
||||
m = M().eval()
|
||||
qconfig = default_qconfig
|
||||
|
||||
graph = symbolic_trace(m)
|
||||
script = torch.jit.script(graph)
|
||||
|
||||
a = m(*inputs)
|
||||
b = graph(*inputs)
|
||||
c = script(*inputs)
|
||||
assert (a - b).abs().max() == 0
|
||||
assert (a - c).abs().max() == 0
|
||||
assert torch.allclose(a, b)
|
||||
assert torch.allclose(a, c)
|
||||
|
||||
|
||||
graph = fuse(graph)
|
||||
|
||||
quantizer = Quantizer()
|
||||
qconfig_dict = {'': qconfig}
|
||||
if is_dynamic:
|
||||
prepared = quantizer.prepare_dynamic(graph, qconfig_dict)
|
||||
else:
|
||||
prepared = quantizer.prepare(graph, qconfig_dict)
|
||||
|
||||
prepared(*inputs)
|
||||
|
||||
qgraph = quantizer.convert(prepared)
|
||||
qgraph_debug = quantizer.convert(prepared, debug=True)
|
||||
qgraph.eval()
|
||||
qgraph_debug.eval()
|
||||
qgraph_script = torch.jit.script(qgraph)
|
||||
|
||||
d = qgraph(*inputs)
|
||||
d_debug = qgraph_debug(*inputs)
|
||||
e = qgraph_script(*inputs)
|
||||
e_debug = qgraph_debug(*inputs)
|
||||
|
||||
found = False
|
||||
modules = dict(qgraph.root.named_modules())
|
||||
for node in qgraph.graph.nodes:
|
||||
if node.op == 'call_function':
|
||||
found = found or node.op == quantized_node[0] and node.target == quantized_node[1]
|
||||
elif node.op == 'call_module':
|
||||
found = found or node.op == quantized_node[0] and type(modules[node.target]) == quantized_node[1]
|
||||
assert found, 'Expected to find quantized node:' + str(quantized_op)
|
||||
# assert (a-d).abs().max() < 2
|
||||
assert torch.allclose(d, e)
|
||||
assert (d - d_debug).abs().max() == 0
|
||||
assert (e - e_debug).abs().max() == 0
|
||||
|
|
@ -60,6 +60,9 @@ from quantization.test_quantize_jit import TestQuantizeJitOps # noqa: F401
|
|||
from quantization.test_quantize_jit import TestQuantizeDynamicJitPasses # noqa: F401
|
||||
from quantization.test_quantize_jit import TestQuantizeDynamicJitOps # noqaa: F401
|
||||
|
||||
# 3. GraphModule based graph mode quantization
|
||||
from quantization.test_quantize_fx import TestQuantizeFx # noqa: F401
|
||||
|
||||
# Tooling: numric_suite
|
||||
from quantization.test_numeric_suite import TestEagerModeNumericSuite # noqa: F401
|
||||
|
||||
|
|
|
|||
2
torch/quantization/_quantize_fx.py
Normal file
2
torch/quantization/_quantize_fx.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
from .fx import Quantizer # noqa: F401
|
||||
from .fx import fuse # noqa: F401
|
||||
3
torch/quantization/fx/__init__.py
Normal file
3
torch/quantization/fx/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
from .quantize import Quantizer
|
||||
from .fuse import fuse
|
||||
151
torch/quantization/fx/fuse.py
Normal file
151
torch/quantization/fx/fuse.py
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
import torch
|
||||
from torch.quantization.fuse_modules import (
|
||||
fuse_conv_bn,
|
||||
fuse_conv_bn_relu,
|
||||
)
|
||||
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
)
|
||||
|
||||
from torch.fx.graph import (
|
||||
Graph,
|
||||
map_arg,
|
||||
)
|
||||
|
||||
from .pattern_utils import (
|
||||
matches,
|
||||
register_fusion_pattern,
|
||||
get_fusion_patterns,
|
||||
)
|
||||
|
||||
from .utils import _parent_name
|
||||
|
||||
import copy
|
||||
|
||||
# Fusion Patterns
|
||||
@register_fusion_pattern((torch.nn.BatchNorm2d, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
||||
@register_fusion_pattern((torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
@register_fusion_pattern((torch.nn.functional.relu, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
class ConvBNReLUFusion():
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__()
|
||||
self.relu_node = None
|
||||
self.bn_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.ReLU):
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module'
|
||||
if isinstance(quantizer.modules[node.target], torch.nn.BatchNorm2d):
|
||||
self.bn_node = node
|
||||
self.bn = quantizer.modules[self.bn_node.target]
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module' and type(quantizer.modules[node.target]) == torch.nn.modules.Conv2d
|
||||
self.conv_node = node
|
||||
self.conv = quantizer.modules[self.conv_node.target]
|
||||
|
||||
def fuse(self, quantizer, load_arg):
|
||||
conv_parent_name, conv_name = _parent_name(self.conv_node.target)
|
||||
if self.relu_node is not None:
|
||||
# since relu can be used multiple times, we'll need to create a relu module for each match
|
||||
if self.relu_node.op == 'call_module':
|
||||
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
|
||||
else:
|
||||
# TODO: get inplace argument from functional
|
||||
relu = torch.nn.ReLU()
|
||||
relu.training = self.conv.training
|
||||
if self.bn_node is not None:
|
||||
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn_relu(self.conv, self.bn, relu))
|
||||
else:
|
||||
# conv_relu
|
||||
setattr(quantizer.modules[conv_parent_name], conv_name, torch.nn.intrinsic.ConvReLU2d(self.conv, relu))
|
||||
else:
|
||||
assert self.bn_node is not None
|
||||
setattr(quantizer.modules[conv_parent_name], conv_name, fuse_conv_bn(self.conv, self.bn))
|
||||
|
||||
# TODO: do we need to make sure bn is only used once?
|
||||
if self.bn_node is not None:
|
||||
parent_name, name = _parent_name(self.bn_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, torch.nn.Identity())
|
||||
return quantizer.fused_graph.node_copy(self.conv_node, load_arg)
|
||||
|
||||
@register_fusion_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
||||
@register_fusion_pattern((torch.nn.ReLU, torch.nn.Linear))
|
||||
class LinearReLUFusion():
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__()
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_module'
|
||||
assert isinstance(quantizer.modules[node.target], torch.nn.modules.Linear)
|
||||
self.linear_node = node
|
||||
self.linear = quantizer.modules[self.linear_node.target]
|
||||
|
||||
def fuse(self, quantizer, load_arg):
|
||||
# since relu can be used multiple times, we'll need to create a relu module for each match
|
||||
if self.relu_node.op == 'call_module':
|
||||
relu = torch.nn.ReLU(quantizer.modules[self.relu_node.target].inplace)
|
||||
else:
|
||||
# TODO: get inplace argument from functional
|
||||
relu = torch.nn.ReLU()
|
||||
relu.training = self.linear.training
|
||||
# linear_relu
|
||||
linear_parent_name, linear_name = _parent_name(self.linear_node.target)
|
||||
setattr(quantizer.modules[linear_parent_name], linear_name, torch.nn.intrinsic.LinearReLU(self.linear, relu))
|
||||
return quantizer.fused_graph.node_copy(self.linear_node, load_arg)
|
||||
|
||||
class Fuser:
|
||||
def fuse_conv_bn(self, model, inplace=False):
|
||||
input_root = model.root
|
||||
if not inplace:
|
||||
input_root = copy.deepcopy(input_root)
|
||||
input_graph = model.graph
|
||||
self.modules = dict(input_root.named_modules())
|
||||
|
||||
fusion_patterns = get_fusion_patterns()
|
||||
# find conv-bn pairs
|
||||
conv_bn_pairs = self._find_matches(input_root, input_graph, fusion_patterns)
|
||||
self.fused_graph = Graph()
|
||||
env = {}
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
for node in input_graph.nodes:
|
||||
root_node, obj = conv_bn_pairs.get(node.name, (None, None))
|
||||
if root_node is node:
|
||||
env[node.name] = obj.fuse(self, load_arg)
|
||||
elif root_node is None:
|
||||
env[node.name] = self.fused_graph.node_copy(node, load_arg)
|
||||
# node matched in patterns and is not root is removed here
|
||||
|
||||
self.fused_graph.output(load_arg(input_graph.result))
|
||||
return GraphModule(input_root, self.fused_graph)
|
||||
|
||||
def _find_matches(self, root, graph, patterns):
|
||||
modules = dict(root.named_modules())
|
||||
match_map = {} # node name -> (root_node, match_value?)
|
||||
|
||||
def apply_match(pattern, node, match):
|
||||
if isinstance(pattern, tuple):
|
||||
s, *args = pattern
|
||||
apply_match(s, node, match)
|
||||
for subpattern, arg in zip(args, node.args):
|
||||
apply_match(subpattern, arg, match)
|
||||
else:
|
||||
match_map[node.name] = match
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
if node.name not in match_map:
|
||||
for pattern, value in patterns.items():
|
||||
if matches(modules, node, pattern):
|
||||
apply_match(pattern, node, (node, value(self, node)))
|
||||
|
||||
return match_map
|
||||
|
||||
def fuse(graph_module, inplace=False):
|
||||
fuser = Fuser()
|
||||
return fuser.fuse_conv_bn(graph_module, inplace)
|
||||
86
torch/quantization/fx/pattern_utils.py
Normal file
86
torch/quantization/fx/pattern_utils.py
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
|
||||
import torch
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
|
||||
# pattern for conv bn fusion
|
||||
FUSION_PATTERNS = OrderedDict()
|
||||
def register_fusion_pattern(pattern):
|
||||
def insert(fn):
|
||||
FUSION_PATTERNS[pattern] = fn
|
||||
return fn
|
||||
return insert
|
||||
|
||||
def get_fusion_patterns():
|
||||
return FUSION_PATTERNS
|
||||
|
||||
# pattern for both static quantization and qat
|
||||
QUANTIZATION_PATTERNS = OrderedDict()
|
||||
def register_quant_pattern(pattern):
|
||||
def insert(fn):
|
||||
QUANTIZATION_PATTERNS[pattern] = fn
|
||||
return fn
|
||||
return insert
|
||||
|
||||
def get_quant_patterns():
|
||||
return QUANTIZATION_PATTERNS
|
||||
|
||||
# pattern for dynamic quantization
|
||||
DYNAMIC_QUANTIZATION_PATTERNS = OrderedDict()
|
||||
def register_dynamic_pattern(pattern):
|
||||
def insert(fn):
|
||||
DYNAMIC_QUANTIZATION_PATTERNS[pattern] = fn
|
||||
return fn
|
||||
return insert
|
||||
|
||||
def get_dynamic_quant_patterns():
|
||||
return DYNAMIC_QUANTIZATION_PATTERNS
|
||||
|
||||
# Example use of register pattern function:
|
||||
# @register_fusion_pattern(torch.nn.ReLU, (torch.nn.BatchNorm2d, torch.nn.Conv2d)))
|
||||
# class ConvBNReLUFusion():
|
||||
# def __init__(...):
|
||||
# ...
|
||||
#
|
||||
# Note: The order of patterns is important! match function will take whatever is matched first, so we'll
|
||||
# need to put the fusion patterns before single patterns. For example, add_relu should be registered come before relu.
|
||||
# decorators are applied in the reverse order we see. Also when we match the nodes in the graph with these patterns,
|
||||
# we'll start from the last node of the graph and traverse back.
|
||||
|
||||
|
||||
def matches(modules, node, pattern, max_uses=sys.maxsize):
|
||||
""" Matches a node in fx against a pattern
|
||||
"""
|
||||
if isinstance(pattern, tuple):
|
||||
self_match, *arg_matches = pattern
|
||||
if self_match is getattr:
|
||||
assert len(pattern) == 2, 'Expecting getattr pattern to have two elements'
|
||||
arg_matches = []
|
||||
else:
|
||||
self_match = pattern
|
||||
arg_matches = []
|
||||
|
||||
if node.uses > max_uses:
|
||||
return False
|
||||
|
||||
if isinstance(self_match, type) and issubclass(self_match, torch.nn.Module):
|
||||
if node.op != 'call_module':
|
||||
return False
|
||||
if not type(modules[node.target]) == self_match:
|
||||
return False
|
||||
elif callable(self_match):
|
||||
if node.op != 'call_function' or node.target is not self_match:
|
||||
return False
|
||||
elif node.target is getattr:
|
||||
if node.args[1] != pattern[1]:
|
||||
return False
|
||||
elif node.target != self_match:
|
||||
return False
|
||||
|
||||
if not arg_matches:
|
||||
return True
|
||||
|
||||
if len(arg_matches) != len(node.args):
|
||||
return False
|
||||
|
||||
return all(matches(modules, node, arg_match, max_uses=1) for node, arg_match in zip(node.args, arg_matches))
|
||||
782
torch/quantization/fx/quantize.py
Normal file
782
torch/quantization/fx/quantize.py
Normal file
|
|
@ -0,0 +1,782 @@
|
|||
import torch
|
||||
from torch.quantization import (
|
||||
propagate_qconfig_,
|
||||
convert,
|
||||
DEFAULT_QAT_MODULE_MAPPING,
|
||||
)
|
||||
|
||||
from torch.fx import (
|
||||
GraphModule,
|
||||
Proxy,
|
||||
)
|
||||
|
||||
from torch.fx.graph import (
|
||||
Graph,
|
||||
Node,
|
||||
map_arg,
|
||||
)
|
||||
|
||||
from .pattern_utils import (
|
||||
matches,
|
||||
register_quant_pattern,
|
||||
get_quant_patterns,
|
||||
register_dynamic_pattern,
|
||||
get_dynamic_quant_patterns,
|
||||
)
|
||||
|
||||
from .utils import _parent_name
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import copy
|
||||
import enum
|
||||
import operator
|
||||
|
||||
# Quantization type (dynamic quantization, static quantization).
|
||||
# Should match the c++ enum in quantization_type.h
|
||||
class QuantType(enum.IntEnum):
|
||||
DYNAMIC = 0
|
||||
STATIC = 1
|
||||
QAT = 2
|
||||
|
||||
# ------------------------
|
||||
# Helper Functions
|
||||
# ------------------------
|
||||
def get_qparams(activation_post_process):
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
dtype = activation_post_process.dtype
|
||||
return scale, zero_point, dtype
|
||||
|
||||
def quantize_node(node, activation_post_process):
|
||||
scale, zero_point, dtype = get_qparams(activation_post_process)
|
||||
return torch.quantize_per_tensor(node, scale, zero_point, dtype)
|
||||
|
||||
def quantize(quantizer, node):
|
||||
quantize_node(node, quantizer.activation_post_process_map[node.name])
|
||||
|
||||
# A dictionary for querying the weight index for a given op
|
||||
WEIGHT_INDEX_DICT = {
|
||||
torch.nn.functional.conv2d : [1],
|
||||
torch.nn.functional.linear : [1],
|
||||
}
|
||||
|
||||
# Pattern Registrations
|
||||
|
||||
# 1. Post Training Static Quantization and Quantization Aware Training Patterns
|
||||
|
||||
# Base Pattern Handler
|
||||
class QuantizeHandler(ABC):
|
||||
""" Base handler class for the quantizer patterns
|
||||
"""
|
||||
def __init__(self, quantizer, node):
|
||||
""" Records pattern information in __init__, which will be used
|
||||
in convert
|
||||
"""
|
||||
# this is an indicator of whether all the inputs are Node or not
|
||||
# since some op might be quantized differently depending on whether
|
||||
# all inputs are tensors or not, e.g. add/mul
|
||||
self.all_nodes = True
|
||||
|
||||
@abstractmethod
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
""" Convert the given node to a quantized node and insert
|
||||
it to the quantized graph
|
||||
"""
|
||||
return NotImplemented
|
||||
|
||||
@register_quant_pattern(operator.add)
|
||||
@register_quant_pattern((torch.nn.ReLU, operator.add))
|
||||
@register_quant_pattern((torch.nn.functional.relu, operator.add))
|
||||
class Add(QuantizeHandler):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_function' and node.target == operator.add
|
||||
self.add_node = node
|
||||
self.all_nodes = all([isinstance(a, Node) for a in self.add_node.args[:2]])
|
||||
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
if not self.all_nodes:
|
||||
# add scalar
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.add_relu
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', op,
|
||||
load_arg(quantized=[0])(self.add_node.args), self.add_node.kwargs)
|
||||
else:
|
||||
activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.add_relu
|
||||
else:
|
||||
op = torch.ops.quantized.add
|
||||
kwargs = self.add_node.kwargs
|
||||
kwargs.update({'scale': scale, 'zero_point': zero_point})
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', op, load_arg(quantized=True)(self.add_node.args), kwargs)
|
||||
|
||||
@register_quant_pattern(operator.mul)
|
||||
@register_quant_pattern((torch.nn.ReLU, operator.mul))
|
||||
@register_quant_pattern((torch.nn.functional.relu, operator.mul))
|
||||
class Mul(QuantizeHandler):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
assert node.op == 'call_function' and node.target == operator.mul
|
||||
self.mul_node = node
|
||||
self.all_nodes = all([isinstance(a, Node) for a in self.mul_node.args[:2]])
|
||||
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
if not self.all_nodes:
|
||||
# mul scalar
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.mul_relu
|
||||
else:
|
||||
op = torch.ops.quantized.mul
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', op, load_arg(quantized=[0])(self.mul_node.args), self.mul_node.kwargs)
|
||||
else:
|
||||
activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
if self.relu_node is not None:
|
||||
op = torch.ops.quantized.mul_relu
|
||||
else:
|
||||
op = torch.ops.quantized.mul
|
||||
kwargs = self.mul_node.kwargs
|
||||
kwargs.update({'scale': scale, 'zero_point': zero_point})
|
||||
return quantizer.quantized_graph.create_node('call_function', op, load_arg(quantized=True)(self.mul_node.args), kwargs)
|
||||
|
||||
@register_quant_pattern(torch.cat)
|
||||
class Cat(QuantizeHandler):
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
if not self.all_nodes:
|
||||
return NotImplemented
|
||||
activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
scale, zero_point = activation_post_process.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
kwargs = load_arg(quantized=False)(node.kwargs)
|
||||
kwargs.update({'scale': scale, 'zero_point': zero_point})
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.cat, load_arg(quantized=[0])(node.args), kwargs)
|
||||
|
||||
# handle conv, maybe followed by relu
|
||||
# NB: matching order is reversed, that is we match from the bottom of this list to the beginning
|
||||
@register_quant_pattern(torch.nn.Conv2d)
|
||||
@register_quant_pattern(torch.nn.functional.conv2d)
|
||||
@register_quant_pattern(torch.nn.qat.Conv2d)
|
||||
@register_quant_pattern(torch.nn.intrinsic.ConvReLU2d)
|
||||
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBn2d)
|
||||
@register_quant_pattern(torch.nn.intrinsic.qat.ConvBnReLU2d)
|
||||
@register_quant_pattern(torch.nn.intrinsic.qat.ConvReLU2d)
|
||||
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.conv2d))
|
||||
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.conv2d))
|
||||
# just for error checks
|
||||
@register_quant_pattern((torch.nn.ReLU, torch.nn.Conv2d))
|
||||
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Conv2d))
|
||||
class ConvRelu(QuantizeHandler):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
self.conv_node = node
|
||||
if node.op == 'call_module':
|
||||
self.conv = quantizer.modules[self.conv_node.target]
|
||||
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
# TODO: debug option for conv module
|
||||
if self.conv_node.op == 'call_module':
|
||||
# note that relu should already be fused into conv module in the fusion step
|
||||
assert self.relu_node is None, 'conv module and relu fusion is not executed, ' \
|
||||
'please make sure to run fusion before prepare'
|
||||
# 1. attach activation post process to module
|
||||
if type(self.conv) == torch.nn.intrinsic.ConvReLU2d:
|
||||
self.conv[1].activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
else:
|
||||
self.conv.activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
# 2. select quantized class
|
||||
if type(self.conv) in [torch.nn.Conv2d,
|
||||
torch.nn.qat.Conv2d,
|
||||
torch.nn.intrinsic.qat.ConvBn2d]:
|
||||
qconv = torch.nn.quantized.Conv2d
|
||||
elif type(self.conv) in [torch.nn.intrinsic.ConvReLU2d,
|
||||
torch.nn.intrinsic.qat.ConvReLU2d,
|
||||
torch.nn.intrinsic.qat.ConvBnReLU2d]:
|
||||
qconv = torch.nn.intrinsic.quantized.ConvReLU2d
|
||||
else:
|
||||
raise Exception("unhandled conv type:", type(self.conv))
|
||||
quantized = qconv.from_float(self.conv)
|
||||
parent_name, name = _parent_name(self.conv_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_module',
|
||||
self.conv_node.target,
|
||||
(load_arg(quantized=True)(self.conv_node.args[0]),),
|
||||
{})
|
||||
elif self.conv_node.op == 'call_function':
|
||||
if self.relu_node is not None:
|
||||
raise Exception("functional conv + relu is not supported yet")
|
||||
if debug:
|
||||
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
||||
args = load_arg(quantized=False)(self.conv_node.args)
|
||||
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
||||
conv_out = quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.nn.functional.conv2d, args, kwargs)
|
||||
return quantize_node(
|
||||
conv_out, quantizer.activation_post_process_map[self.conv_node.name])
|
||||
else:
|
||||
assert len(self.conv_node.args) == 7, \
|
||||
'only conv2d calls with all arguments specified is support right now in debug=False option'
|
||||
args = load_arg(quantized=[0, 1])(self.conv_node.args)
|
||||
# pack weight
|
||||
weight = load_arg(quantized=True)(self.conv_node.args[1])
|
||||
other_args = load_arg(quantized=False)(self.conv_node.args[2:])
|
||||
prepack_args = [weight] + list(other_args)
|
||||
packed_weight = quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.conv2d_prepack, prepack_args, {})
|
||||
# construct conv input
|
||||
conv_input = load_arg(quantized=True)(self.conv_node.args[0])
|
||||
activation_post_process = quantizer.activation_post_process_map[self.conv_node.name]
|
||||
scale, zero_point, _ = get_qparams(activation_post_process)
|
||||
qconv_args = [conv_input, packed_weight, scale, zero_point]
|
||||
kwargs = load_arg(quantized=False)(self.conv_node.kwargs)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.conv2d, qconv_args, kwargs)
|
||||
|
||||
# handle linear, maybe followed by relu
|
||||
@register_quant_pattern(torch.nn.Linear)
|
||||
@register_quant_pattern(torch.nn.functional.linear)
|
||||
@register_quant_pattern(torch.nn.qat.Linear)
|
||||
@register_quant_pattern(torch.nn.intrinsic.qat.LinearReLU)
|
||||
@register_quant_pattern((torch.nn.functional.relu, torch.nn.functional.linear))
|
||||
@register_quant_pattern((torch.nn.ReLU, torch.nn.functional.linear))
|
||||
# for error checks
|
||||
@register_quant_pattern((torch.nn.ReLU, torch.nn.Linear))
|
||||
@register_quant_pattern((torch.nn.functional.relu, torch.nn.Linear))
|
||||
class LinearReLU(QuantizeHandler):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
self.relu_node = None
|
||||
if (node.op == 'call_function' and node.target is torch.nn.functional.relu) or \
|
||||
(node.op == 'call_module' and isinstance(quantizer.modules[node.target], torch.nn.ReLU)):
|
||||
self.relu_node = node
|
||||
node = node.args[0]
|
||||
self.linear_node = node
|
||||
if node.op == 'call_module':
|
||||
self.linear = quantizer.modules[self.linear_node.target]
|
||||
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
# TODO: debug option for linear module
|
||||
if self.linear_node.op == 'call_module':
|
||||
# note that relu should already be fused into conv module in the fusion step
|
||||
assert self.relu_node is None, 'linear module and relu fusion is not executed, ' \
|
||||
'please make sure to run fusion before prepare'
|
||||
# 1. attach activation post process to module
|
||||
if type(self.linear) == torch.nn.intrinsic.LinearReLU:
|
||||
self.linear[1].activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
else:
|
||||
self.linear.activation_post_process = quantizer.activation_post_process_map[node.name]
|
||||
# 2. select quantized class
|
||||
if type(self.linear) in [torch.nn.Linear, torch.nn.qat.Linear]:
|
||||
qlinear = torch.nn.quantized.Linear
|
||||
elif type(self.linear) in [torch.nn.intrinsic.LinearReLU, torch.nn.intrinsic.qat.LinearReLU]:
|
||||
qlinear = torch.nn.intrinsic.quantized.LinearReLU
|
||||
else:
|
||||
raise Exception("unhandled linear type:", type(self.linear))
|
||||
quantized = qlinear.from_float(self.linear)
|
||||
parent_name, name = _parent_name(self.linear_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_module',
|
||||
self.linear_node.target, (load_arg(quantized=True)(self.linear_node.args[0]),), {})
|
||||
elif self.linear_node.op == 'call_function':
|
||||
if debug:
|
||||
args = load_arg(quantized=[0, 1])(self.linear_node.args)
|
||||
args = load_arg(quantized=False)(self.linear_node.args)
|
||||
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
||||
linear_out = quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.nn.functional.linear, args, kwargs)
|
||||
return quantize_node(
|
||||
linear_out,
|
||||
quantizer.activation_post_process_map[self.linear_node.name])
|
||||
else:
|
||||
args = load_arg(quantized=[0, 1])(self.linear_node.args)
|
||||
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
||||
# pack weight
|
||||
weight = load_arg(quantized=True)(self.linear_node.args[1])
|
||||
bias = None
|
||||
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
|
||||
if len(self.linear_node.args) > 2:
|
||||
bias = load_arg(quantized=False)(self.linear_node.args[2])
|
||||
other_args = other_args[1:] # remove the bias argument
|
||||
else:
|
||||
assert 'bias' in kwargs, \
|
||||
'expect bias provided as a keyword argument when it is not a positional argument'
|
||||
bias = kwargs['bias']
|
||||
kwargs.pop('bias')
|
||||
prepack_args = [weight, bias]
|
||||
packed_weight = quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
|
||||
# construct linear input
|
||||
linear_input = load_arg(quantized=True)(self.linear_node.args[0])
|
||||
activation_post_process = \
|
||||
quantizer.activation_post_process_map[self.linear_node.name]
|
||||
scale, zero_point, _ = get_qparams(activation_post_process)
|
||||
qlinear_args = [linear_input, packed_weight, scale, zero_point]
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.linear, qlinear_args, kwargs)
|
||||
|
||||
# these ops have quantized equivalents that do not need any extra information
|
||||
@register_quant_pattern(torch.nn.AdaptiveAvgPool2d)
|
||||
@register_quant_pattern(torch.nn.AvgPool2d)
|
||||
@register_quant_pattern(torch.nn.Dropout)
|
||||
@register_quant_pattern(torch.nn.MaxPool2d)
|
||||
@register_quant_pattern(torch.nn.ReLU)
|
||||
@register_quant_pattern(torch.nn.ReLU6)
|
||||
@register_quant_pattern(torch.nn.functional.adaptive_avg_pool2d)
|
||||
@register_quant_pattern(torch.nn.functional.dropout)
|
||||
@register_quant_pattern(torch.nn.functional.max_pool2d)
|
||||
@register_quant_pattern(torch._C._nn.avg_pool2d)
|
||||
@register_quant_pattern(torch.flatten)
|
||||
@register_quant_pattern(torch.transpose)
|
||||
@register_quant_pattern(torch.mean)
|
||||
@register_quant_pattern(torch.unsqueeze)
|
||||
@register_quant_pattern(operator.getitem)
|
||||
@register_quant_pattern(operator.floordiv)
|
||||
@register_quant_pattern('chunk')
|
||||
@register_quant_pattern('contiguous')
|
||||
@register_quant_pattern('mean')
|
||||
@register_quant_pattern('reshape')
|
||||
@register_quant_pattern('shape')
|
||||
@register_quant_pattern('size')
|
||||
@register_quant_pattern('view')
|
||||
class CopyNode(QuantizeHandler):
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
return quantizer.quantized_graph.node_copy(node, load_arg(quantized=None))
|
||||
|
||||
class DefaultQuant(QuantizeHandler):
|
||||
def convert(self, quantizer, node):
|
||||
assert self.all_nodes
|
||||
return quantize(quantizer, node)
|
||||
|
||||
# 2. Post Training Dynamic Quantizatoin Patterns
|
||||
@register_dynamic_pattern(torch.nn.Linear)
|
||||
@register_dynamic_pattern(torch.nn.functional.linear)
|
||||
class DynamicLinear(QuantizeHandler):
|
||||
def __init__(self, quantizer, node):
|
||||
super().__init__(quantizer, node)
|
||||
self.linear_node = node
|
||||
if node.op == 'call_module':
|
||||
assert isinstance(quantizer.modules[node.target], torch.nn.Linear)
|
||||
self.linear = quantizer.modules[self.linear_node.target]
|
||||
|
||||
def convert(self, quantizer, node, load_arg, debug=False):
|
||||
if self.linear_node.op == 'call_module':
|
||||
quantized = torch.nn.quantized.dynamic.Linear.from_float(self.linear)
|
||||
parent_name, name = _parent_name(self.linear_node.target)
|
||||
setattr(quantizer.modules[parent_name], name, quantized)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_module',
|
||||
self.linear_node.target,
|
||||
(load_arg(quantized=False)(self.linear_node.args[0]),),
|
||||
{})
|
||||
elif self.linear_node.op == 'call_function':
|
||||
if debug:
|
||||
# quantize and dequantize weight
|
||||
args = load_arg(quantized=[1])(self.linear_node.args)
|
||||
args = load_arg(quantized=False)(self.linear_node.args)
|
||||
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.nn.functional.linear, args, kwargs)
|
||||
else:
|
||||
# quantize and dequantize weight
|
||||
args = load_arg(quantized=[1])(self.linear_node.args)
|
||||
kwargs = load_arg(quantized=False)(self.linear_node.kwargs)
|
||||
# pack weight
|
||||
weight = load_arg(quantized=True)(self.linear_node.args[1])
|
||||
bias = None
|
||||
other_args = load_arg(quantized=False)(self.linear_node.args[1:])
|
||||
if len(self.linear_node.args) > 2:
|
||||
bias = load_arg(quantized=False)(self.linear_node.args[2])
|
||||
other_args = other_args[1:] # remove the bias argument
|
||||
else:
|
||||
assert 'bias' in kwargs, \
|
||||
'expect bias provided as a keyword argument when it is not a positional argument'
|
||||
bias = kwargs['bias']
|
||||
kwargs.pop('bias')
|
||||
prepack_args = [weight, bias]
|
||||
packed_weight = quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.linear_prepack, prepack_args, {})
|
||||
# construct dynamic linear input
|
||||
linear_input = load_arg(quantized=False)(self.linear_node.args[0])
|
||||
qdynamic_linear_args = [linear_input, packed_weight]
|
||||
return quantizer.quantized_graph.create_node(
|
||||
'call_function', torch.ops.quantized.linear_dynamic, qdynamic_linear_args, kwargs)
|
||||
|
||||
class Quantizer:
|
||||
def __init__(self):
|
||||
# mapping from matched node to activation_post_process
|
||||
# must be filled before convert
|
||||
self.activation_post_process_map = None
|
||||
|
||||
def _qat_swap_modules(self, root):
|
||||
convert(root, mapping=DEFAULT_QAT_MODULE_MAPPING, inplace=True, remove_qconfig=False)
|
||||
|
||||
def _generate_qconfig_map(self, root, input_graph):
|
||||
def get_qconfig(module):
|
||||
return module.qconfig if hasattr(module, 'qconfig') else None
|
||||
|
||||
self.qconfig_map = dict()
|
||||
for node in input_graph.nodes:
|
||||
if node.op == 'get_param':
|
||||
parent, _ = _parent_name(node.target)
|
||||
self.qconfig_map[node.name] = get_qconfig(self.modules[parent])
|
||||
elif node.op == 'call_function':
|
||||
self.qconfig_map[node.name] = get_qconfig(root)
|
||||
elif node.op == 'call_method':
|
||||
self_obj = node.args[0]
|
||||
# qconfig for call_method should be the same as the `self` object for the call
|
||||
self.qconfig_map[node.name] = self.qconfig_map[self_obj.name]
|
||||
elif node.op == 'call_module':
|
||||
self.qconfig_map[node.name] = get_qconfig(self.modules[node.target])
|
||||
|
||||
def _prepare(self, model, qconfig_dict, inplace, quant_type):
|
||||
input_root = model.root
|
||||
if not inplace:
|
||||
input_root = copy.deepcopy(input_root)
|
||||
|
||||
input_graph = model.graph
|
||||
self.quant_type = quant_type
|
||||
# TODO: allow user specified patterns
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
self.patterns = get_dynamic_quant_patterns()
|
||||
else:
|
||||
self.patterns = get_quant_patterns()
|
||||
|
||||
propagate_qconfig_(input_root, qconfig_dict)
|
||||
if input_root.training:
|
||||
self._qat_swap_modules(input_root)
|
||||
|
||||
self.modules = dict(input_root.named_modules())
|
||||
|
||||
# map from node name to qconfig, used in _find_matches
|
||||
self._generate_qconfig_map(input_root, input_graph)
|
||||
|
||||
# match the patterns that will get quantized
|
||||
matches = self._find_matches(input_graph, self.modules, self.patterns)
|
||||
|
||||
# find _inputs_ to matched nodes that are not quantized, these
|
||||
# have to be quantized, which requires measuring stats,
|
||||
# initialize an DefaultQuant object for each
|
||||
quants = self._find_quants(input_graph, matches)
|
||||
|
||||
self.activation_post_process_map = dict()
|
||||
|
||||
env = {}
|
||||
observed_graph = Graph()
|
||||
observed = set()
|
||||
|
||||
def load_arg(a):
|
||||
return map_arg(a, lambda node: env[node.name])
|
||||
|
||||
for node in input_graph.nodes:
|
||||
if node.name in observed:
|
||||
continue
|
||||
|
||||
def get_new_observer_name(parent_module):
|
||||
i = 0
|
||||
|
||||
def get_observer_name(i):
|
||||
return 'activation_post_process_' + str(i)
|
||||
observer_name = get_observer_name(i)
|
||||
while hasattr(parent_module, observer_name):
|
||||
i += 1
|
||||
observer_name = get_observer_name(i)
|
||||
return observer_name
|
||||
root_node, _, obj, qconfig = matches.get(node.name, (None, None, None, None))
|
||||
if root_node is None:
|
||||
env[node.name] = observed_graph.node_copy(node, load_arg)
|
||||
elif root_node is node:
|
||||
env[node.name] = observed_graph.node_copy(node, load_arg)
|
||||
|
||||
def insert_observer(node, observer):
|
||||
observer_name = get_new_observer_name(input_root)
|
||||
setattr(input_root, observer_name, observer)
|
||||
self.activation_post_process_map[node.name] = observer
|
||||
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
|
||||
observed.add(node.name)
|
||||
|
||||
# don't need to insert observer for output in dynamic quantization
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
continue
|
||||
|
||||
if isinstance(obj, CopyNode):
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
'call_method'], \
|
||||
'CopyNode of type ' + node.op + ' is not handled'
|
||||
# propagate observed property from input
|
||||
if node.args[0].name in observed:
|
||||
observed.add(node.name)
|
||||
elif (isinstance(obj, Add) or isinstance(obj, Mul)) and not obj.all_nodes:
|
||||
if node.args[0].name in observed:
|
||||
observed.add(node.name)
|
||||
elif qconfig is not None and obj.all_nodes:
|
||||
# observer for outputs
|
||||
insert_observer(node, qconfig.activation())
|
||||
else:
|
||||
env[node.name] = observed_graph.node_copy(node, load_arg)
|
||||
|
||||
if node.name not in observed and node.name in quants:
|
||||
observer_name = get_new_observer_name(input_root)
|
||||
_, qconfig, is_weight = quants[node.name]
|
||||
if qconfig is not None:
|
||||
self.activation_post_process_map[node.name] = qconfig.weight() if is_weight else qconfig.activation()
|
||||
setattr(input_root, observer_name, self.activation_post_process_map[node.name])
|
||||
env[node.name] = observed_graph.create_node('call_module', observer_name, [load_arg(node)], {})
|
||||
observed.add(node.name)
|
||||
observed_graph.output(load_arg(input_graph.result))
|
||||
|
||||
return GraphModule(input_root, observed_graph)
|
||||
|
||||
def prepare(self, model, qconfig_dict, inplace=False):
|
||||
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.STATIC)
|
||||
|
||||
def prepare_dynamic(self, model, qconfig_dict, inplace=False):
|
||||
return self._prepare(model, qconfig_dict, inplace, quant_type=QuantType.DYNAMIC)
|
||||
|
||||
def convert(self, observed, inplace=False, debug=False):
|
||||
assert self.activation_post_process_map is not None
|
||||
observed_root = observed.root
|
||||
observed_graph = observed.graph
|
||||
if not inplace:
|
||||
observed_root = copy.deepcopy(observed_root)
|
||||
self.modules = dict(observed_root.named_modules())
|
||||
|
||||
matches = self._find_matches(observed.graph, self.modules, self.patterns)
|
||||
quants = self._find_quants(observed.graph, matches)
|
||||
self.quantized_graph = Graph()
|
||||
env = {}
|
||||
quant_env = {}
|
||||
|
||||
def load_non_quantized(n):
|
||||
if n.name not in env:
|
||||
assert n.name in quant_env, \
|
||||
'trying to load float node but did not find node:' + n.name + \
|
||||
' in quantized environment:' + str(quant_env)
|
||||
env[n.name] = Proxy(quant_env[n.name]).dequantize().node
|
||||
return env[n.name]
|
||||
|
||||
def load_quantized(n):
|
||||
if n.name not in quant_env:
|
||||
assert n.name in env, \
|
||||
'trying to load quantized node but did not find node:' + n.name + \
|
||||
' in float environment:' + str(env)
|
||||
assert n.name in quants, 'did not find quant object for node:' + n.name
|
||||
quant = quants[n.name][0]
|
||||
quant_env[n.name] = quant.convert(self, env[n.name])
|
||||
return quant_env[n.name]
|
||||
|
||||
def load_x(n):
|
||||
assert n.name in env or n.name in quant_env, \
|
||||
'node ' + n.name + ' does not exist in either of the environment'
|
||||
if n.name in quant_env:
|
||||
return quant_env[n.name]
|
||||
else:
|
||||
return env[n.name]
|
||||
|
||||
def load_arg(quantized):
|
||||
"""
|
||||
if quantized is a list, then arg should be a list and the args with corresponding
|
||||
indexes will be quantized
|
||||
if quantized is a boolean, then all args will be quantized/not quantized
|
||||
if quantized is None, then we'll load the node as long as it exists
|
||||
"""
|
||||
assert quantized is None or isinstance(quantized, (tuple, list, bool)), type(quantized)
|
||||
|
||||
def load_arg_impl(arg):
|
||||
if quantized is None:
|
||||
return map_arg(arg, load_x)
|
||||
if isinstance(quantized, bool):
|
||||
return map_arg(arg, load_quantized if quantized else load_non_quantized)
|
||||
elif isinstance(quantized, (tuple, list)):
|
||||
assert isinstance(arg, (tuple, list)), arg
|
||||
loaded_arg = []
|
||||
# for now, we only support quantizing positional arguments
|
||||
for i, a in enumerate(arg):
|
||||
if i in quantized:
|
||||
loaded_arg.append(map_arg(a, load_quantized))
|
||||
else:
|
||||
loaded_arg.append(map_arg(a, load_non_quantized))
|
||||
return type(arg)(loaded_arg)
|
||||
return load_arg_impl
|
||||
|
||||
def is_quantized(node):
|
||||
assert node.name in env or node.name in quant_env, 'Expecting node to be in the environment'
|
||||
# there might be nodes appearing in both environemnts, but quant_env will take
|
||||
# precedence
|
||||
if node.name in quant_env:
|
||||
return True
|
||||
elif node.name in env:
|
||||
return False
|
||||
|
||||
for node in observed_graph.nodes:
|
||||
root_node, matched, obj, qconfig = matches.get(node.name, (None, None, None, None))
|
||||
if root_node is node:
|
||||
result = obj.convert(self, node, load_arg)
|
||||
quantized = True
|
||||
# Need to get correct quantized/non-quantized state for the output of CopyNode
|
||||
if isinstance(obj, CopyNode):
|
||||
assert node.op in [
|
||||
'call_module',
|
||||
'call_function',
|
||||
'call_method'], \
|
||||
'CopyNode of type ' + node.op + ' is not handled'
|
||||
quantized = is_quantized(node.args[0])
|
||||
|
||||
if self.quant_type == QuantType.DYNAMIC:
|
||||
quantized = False
|
||||
|
||||
if quantized:
|
||||
quant_env[node.name] = result
|
||||
else:
|
||||
env[node.name] = result
|
||||
continue
|
||||
elif root_node is not None:
|
||||
continue
|
||||
|
||||
# handle activation post process calls
|
||||
if node.op == 'call_module':
|
||||
if node.target.split('.')[-1].startswith('activation_post_process_'):
|
||||
observer_module = self.modules[node.target]
|
||||
prev_node = node.args[0]
|
||||
if prev_node.name in quant_env:
|
||||
# if previous node is already quantized, we'll just remove the activation_post_process
|
||||
quant_env[node.name] = quant_env[prev_node.name]
|
||||
continue
|
||||
# replace activation post process with quantization ops
|
||||
parent_name = ''
|
||||
|
||||
scale, zero_point = observer_module.calculate_qparams()
|
||||
scale = float(scale)
|
||||
zero_point = int(zero_point)
|
||||
dtype = observer_module.dtype
|
||||
qparams = {'_scale_': scale, '_zero_point_': zero_point, '_dtype_': dtype}
|
||||
i = 0
|
||||
|
||||
def noattr(module, qparams, i):
|
||||
for name in qparams.keys():
|
||||
if hasattr(module, name + str(i)):
|
||||
return False
|
||||
return True
|
||||
|
||||
def get_next_i(module, qparams):
|
||||
i = 0
|
||||
while not noattr(module, qparams, i):
|
||||
i += 1
|
||||
return i
|
||||
|
||||
parent_module = self.modules[parent_name]
|
||||
i = get_next_i(parent_module, qparams)
|
||||
inputs = [load_non_quantized(node.args[0])]
|
||||
for key, value in qparams.items():
|
||||
setattr(parent_module, key + str(i), value)
|
||||
qparam_full_path = key + str(i)
|
||||
if parent_name:
|
||||
qparam_full_path = parent_name + '.' + qparam_full_path
|
||||
inputs.append(self.quantized_graph.get_param(qparam_full_path))
|
||||
quant_env[node.name] = self.quantized_graph.create_node('call_function', torch.quantize_per_tensor, inputs, {})
|
||||
continue
|
||||
# dequantize inputs for the node that are not quantized
|
||||
env[node.name] = self.quantized_graph.node_copy(node, load_non_quantized)
|
||||
|
||||
self.quantized_graph.output(load_non_quantized(observed_graph.result))
|
||||
|
||||
to_be_removed = []
|
||||
for name, _ in observed_root.named_modules():
|
||||
if name.split('.')[-1].startswith('activation_post_process_'):
|
||||
to_be_removed.append(name)
|
||||
for n in to_be_removed:
|
||||
delattr(observed_root, n)
|
||||
return GraphModule(observed_root, self.quantized_graph)
|
||||
|
||||
def _find_matches(self, graph, modules, patterns):
|
||||
match_map = {} # node name -> (root_node, match_value?)
|
||||
all_matched = set()
|
||||
|
||||
def record_match(pattern, node, matched):
|
||||
if isinstance(pattern, tuple):
|
||||
s, *args = pattern
|
||||
record_match(s, node, matched)
|
||||
if pattern[0] is not getattr:
|
||||
for subpattern, arg in zip(args, node.args):
|
||||
record_match(subpattern, arg, matched)
|
||||
else:
|
||||
matched.append(node)
|
||||
|
||||
for node in reversed(graph.nodes):
|
||||
if node.name not in match_map and node.name not in all_matched:
|
||||
for pattern, value in patterns.items():
|
||||
if matches(modules, node, pattern):
|
||||
matched = []
|
||||
record_match(pattern, node, matched)
|
||||
for n in matched:
|
||||
match_map[n.name] = (node, matched, value(self, node), self.qconfig_map[n.name])
|
||||
all_matched.add(n.name)
|
||||
# break after finding the first match
|
||||
break
|
||||
return match_map
|
||||
|
||||
def _find_quants(self, graph, matches):
|
||||
quants = {}
|
||||
|
||||
def visit(node, qconfig):
|
||||
def visit_arg(arg):
|
||||
# note: we have to measure quantization information
|
||||
# even for nodes where we might not use it because it is already
|
||||
# quantized. This is because each match has the option to
|
||||
# say NotImplemented (if for instance, it is an __add__ and the data type is not appropriate)
|
||||
is_weight = False
|
||||
if isinstance(node, Node) and node.op == 'call_function' and node.target in WEIGHT_INDEX_DICT:
|
||||
for i, node_arg in enumerate(node.args):
|
||||
if arg is node_arg and i in WEIGHT_INDEX_DICT[node.target]:
|
||||
is_weight = True
|
||||
if self.quant_type != QuantType.DYNAMIC or is_weight:
|
||||
# overwrite previous quant config
|
||||
quants[arg.name] = (DefaultQuant(self, arg), qconfig, is_weight)
|
||||
return visit_arg
|
||||
|
||||
for node in graph.nodes:
|
||||
if node.name in matches:
|
||||
root_node, matched, obj, qconfig = matches[node.name]
|
||||
# don't attach observer/fake_quant for CopyNode
|
||||
if isinstance(obj, CopyNode):
|
||||
qconfig = None
|
||||
if root_node is node:
|
||||
# matched[-1] is the first op in the sequence and
|
||||
# matched[0] is the last op in the sequence
|
||||
# inputs
|
||||
map_arg(matched[-1].args, visit(matched[-1], qconfig))
|
||||
map_arg(matched[-1].kwargs, visit(matched[-1], qconfig))
|
||||
# output
|
||||
map_arg(matched[0], visit(None, qconfig))
|
||||
return quants
|
||||
7
torch/quantization/fx/utils.py
Normal file
7
torch/quantization/fx/utils.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# turn foo.bar -> ['foo', 'bar']
|
||||
def _parent_name(target):
|
||||
r = target.rsplit('.', 1)
|
||||
if len(r) == 1:
|
||||
return '', r[0]
|
||||
else:
|
||||
return r[0], r[1]
|
||||
Loading…
Reference in New Issue
Block a user