pytorch/test/quantization/jit/test_quantize_jit.py
John Clow 9477211e7d Hoisting common expressions out of If blocks (#59492)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59492

Adding code to find common expressions from the two subblocks of an if
operation and hoist them before the if block.
This also allows Dead Code Elimination to
then eliminate some if blocks.

Also eliminated some dead code in the codebase.

Test Plan:
python test_jit.py TestIfHoisting

Imported from OSS

Reviewed By: ngimel

Differential Revision: D29399533

fbshipit-source-id: 9336b9dc48c02c38862f98f98cd72fc1767a1802
2021-08-18 16:29:30 -07:00

3836 lines
140 KiB
Python

# -*- coding: utf-8 -*-
# torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.jit
import torch.jit.quantized
# torch.quantization
from torch.quantization import (
QConfig,
default_dynamic_qconfig,
float16_dynamic_qconfig,
default_observer,
per_channel_dynamic_qconfig,
default_per_channel_weight_observer,
default_qconfig,
get_default_qconfig,
quantize,
quantize_dynamic,
default_weight_observer,
default_histogram_observer,
fuse_modules,
quantize_jit,
quantize_dynamic_jit,
PlaceholderObserver,
)
# torch.quantization.quantize_jit
from torch.quantization.quantize_jit import (
convert_jit,
convert_dynamic_jit,
fuse_conv_bn_jit,
prepare_jit,
prepare_dynamic_jit,
script_qconfig,
)
# Testing utils
from torch.testing._internal.common_quantized import (
override_qengines,
qengine_is_fbgemm,
qengine_is_qnnpack,
)
from torch.testing._internal.common_quantization import (
QuantizationTestCase,
skipIfNoFBGEMM,
get_script_module,
SingleLayerLinearModel,
SkipQuantModel,
NestedModel,
ConvModel,
ConvTransposeModel,
default_per_channel_qconfig,
test_only_eval_fn,
ConvBnModel,
)
# Annotated models
from torch.testing._internal.common_quantization import (
AnnotatedSingleLayerLinearModel,
AnnotatedSkipQuantModel,
AnnotatedNestedModel,
AnnotatedConvModel,
AnnotatedConvTransposeModel,
AnnotatedConvBnModel,
)
from torch.testing import FileCheck
from torch.testing._internal.jit_utils import attrs_with_prefix
from torch.testing._internal.jit_utils import get_forward
from torch.testing._internal.jit_utils import get_forward_graph
from torch.jit._recursive import wrap_cpp_module
# Standard library
from typing import List, Tuple
import io
import itertools
import unittest
class TestQuantizeJitPasses(QuantizationTestCase):
"""Test graph mode quantization passes used by quantize_jit"""
def test_foldbn_trivial(self):
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
# Test trivial case
class TestModule(torch.nn.Module):
def __init__(self, dim):
super(TestModule, self).__init__()
self.conv = conv_module[dim](1, 20, 5, 1)
self.bn = bn_module[dim](num_features=20)
self.bn.eps = 0.0023
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
options = itertools.product([True, False], [2, 3])
data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)}
# Check that the transformation doesn't change numerics
for tracing, dim in options:
eager = TestModule(dim).eval()
x = data[dim]
scripted_or_traced = get_script_module(eager, tracing, x).eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 2, exactly=True
).run(str(get_forward(scripted_or_traced._c).graph))
# Run FoldConvBatchnorm pass.
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
# Check that after the pass one of the CallMethods is gone (supposedly,
# the bn.forward).
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 1, exactly=True
).run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_trivial_nobias(self):
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
# Test trivial case
class TestModule(torch.nn.Module):
def __init__(self, dim):
super(TestModule, self).__init__()
self.conv = conv_module[dim](1, 20, 5, 1, bias=False)
self.bn = bn_module[dim](num_features=20)
# to make sure new bias is not zero
self.bn.eps = 0.0027
self.bn.bias = torch.nn.Parameter(torch.rand([20]))
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
options = itertools.product([True, False], [2, 3])
data = {2: torch.rand(1, 1, 6, 6), 3: torch.rand(1, 1, 6, 6, 6)}
for tracing, dim in options:
eager = TestModule(dim).eval()
x = data[dim]
scripted_or_traced = get_script_module(eager, tracing, x).eval()
# Check that in the original script module's forward we have two
# CallMethod nodes. One of them should be for conv.forward and the other
# for bn.forward.
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 2, exactly=True
).run(str(get_forward_graph(scripted_or_traced._c)))
# Run FoldConvBatchnorm pass.
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
# Check that after the pass one of the CallMethods is gone (supposedly,
# the bn.forward).
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 1, exactly=True
).run(str(get_forward_graph(scripted_or_traced._c)))
# Check that the transformation doesn't change numerics
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_in_submodule(self):
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
# Test that we find Conv-BN patterns in submodules
class SubModule(torch.nn.Module):
def __init__(self, dim):
super(SubModule, self).__init__()
self.conv = conv_module[dim](1, 20, 5, 1)
self.bn = bn_module[dim](num_features=20)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class TestModule(torch.nn.Module):
def __init__(self, dim):
super(TestModule, self).__init__()
self.sub = SubModule(dim)
def forward(self, x):
x = self.sub(x)
return x
options = itertools.product([True, False], [2, 3])
data = {2: torch.rand(1, 1, 10, 10), 3: torch.rand(1, 1, 10, 10, 10)}
for tracing, dim in options:
eager = TestModule(dim).eval()
x = data[dim]
scripted_or_traced = get_script_module(eager, tracing, x).eval()
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 2, exactly=True
).run(str(get_forward_graph(scripted_or_traced.sub._c)))
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
FileCheck().check_count(
'prim::CallMethod[name="forward"]', 1, exactly=True
).run(str(get_forward_graph(scripted_or_traced.sub._c)))
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_shared_classtype(self):
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
class TestModule(torch.nn.Module):
def __init__(self, dim, bias=False):
super(TestModule, self).__init__()
self.conv1 = conv_module[dim](5, 5, 3, bias=bias)
self.bn1 = bn_module[dim](num_features=5)
self.bn1.running_mean.fill_(-0.2)
self.bn1.bias = torch.nn.Parameter(torch.rand([5]))
# to make sure new bias is not zero
self.bn1.eps = 0.0023
self.conv2 = conv_module[dim](5, 5, 3, bias=bias)
self.bn2 = bn_module[dim](num_features=5)
self.bn2.eps = 0.0029
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
return x
options = itertools.product([True, False], [2, 2], [True, False])
data = {2: torch.rand(1, 5, 6, 6), 3: torch.rand(1, 5, 6, 6, 6)}
for tracing, dim, bias in options:
eager = TestModule(dim, bias).eval()
x = data[dim]
scripted_or_traced = get_script_module(eager, tracing, x)
folded = fuse_conv_bn_jit(scripted_or_traced)
self.assertEqual(eager(x), scripted_or_traced(x))
def test_foldbn_no_fusion(self):
"""Test that we don't fuse the cases when module type does not match"""
class CustomConv(torch.nn.Module):
def __init__(self):
super(CustomConv, self).__init__()
def forward(self, x):
return x
class CustomBn(torch.nn.Module):
def __init__(self):
super(CustomBn, self).__init__()
def forward(self, x):
return x
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = CustomConv()
self.bn = CustomBn()
def forward(self, x):
return self.bn(self.conv(x))
m = torch.jit.script(M())
m = fuse_conv_bn_jit(m)
FileCheck().check_count("prim::CallMethod", 2, exactly=True).run(m.graph)
def test_foldbn_complex_cases(self):
# This test case attempt to try combinations of conv2d/conv3d with bias/nobias
# as well as BatchNorm with affine/no-affine along with varying the
# number of layers.
# this only works when default dtype is double
torch.set_default_dtype(torch.double)
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
conv_module = {2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
class SubModule(torch.nn.Module):
def __init__(self, dim, num_blocks, enable_bias, enable_affine):
super(SubModule, self).__init__()
layers = []
for i in range(num_blocks):
layers.append(conv_module[dim](20, 20, 5, 1, bias=enable_bias))
bn_obj = bn_module[dim](num_features=20, affine=enable_affine)
if enable_affine:
bn_obj.weight = torch.nn.Parameter(
torch.rand_like(bn_obj.weight)
)
bn_obj.bias = torch.nn.Parameter(torch.rand_like(bn_obj.bias))
bn_obj.running_mean = torch.rand_like(bn_obj.running_mean)
bn_obj.running_var = torch.rand_like(bn_obj.running_var)
layers.append(bn_obj)
self.layers = nn.Sequential(*layers)
def forward(self, x):
return self.layers(x)
class TestModule(torch.nn.Module):
def __init__(self, dim, num_blocks, enable_bias, enable_affine):
super(TestModule, self).__init__()
self.sub = SubModule(dim, num_blocks, enable_bias, enable_affine)
def forward(self, x):
x = self.sub(x)
return x
options = itertools.product(
[True, False], [2, 3], [True, False], [True, False], [1, 2]
)
data = {2: torch.rand(1, 20, 10, 10), 3: torch.rand(1, 20, 10, 10, 10)}
for tracing, dim, enable_bias, enable_bn_affine, num_layers in options:
eager = TestModule(dim, num_layers, enable_bias, enable_bn_affine).eval()
x = data[dim]
scripted_or_traced = get_script_module(eager, tracing, x).eval()
FileCheck().check_count(
'prim::CallMethod[name="forward"]', num_layers * 2, exactly=True
).run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
scripted_or_traced = fuse_conv_bn_jit(scripted_or_traced)
FileCheck().check_count(
'prim::CallMethod[name="forward"]', num_layers, exactly=True
).run(str(get_forward_graph(scripted_or_traced.sub.layers._c)))
self.assertEqual(eager(x), scripted_or_traced(x))
torch.set_default_dtype(torch.float)
def test_fuse_linear(self):
class FunctionalLinear(torch.nn.Module):
def __init__(self, weight, bias):
super(FunctionalLinear, self).__init__()
self.weight = weight
self.bias = bias
def forward(self, x):
res = torch.matmul(x, self.weight.t())
if self.bias is not None:
res.add_(self.bias)
return res
x1 = torch.rand(3)
w1 = torch.rand(5, 3)
b1 = torch.rand(5)
x2 = torch.rand(5, 5)
w2 = torch.rand(5, 5)
b2 = torch.rand(5)
x3 = torch.rand(5, 5, 5)
w3 = torch.rand(5, 5)
b3 = torch.rand(5)
for has_bias, (x, weight, b) in itertools.product(
[True, False], [(x1, w1, b1), (x2, w2, b2), (x3, w3, b3)]
):
bias = b if has_bias else None
model = torch.jit.trace(FunctionalLinear(weight, bias), [x])
for node in model.graph.nodes():
if node.kind() == "aten::matmul":
source_range_1 = node.sourceRange()
torch._C._jit_pass_fuse_linear(model.graph)
for node in model.graph.nodes():
if node.kind() == "aten::linear":
source_range_2 = node.sourceRange()
FileCheck().check("aten::linear").run(model.graph)
check_not = ["aten::matmul", "aten::addmm", "aten::add_", "aten::t("]
for cn in check_not:
FileCheck().check_not(cn).run(model.graph)
# make sure it runs
self.assertTrue(source_range_1 == source_range_2)
model(x)
# check matmuls are not fused
class Matmul(torch.nn.Module):
def __init__(self, weight):
super(Matmul, self).__init__()
self.weight = weight
def forward(self, x):
return torch.matmul(x, self.weight)
x = torch.rand(5, 6, 5)
w = torch.rand(5, 5, 100)
model = torch.jit.trace(Matmul(w), [x])
torch._C._jit_pass_fuse_linear(model.graph)
# check 3d matmul is not fused
FileCheck().check("aten::matmul").run(model.graph)
FileCheck().check_not("aten::linear").run(model.graph)
# make sure it runs
model(x)
def test_insert_observers(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return self.conv(x)
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# for input and output of conv
assert len(attrs_with_prefix(m, "_observer_")) == 2
# for weight
assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
def test_insert_observers_interface(self):
@torch.jit.interface
class SubInterface(torch.nn.Module):
def addOne(self, inp) -> torch.Tensor:
pass
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def addOne(self, inp):
return self.fc(inp) + 1
def forward(self, x):
return self.addOne(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
qconfig_dict = {"sub.conv": default_qconfig}
m = prepare_jit(m, qconfig_dict)
def test_insert_observers_interface_unshare_type(self):
@torch.jit.interface
class OperatorIf(nn.Module):
def forward(self, inp: torch.Tensor) -> torch.Tensor:
pass
class Operator(nn.Module):
def __init__(self, a):
super().__init__()
self.a = a
def forward(self, inp: torch.Tensor) -> torch.Tensor:
return self.a * (inp + self.a)
class Inner(nn.Module):
op: OperatorIf
def __init__(self, op):
super().__init__()
self.op = op
def forward(self, inp):
return self.op(inp)
class Outer(nn.Module):
def __init__(self):
super().__init__()
self.inner_a = Inner(Operator(1))
self.inner_b = Inner(Operator(3.0))
def forward(self, inp):
return self.inner_a(inp) + self.inner_b(inp)
qconfig_dict = {"inner_a": default_qconfig, "inner_b": default_qconfig}
eager_model = Outer()
for tracing in [True, False]:
x = torch.rand(3)
script_model = get_script_module(eager_model, tracing, x)
# make sure it runs
prepare_jit(script_model, qconfig_dict)
def test_insert_observers_child_qconfig(self):
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
qconfig_dict = {"sub.fc": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# input and output of sub
assert len(attrs_with_prefix(m, "_observer_")) == 2
# not quantized
assert len(attrs_with_prefix(m.conv, "_observer_")) == 0
# no observers since we observe in the outer most call site
assert len(attrs_with_prefix(m.sub, "_observer_")) == 0
# weight of linear
assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1
@unittest.skipUnless(
"fbgemm" in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
def test_insert_observers_skip_values(self):
class ConvFunctionalReLU(torch.nn.Module):
def __init__(self):
super(ConvFunctionalReLU, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return F.relu(self.conv(x))
class ConvReLUModule(torch.nn.Module):
def __init__(self):
super(ConvReLUModule, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(self.conv(x))
class AddReLUModule(torch.nn.Module):
def __init__(self):
super(AddReLUModule, self).__init__()
self.relu = torch.nn.ReLU()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
out = self.conv(x)
out += x
return self.relu(out)
class AddFunctionalReLU(torch.nn.Module):
def __init__(self):
super(AddFunctionalReLU, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
out = self.conv(x)
out += x
return F.relu(out)
def attrs_with_prefix(module, prefix):
return [x for x, _ in module._modules._c.items() if x.startswith(prefix)]
qconfig_dict = {"": default_qconfig}
m = torch.jit.script(ConvFunctionalReLU())
m = prepare_jit(m, qconfig_dict)
# observer for weight of conv
assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
# observer for input of conv and output of relu
assert len(attrs_with_prefix(m, "_observer_")) == 2
m = torch.jit.script(ConvReLUModule())
m = prepare_jit(m, qconfig_dict)
# observer for input of conv and output of relu
assert len(attrs_with_prefix(m, "_observer_")) == 2
# observer for weight of conv
assert len(attrs_with_prefix(m.conv, "_observer_")) == 1
# observer for output of relu
assert len(attrs_with_prefix(m.relu, "_observer_")) == 0
m = torch.jit.script(AddReLUModule())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
assert len(attrs_with_prefix(m, "_observer")) == 3
assert len(attrs_with_prefix(m.relu, "_observer")) == 0
FileCheck().check("aten::add_").check_not(
'Observer = prim::GetAttr[name="_observer_'
).check("ReLU = prim::GetAttr").run(str(get_forward_graph(m._c)))
m = torch.jit.script(AddFunctionalReLU())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
assert len(attrs_with_prefix(m, "_observer")) == 3
FileCheck().check("aten::add_").check_not(
'Observer = prim::GetAttr[name="_observer_'
).check("CallFunction").check('Observer = prim::GetAttr[name="_observer_').run(
str(get_forward_graph(m._c))
)
def test_insert_observers_weight_dtype(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
def forward(self, x):
return F.relu(self.conv(x))
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
activation_dtypes = set(
obs.getattr("dtype")
for x, obs in m._modules._c.items()
if x.startswith("_observer_")
)
weight_dtypes = set(
obs.getattr("dtype")
for x, obs in m.conv._modules._c.items()
if x.startswith("_observer_")
)
assert len(activation_dtypes) == 1, "Expected to have 1 activation dtype"
assert len(weight_dtypes) == 1, "Expected to have 1 weight dtype"
assert (
list(activation_dtypes)[0] != list(weight_dtypes)[0]
), "Expected activation dtype to "
" be different from wegiht dtype"
def test_insert_observers_for_reused_weight(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x, y, weight):
x = F.conv2d(x, weight)
y = F.conv2d(y, weight)
return x + y
m = torch.jit.script(M()).eval()
m = prepare_jit(m, {"": default_qconfig})
# 3 for x, y, weight, one for output of each F.conv2d and one for output of add
assert len(attrs_with_prefix(m, "_observer")) == 6
def test_insert_observers_shared_class_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 5, 3).float()
self.conv2 = torch.nn.Conv2d(3, 5, 3).float()
def forward(self, x):
return self.conv2(self.conv1(x))
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# conv1 and conv2 shares the same type, we need to
# make sure we didn't quantize the type twice
conv1_observers = attrs_with_prefix(m.conv1, "_observer_")
conv2_observers = attrs_with_prefix(m.conv2, "_observer_")
assert len(conv1_observers) == 1, "Expected to have 1 observer submodules"
assert len(conv2_observers) == 1, "Expected to have 1 observer submodules"
assert (
conv1_observers == conv2_observers
), "Expect conv1 and conv2 to have same observers since the class type is shared"
def test_insert_observers_for_general_ops(self):
"""Make sure we skip observers for ops that doesn't require
observation, e.g. flatten
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv(x)
x = torch.flatten(x)
return x
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# input and output of conv
assert len(attrs_with_prefix(m, "_observer_")) == 2
FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
'prim::GetAttr[name="conv"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"aten::flatten"
).check_not(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
# TODO: this is too long, split this to test_insert_observers.py and remove
# insrt_observers prefix
def test_insert_observers_propagate_observed(self):
"""Make sure we propagate observed property through general ops"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv1(x)
x = torch.flatten(x)
# we don't want to insert observer for input of self.conv2
# because output of self.conv1 is already observed
x = self.conv2(x)
return x
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# input and output of conv
assert len(attrs_with_prefix(m, "_observer_")) == 3
FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"aten::flatten"
).check_not(
'Observer = prim::GetAttr[name="_observer_'
).check(
'prim::GetAttr[name="conv2"]'
).check(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
def test_insert_observers_propagate_observed_in_submodule(self):
"""Make sure we propagate observed property through general ops"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
self.avgpool = torch.nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.conv1(x)
x = self.avgpool(x)
# we don't want to insert observer for input of self.conv2
# because output of self.conv1 is already observed
x = self.conv2(x)
return x
m = torch.jit.script(M())
qconfig_dict = {"": default_qconfig}
m = prepare_jit(m, qconfig_dict)
# input and output of conv
assert len(attrs_with_prefix(m, "_observer_")) == 3
FileCheck().check('Observer = prim::GetAttr[name="_observer_').check(
'prim::GetAttr[name="conv1"]'
).check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check(
"prim::CallMethod"
).check_not(
'Observer = prim::GetAttr[name="_observer_'
).check(
'prim::GetAttr[name="conv2"]'
).check(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
def test_insert_observers_propagate_observed_for_function(self):
def channel_shuffle(x: torch.Tensor, groups: int) -> torch.Tensor:
batchsize, num_channels, height, width = x.data.size()
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 1).float()
self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
def forward(self, x):
x = self.conv1(x)
x = channel_shuffle(x, 1)
x = self.conv2(x)
return x
data = [
(
torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.randint(0, 1, (1,), dtype=torch.long),
)
for _ in range(2)
]
m = torch.jit.script(M()).eval()
m = prepare_jit(m, {"": default_qconfig})
# we want to test that channel_shuffle is going to pass
# the observed property from the output of conv1 to input of conv2
# so that we don't insert observers for input of conv2
assert (
len(
attrs_with_prefix(
m,
"_observer_",
)
)
== 3
)
def test_insert_observers_for_if(self):
class QuantProp(torch.nn.Module):
def __init__(self, use_skip):
super(QuantProp, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = use_skip
def forward(self, x):
if self.use_skip:
x = self.conv(x)
return torch.reshape(x, x.shape)
else:
x = self.conv(x)
return torch.reshape(x, x.shape)
class Res(torch.nn.Module):
def __init__(self, use_skip):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = use_skip
def forward(self, x):
if self.use_skip:
return self.conv(x)
else:
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.quant_prop = QuantProp(True)
self.res = Res(False)
def forward(self, x):
x = self.quant_prop(x)
x = self.res(x)
return x
data = [torch.rand(1, 3, 10, 10, dtype=torch.float)]
result = {False: [1, 2, 2], True: [2, 1, 0]}
for tracing in [True, False]:
if tracing:
m = torch.jit.trace(M(), data).eval()
else:
m = torch.jit.script(M()).eval()
m = prepare_jit(m, {"": default_qconfig})
assert (
len(
attrs_with_prefix(
m,
"_observer_",
)
)
== result[tracing][0]
)
assert (
len(
attrs_with_prefix(
m.quant_prop,
"_observer_",
)
)
== result[tracing][1]
)
assert (
len(
attrs_with_prefix(
m.res,
"_observer_",
)
)
== result[tracing][2]
)
def test_insert_observers_for_nested_if(self):
class Res(torch.nn.Module):
def __init__(self, use_skip):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.cond = use_skip
self.use_skip = use_skip
def forward(self, x):
if self.use_skip:
if self.cond:
return self.conv(x)
else:
return self.conv(x)
else:
return self.conv(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.res1 = Res(True)
self.res2 = Res(False)
def forward(self, x):
x = self.res1(x)
x = self.res2(x)
return x
data = torch.rand((1, 3, 10, 10), dtype=torch.float)
result = {True: 3, False: 1}
for tracing in [True, False]:
if tracing:
m = torch.jit.trace(M(), data).eval()
else:
m = torch.jit.script(M()).eval()
m = prepare_jit(m, {"": default_qconfig})
assert len(attrs_with_prefix(m, "_observer_")) == result[tracing]
def test_insert_observers_for_if_consistent_observation(self):
"""check quantization for if works as long as
output of all branches are quantized/observed consistently
"""
class M(torch.nn.Module):
def __init__(self, cond):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
self.cond = cond
def forward(self, x):
x = self.conv(x)
# x is already observed
if self.cond:
x = torch.flatten(x)
return x
class M2(torch.nn.Module):
def __init__(self, cond):
super(M2, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
self.cond = cond
def forward(self, x):
x = self.conv1(x)
if self.cond:
x = self.conv2(x)
# x will be observed in the branch
else:
x = torch.flatten(x)
# since output for both branch are quantized
# the if node is quantized consistently
return x
data = torch.rand((1, 3, 5, 5), dtype=torch.float)
options = list(itertools.product([True, False], [True, False]))
for cond, tracing in options:
if tracing:
m = torch.jit.trace(M(cond), data)
else:
m = torch.jit.script(M(cond))
m = prepare_jit(m, {"": default_qconfig})
assert len(attrs_with_prefix(m, "_observer_")) == 2
for cond, tracing in options:
if tracing:
m = torch.jit.trace(M2(cond), data)
else:
m = torch.jit.script(M2(cond))
m = prepare_jit(m, {"": default_qconfig})
num_observers = 2 if tracing and not cond else 3
assert len(attrs_with_prefix(m, "_observer_")) == num_observers
def test_insert_quant_dequant(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3).float()
def forward(self, x):
return self.conv(x)
for is_per_channel in [True, False]:
m = torch.jit.script(M())
observer = (
default_per_channel_weight_observer.with_args(ch_axis=1)
if is_per_channel
else default_observer
)
qconfig_dict = {"": QConfig(activation=observer, weight=observer)}
m = prepare_jit(m, qconfig_dict)
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data)
m = convert_jit(m, debug=True)
assert (
len(m._modules._c.items()) == 1
), "Expected to have single submodule of conv"
# make sure the quantized model is executable
m(data)
quant_func = (
"aten::quantize_per_channel"
if is_per_channel
else "aten::quantize_per_tensor"
)
FileCheck().check_count(quant_func, 3, exactly=True).run(m.graph)
def test_insert_quant_dequant_shared_class_type(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
return self.conv2(self.conv1(x))
for is_per_channel in [True, False]:
m = torch.jit.script(M())
observer = (
default_per_channel_weight_observer.with_args(ch_axis=1)
if is_per_channel
else default_observer
)
qconfig = QConfig(activation=observer, weight=observer)
qconfig_dict = {"": qconfig}
m = prepare_jit(m, qconfig_dict)
# observers for input, output and value between conv1/conv2
assert (
len(attrs_with_prefix(m, "_observer_")) == 3
), "Expected to have 3 obervers"
# observer for weight
assert (
len(attrs_with_prefix(m.conv1, "_observer_")) == 1
), "Expected to have 1 obervers"
# observer for weight
assert (
len(attrs_with_prefix(m.conv2, "_observer_")) == 1
), "Expected to have 1 obervers"
data = torch.randn(1, 3, 10, 10, dtype=torch.float)
m(data)
m = convert_jit(m, debug=True)
m(data)
assert m.conv1._c._type() == m.conv2._c._type()
# check all observers have been removed
assert (
len(attrs_with_prefix(m, "_observer_")) == 0
), "Expected to have 0 obervers"
assert (
len(attrs_with_prefix(m.conv1, "_observer_")) == 0
), "Expected to have 0 obervers"
assert (
len(attrs_with_prefix(m.conv2, "_observer_")) == 0
), "Expected to have 0 obervers"
quant_func = (
"aten::quantize_per_channel"
if is_per_channel
else "aten::quantize_per_tensor"
)
for module in ["conv1", "conv2"]:
conv = m._c.getattr(module)
# quantize weight
FileCheck().check(quant_func).check_next("aten::dequantize").check(
'prim::CallMethod[name="_conv_forward"]'
).check("return").run(get_forward_graph(conv))
# no quantize node in _conv_forward
FileCheck().check_not(quant_func).check("aten::conv2d").check_not(
quant_func
).check("return").run(conv._get_method("_conv_forward").graph)
def test_dedup_module_uses(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.relu(x)
x -= 0.5
return self.relu(x)
data = torch.randn((2, 2))
m = torch.jit.script(M())
ref_res = m(data)
assert (
len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 1
), "Expected to have 1 relu modules after dedup module uses"
torch._C._jit_pass_dedup_module_uses(m._c)
m = torch.jit._recursive.wrap_cpp_module(m._c)
res = m(data)
assert (
len([x for x, _ in m._modules._c.items() if x.startswith("relu")]) == 2
), "Expected to have 2 relu modules after dedup module uses"
self.assertEqual(res, ref_res)
def test_replicate_dequantize(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
def forward(self, x):
x = torch.dequantize(x)
r = self.conv(x)
r += x
return r
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
m = torch.jit.script(M())
ref_res = m(x)
FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
torch._C._jit_pass_replicate_dequantize(m.graph)
FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph)
res = get_forward(m._c)(x)
self.assertEqual(res, ref_res)
def test_replicate_dequantize_in_block(self):
class M(torch.nn.Module):
def __init__(self, cond):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.cond = cond
def forward(self, x):
x = torch.dequantize(x)
if self.cond:
x = self.conv(x)
else:
x = x + 3
return x
x = torch.randn([1, 3, 10, 10], dtype=torch.float)
x = torch.quantize_per_tensor(x, 0.5, 1, torch.quint8)
m = torch.jit.script(M(True))
ref_res = m(x)
FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
torch._C._jit_pass_replicate_dequantize(m.graph)
FileCheck().check_count("aten::dequantize", 2, exactly=True).run(m.graph)
# check dequantize is right before CallMethod of conv
FileCheck().check("aten::dequantize").check_next("CallMethod").run(m.graph)
# check dequantize is right before add
FileCheck().check("aten::dequantize").check("aten::dequantize").check_next(
"aten::add"
).run(m.graph)
res = get_forward(m._c)(x)
self.assertEqual(res, ref_res)
def test_swap_functional_linear(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
def forward(self, x, weight, bias):
x = torch.dequantize(x)
weight = torch.dequantize(weight)
x = F.linear(x, weight, bias)
x = torch.quantize_per_tensor(
x, scale=1.0, zero_point=0, dtype=torch.quint8
)
return x
x = torch.rand((10, 5), dtype=torch.float)
x = torch.quantize_per_tensor(x, scale=0.5, zero_point=1, dtype=torch.quint8)
weight = torch.rand((5, 5), dtype=torch.float)
weight = torch.quantize_per_tensor(
weight, scale=0.5, zero_point=1, dtype=torch.qint8
)
bias = torch.rand((5), dtype=torch.float)
m = torch.jit.script(M())
ref_res = m(x, weight, bias)
FileCheck().check("CallFunction").run(m.graph)
torch._C._jit_pass_swap_functional_linear(m.graph)
FileCheck().check("aten::linear").check_not("CallFunction").run(m.graph)
res = m(x, weight, bias)
self.assertEqual(res, ref_res)
def test_replicate_quantize_for_if(self):
"""We want to move quantize nodes for output of prim::If
inside the prim::If blocks so that we can match quantization
patterns.
"""
class Res(torch.nn.Module):
def __init__(self):
super(Res, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 1).float()
self.conv2 = torch.nn.Conv2d(3, 3, 1).float()
self.use_skip = True
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
# to avoid being frozen
self.use_skip = cond
if self.use_skip:
return self.conv(x)
else:
return self.conv2(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.res1 = Res()
self.res2 = Res()
def forward(self, x):
x = self.res1(x, True)
x = self.res2(x, False)
return x
data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
qconfig_dict = {"": default_qconfig}
m = torch.jit.script(M()).eval()
m = quantize_jit(m, qconfig_dict, test_only_eval_fn, [data])
# make sure patterns in both branches are fused
FileCheck().check_count("quantized::conv2d(", 4, exactly=True).run(m.graph)
def test_finalize_for_linear(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc(x)
data = [[torch.rand((1, 5), dtype=torch.float)]]
qconfig_dict = {"": default_qconfig}
model = torch.jit.script(M()).eval()
model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
# make sure there is only one quantize_per_tensor for input
# and linear_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).check_not(
"quantized::linear_prepack"
).check("quantized::linear").run(model.graph)
def test_inplace_option(self):
for tracing in [True, False]:
model = get_script_module(
torch.nn.Conv2d(3, 3, 3).float(), tracing, self.img_data_2d[0][0]
)
qconfig_dict = {"": default_qconfig}
quantize_jit(
model, qconfig_dict, test_only_eval_fn, [self.img_data_2d], inplace=True
)
FileCheck().check("quantized::conv2d").run(model.graph)
FileCheck().check_not("aten::conv2d").run(model.graph)
def test_finalize_debug(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
self.avgpool = torch.nn.AvgPool2d(3)
def forward(self, x):
x = self.conv(x)
x = self.avgpool(x)
return x
data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
qconfig_dict = {"": default_qconfig}
model = torch.jit.script(M()).eval()
model = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data], debug=True)
FileCheck().check_not("quantized::conv2d").check("aten::conv2d").check(
"aten::avg_pool2d"
).check("aten::q_scale").check_next("aten::q_zero_point").check_next(
"prim::dtype"
).check_next(
"aten::quantize_per_tensor"
).check(
"aten::dequantize"
).run(
model.graph
)
def test_module_list(self):
class SimpleLinearLayer(torch.nn.Module):
def __init__(self):
super(SimpleLinearLayer, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
return self.fc(x)
class ComplexModel(torch.nn.Module):
def __init__(self):
super(ComplexModel, self).__init__()
self.layers = torch.nn.ModuleList(
[SimpleLinearLayer() for i in range(2)]
)
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
states = []
for layer in self.layers:
val = layer(x)
states.append(val)
return states
data = torch.rand((1, 5), dtype=torch.float)
qconfig_dict = {"": default_qconfig}
model = torch.jit.script(ComplexModel()).eval()
model = prepare_jit(model, qconfig_dict)
assert len(attrs_with_prefix(model, "_observer")) == 3
model(data)
model = convert_jit(model, debug=False)
FileCheck().check("quantized::linear").check("quantized::linear").run(
model.graph
)
def test_conv_trace(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1d = torch.nn.Conv1d(3, 3, 3).float()
self.conv2d = torch.nn.Conv2d(3, 3, 3).float()
self.conv3d = torch.nn.Conv3d(3, 3, 3).float()
def forward(self, x, y, z):
a = self.conv1d(x)
b = self.conv2d(y)
c = self.conv3d(z)
return (a, b, c)
qconfig_dict = {"": default_qconfig}
inputs = (
torch.rand((1, 3, 10), dtype=torch.float),
torch.rand((1, 3, 10, 10), dtype=torch.float),
torch.rand((1, 3, 10, 10, 10), dtype=torch.float),
)
model = torch.jit.trace(M(), inputs).eval()
m = prepare_jit(model, qconfig_dict)
FileCheck().check("aten::conv1d").check_not("aten::_convolution").run(
str(get_forward_graph(m.conv1d._c))
)
FileCheck().check("aten::conv2d").check_not("aten::_convolution").run(
str(get_forward_graph(m.conv2d._c))
)
FileCheck().check("aten::conv3d").check_not("aten::_convolution").run(
str(get_forward_graph(m.conv3d._c))
)
@unittest.skipUnless(
"fbgemm" in torch.backends.quantized.supported_engines,
" Quantized operations require FBGEMM. FBGEMM is only optimized for CPUs"
" with instruction set support avx2 or newer.",
)
def test_replicate_dequant_same_value(self):
class Mul(torch.nn.Module):
def __init__(self):
super(Mul, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x):
x = self.conv(x)
return x * x
data = [[torch.rand((1, 3, 10, 10), dtype=torch.float)]]
qconfig_dict = {"": default_qconfig}
model = torch.jit.script(Mul()).eval()
m = quantize_jit(model, qconfig_dict, test_only_eval_fn, [data])
FileCheck().check("quantized::mul(").check_not("aten::mul").run(m.graph)
def test_interface_with_fork(self):
class SubModule(torch.nn.Module):
def __init__(self):
super(SubModule, self).__init__()
self.embedding1 = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=False,
mode="sum",
)
def forward(self, x, y):
return self.embedding1(x, y)
class OrigMod(torch.nn.Module):
def __init__(self):
super(OrigMod, self).__init__()
self.embedding1 = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=False,
mode="sum",
)
def forward(self, x, y):
return self.embedding1(x, y)
@torch.jit.interface
class ModInterface(torch.nn.Module):
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
pass
class TestModule(torch.nn.Module):
proxy_mod: ModInterface
def __init__(self):
super(TestModule, self).__init__()
self.proxy_mod = OrigMod()
self.sub = SubModule()
def forward(self, x, y):
a = self.proxy_mod(x, y)
b = self.sub(x, y)
return b
class MainModule(torch.nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.test = TestModule()
def forward(self, x, y):
fut = torch.jit._fork(self.test.forward, x, y)
z = torch.jit._wait(fut)
return z
indices = torch.tensor(
[
9,
6,
5,
7,
8,
8,
9,
2,
8,
6,
6,
9,
1,
6,
8,
8,
3,
2,
3,
6,
3,
6,
5,
7,
0,
8,
4,
6,
5,
8,
2,
3,
]
)
offsets = torch.tensor([0, 19, 20, 28, 28, 32])
m = torch.jit.trace(MainModule(), (indices, offsets))
m.eval()
int8_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_byte"
),
weight=PlaceholderObserver.with_args(custom_op_name="embedding_bag_byte"),
)
m = prepare_jit(m, {"": int8_qconfig})
m = convert_jit(m)
FileCheck().check("quantized::embedding_bag_byte_rowwise_offsets").run(m.graph)
@skipIfNoFBGEMM
def test_quantize_fork_wait(self):
"""Tests the case where fork and wait calls are in different subgraphs
Calling inline fork-wait only removes the fork call and leaves aten::wait
calls in the graph, with Tensor as input (instead of Future[Tensor])
"""
class MainModule(nn.Module):
def __init__(self):
super(MainModule, self).__init__()
self.fork_ops = ForkModule()
def init_values(self, x):
shared_module = self.fork_ops(x)
self.fork_dict = shared_module
def forward(self, x):
val = torch.jit._wait(self.fork_ops(x))
return val
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, x):
w = torch.ones(5, 5)
b = torch.zeros(5)
return torch.nn.functional.linear(x, w, b)
class ForkModule(nn.Module):
def __init__(self):
super(ForkModule, self).__init__()
self.test = TestModule()
def forward(self, x):
fut = torch.jit._fork(self.test.forward, x)
return fut
model = MainModule().eval()
traced = torch.jit.trace(model, (torch.randn(5, 5),))
model = prepare_dynamic_jit(traced, {"": default_qconfig})
model = convert_dynamic_jit(model)
FileCheck().check("quantized::linear_dynamic").run(model.graph)
# Make sure model save works
b = io.BytesIO()
torch.jit.save(model, b)
class TestQuantizeJitOps(QuantizationTestCase):
"""Test graph mode post training static quantization works
for individual ops end to end.
"""
@skipIfNoFBGEMM
def test_linear(self):
class ModuleLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(ModuleLinear, self).__init__()
self.linear = torch.nn.Linear(30, 4).float()
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(self.linear(x))
class FuncLinear(torch.nn.Module):
def __init__(self, has_relu=False, f_relu=False):
super(FuncLinear, self).__init__()
self.w = torch.randn(4, 30)
self.b = torch.randn(4)
if has_relu:
if f_relu:
self.relu = F.relu
else:
self.relu = torch.nn.ReLU()
else:
self.relu = torch.nn.Identity()
def forward(self, x):
return self.relu(F.linear(x, self.w, self.b))
data = [[torch.rand((1, 30), dtype=torch.float)]]
for model, tracing in itertools.product(
[ModuleLinear(has_relu=False), FuncLinear(has_relu=False)], [True, False]
):
model = self.checkGraphModeOp(model, data, "quantized::linear", tracing)
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
model.graph
)
FileCheck().check_not("quantized::linear_prepack").run(model.graph)
for f_relu, tracing in itertools.product([True, False], [True, False]):
for model in [
ModuleLinear(has_relu=True, f_relu=f_relu),
FuncLinear(has_relu=True, f_relu=f_relu),
]:
model = self.checkGraphModeOp(
model, data, "quantized::linear_relu", tracing
)
checker = (
FileCheck()
.check_not("aten::linear")
.check_not("aten::relu")
.check_not("quantized::linear(")
.check_not("quantized::relu(")
.run(model.graph)
)
@skipIfNoFBGEMM
def test_quantized_conv(self):
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
class Conv(torch.nn.Module):
def __init__(self, dim):
super(Conv, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return self.conv(x)
options = itertools.product([1, 2, 3], [True, False])
for dim, tracing in options:
model = self.checkGraphModeOp(
Conv(dim),
self.img_data_dict[dim],
"quantized::conv{}d".format(dim),
tracing,
)
# make sure there is only one quantize_per_tensor for input
# and conv2d_prepack is folded
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
model.graph
)
FileCheck().check_not("quantized::conv{}d_prepack".format(dim)).run(
model.graph
)
@skipIfNoFBGEMM
def test_quantized_conv_relu(self):
"""tests for conv1d_relu/conv2d_relu/conv3d_relu"""
conv_module = {1: torch.nn.Conv1d, 2: torch.nn.Conv2d, 3: torch.nn.Conv3d}
class ConvNdRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(ConvNdRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
return self.relu(self.conv(x))
class ConvNdFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x))
class ConvNdInplaceFunctionalRelu(torch.nn.Module):
def __init__(self, dim):
super(ConvNdInplaceFunctionalRelu, self).__init__()
self.conv = conv_module[dim](3, 3, 3).float()
def forward(self, x):
return F.relu(self.conv(x), True)
options = itertools.product([1, 2, 3], [True, False])
for dim, tracing in options:
for orig_m in [
ConvNdRelu(dim, True),
ConvNdRelu(dim, False),
ConvNdFunctionalRelu(dim),
ConvNdInplaceFunctionalRelu(dim),
]:
conv_name = "conv{}d".format(dim)
m = self.checkGraphModeOp(
orig_m,
self.img_data_dict[dim],
"quantized::conv{}d_relu(".format(dim),
tracing=tracing,
)
FileCheck().check_not("aten::conv{}d(".format(dim)).check_not(
"aten::relu"
).check_not("quantized::conv{}d(".format(dim)).check_not(
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM
def test_quantized_add_alpha(self):
"""Test quant fusion for multiple aten::add using same
constant alpha as the third argument
"""
class QuantizedAdd(torch.nn.Module):
def __init__(self):
super(QuantizedAdd, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
z = x + y
w = y + z
return z + w
data = [
[
torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randn(1, 2, 5, 5, dtype=torch.float),
]
]
for tracing in [True, False]:
m = self.checkGraphModeOp(QuantizedAdd(), data, "quantized::add", tracing)
FileCheck().check_count("quantized::add", 3, exactly=True).run(m.graph)
FileCheck().check_not("aten::add").check_not("aten::add_").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add_relu_alpha(self):
"""Test quant fusion for multiple aten::add using same
constant alpha as the third argument in add_relu pattern
"""
class AddRelu(torch.nn.Module):
def __init__(self, inplace):
super(AddRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
x = self.relu(x)
x = x + y
return self.relu(x)
class InplaceAddRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceAddRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
x = self.relu(x)
x += y
return self.relu(x)
class AddFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
x = F.relu(x)
x = x + y
return F.relu(x)
class InplaceAddFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
x = F.relu(x)
x += y
return F.relu(x)
class AddInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
x = F.relu(x, True)
x = x + y
return F.relu(x, True)
class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
x = F.relu(x, True)
x += y
return F.relu(x, True)
data = [
[
torch.rand((1, 2, 5, 5), dtype=torch.float),
torch.rand((1, 2, 5, 5), dtype=torch.float),
]
]
for m_orig in [
AddRelu(True),
AddRelu(False),
InplaceAddRelu(True),
InplaceAddRelu(False),
AddFunctionalRelu(),
InplaceAddFunctionalRelu(),
AddInplaceFunctionalRelu(),
InplaceAddInplaceFunctionalRelu(),
]:
for tracing in [True, False]:
m = self.checkGraphModeOp(
m_orig, data, "quantized::add_relu(", tracing=tracing
)
FileCheck().check_count("quantized::add_relu(", 2, exactly=True).run(
m.graph
)
FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM
def test_quantized_add(self):
class QuantizedAdd(torch.nn.Module):
def __init__(self):
super(QuantizedAdd, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return x + y
class QuantizedInplaceAdd(torch.nn.Module):
def __init__(self):
super(QuantizedInplaceAdd, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
return x
class NonQuantizedAdd(torch.nn.Module):
def __init__(self):
super(NonQuantizedAdd, self).__init__()
def forward(self, x, y):
return x + y
class NonQuantizedInplaceAdd(torch.nn.Module):
def __init__(self):
super(NonQuantizedInplaceAdd, self).__init__()
def forward(self, x, y):
x += y
return x
data = [
[
torch.randn(1, 2, 3, 3, dtype=torch.float),
torch.randn(1, 2, 3, 3, dtype=torch.float),
]
]
for m, quantized in [
(QuantizedAdd(), True),
(QuantizedInplaceAdd(), True),
(NonQuantizedAdd(), False),
(NonQuantizedInplaceAdd(), False),
]:
for tracing in [True, False]:
op = "quantized::add" if quantized else "aten::add"
m = self.checkGraphModeOp(m, data, op, tracing)
# TODO: remove after refactor of checkGraphModeOp
if quantized:
FileCheck().check_not("aten::add").check_not("aten::add_").run(
m.graph
)
else:
FileCheck().check_not("quantized::add").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add_scalar(self):
class QuantizedAddScalar(torch.nn.Module):
def __init__(self):
super(QuantizedAddScalar, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return x + 3
class QuantizedInplaceAddScalar(torch.nn.Module):
def __init__(self):
super(QuantizedInplaceAddScalar, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x += 3
return x
class NonQuantizedAddScalar(torch.nn.Module):
def __init__(self):
super(NonQuantizedAddScalar, self).__init__()
def forward(self, x):
return x + 3
class NonQuantizedInplaceAddScalar(torch.nn.Module):
def __init__(self):
super(NonQuantizedInplaceAddScalar, self).__init__()
def forward(self, x):
x += 3
return x
data = [[torch.randn(1, 2, 3, 3, dtype=torch.float)]]
for m, quantized in [
(QuantizedAddScalar(), True),
(QuantizedInplaceAddScalar(), True),
(NonQuantizedAddScalar(), False),
(NonQuantizedInplaceAddScalar(), False),
]:
for tracing in [True, False]:
op = "quantized::add_scalar" if quantized else "aten::add"
# we don't check the numerical consistency for add_scalar
# since it's not supported
m = self.checkGraphModeOp(m, data, op, tracing, check=False)
# TODO: remove after refactor of checkGraphModeOp
if quantized:
FileCheck().check_not("aten::add").check_not("aten::add_").run(
m.graph
)
else:
FileCheck().check_not("quantized::add_scalar").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_add_relu(self):
class AddRelu(torch.nn.Module):
def __init__(self, inplace):
super(AddRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
return self.relu(x)
class InplaceAddRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceAddRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
return self.relu(x)
class AddFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
return F.relu(x)
class InplaceAddFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
return F.relu(x)
class AddInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x + y
return F.relu(x, True)
class InplaceAddInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x += y
return F.relu(x, True)
data = [
[
torch.rand((1, 2, 5, 5), dtype=torch.float),
torch.rand((1, 2, 5, 5), dtype=torch.float),
]
]
for m in [
AddRelu(True),
AddRelu(False),
InplaceAddRelu(True),
InplaceAddRelu(False),
AddFunctionalRelu(),
InplaceAddFunctionalRelu(),
AddInplaceFunctionalRelu(),
InplaceAddInplaceFunctionalRelu(),
]:
for tracing in [True, False]:
m = self.checkGraphModeOp(m, data, "quantized::add_relu(", tracing)
FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::add(").check_not(
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM
def test_quantized_add_scalar_relu(self):
class AddScalarRelu(torch.nn.Module):
def __init__(self, inplace):
super(AddScalarRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
x = self.conv(x)
return self.relu(x + 3)
class InplaceAddScalarRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceAddScalarRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
x = self.conv(x)
x += 3
return self.relu(x)
class AddScalarFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddScalarFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return F.relu(x + 3)
class InplaceAddScalarFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddScalarFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x += 3
return F.relu(x)
class AddScalarInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(AddScalarInplaceFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return F.relu(x + 3, True)
class InplaceAddScalarInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceAddScalarInplaceFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x += 3
return F.relu(x, True)
data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]]
for m in [
AddScalarRelu(True),
AddScalarRelu(False),
InplaceAddScalarRelu(True),
InplaceAddScalarRelu(False),
AddScalarFunctionalRelu(),
InplaceAddScalarFunctionalRelu(),
AddScalarInplaceFunctionalRelu(),
InplaceAddScalarInplaceFunctionalRelu(),
]:
for tracing in [True, False]:
# quantized::add_scalar_relu or quantized::add_scalar_relu_out
# TODO: split this after refactor of checkGraphModeOp
m = self.checkGraphModeOp(
m, data, "quantized::add_scalar_relu", tracing, check=False
)
FileCheck().check_not("aten::add(").check_not("aten::add_(").check_not(
"aten::relu("
).check_not("aten::relu_(").check_not(
"quantized::add_scalar("
).check_not(
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM
def test_quantized_cat(self):
"""quantization of the output of cat will be depend on the
input of cat. we only quantize the output of cat when its inputs are quantized.
"""
class QuantizedCat(torch.nn.Module):
def __init__(self):
super(QuantizedCat, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return torch.cat([x, y], 1)
class NonQuantizedCat(torch.nn.Module):
def __init__(self):
super(NonQuantizedCat, self).__init__()
def forward(self, x, y):
return torch.cat([x, y], 1)
data = [
[
torch.randn(1, 2, 5, 5, dtype=torch.float),
torch.randn(1, 2, 5, 5, dtype=torch.float),
]
]
for tracing in [True, False]:
m = self.checkGraphModeOp(QuantizedCat(), data, "quantized::cat", tracing)
FileCheck().check_not("aten::cat").run(m.graph)
m = self.checkGraphModeOp(NonQuantizedCat(), data, "aten::cat", tracing)
FileCheck().check_not("quantized::cat").run(m.graph)
@skipIfNoFBGEMM
def test_qbatch_norm(self):
bn_module = {
1: torch.nn.BatchNorm1d,
2: torch.nn.BatchNorm2d,
3: torch.nn.BatchNorm3d,
}
class M(torch.nn.Module):
def __init__(self, dim):
super(M, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return self.bn(x)
options = itertools.product([True, False], [1, 2, 3])
for tracing, dim in options:
model = self.checkGraphModeOp(
M(dim), self.img_data_dict[dim], "quantized::batch_norm", tracing
)
FileCheck().check_not("aten::batch_norm").run(model.graph)
@skipIfNoFBGEMM
def test_qbatch_norm_relu_BNRelu(self):
bn_module = {2: torch.nn.BatchNorm2d, 3: torch.nn.BatchNorm3d}
class BNRelu(torch.nn.Module):
def __init__(self, dim, inplace):
super(BNRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
self.relu = torch.nn.ReLU(inplace=inplace)
def forward(self, x):
return self.relu(self.bn(x))
options = itertools.product([True, False], [2, 3])
for tracing, dim in options:
for instance in [BNRelu(dim, True), BNRelu(dim, False)]:
model = self.checkGraphModeOp(instance, self.img_data_dict[dim],
"quantized::batch_norm_relu", tracing)
FileCheck().check_not("aten::batch_norm") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.run(model.graph)
@skipIfNoFBGEMM
def test_qbatch_norm_relu_BNFuncRelu(self):
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
class BNFuncRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), False)
options = itertools.product([True, False], [2, 3])
for tracing, dim in options:
instance = BNFuncRelu(dim)
model = self.checkGraphModeOp(instance, self.img_data_dict[dim],
"quantized::batch_norm_relu", tracing)
FileCheck().check_not("aten::batch_norm") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.run(model.graph)
@skipIfNoFBGEMM
def test_qbatch_norm_relu_BNFuncInplaceRelu(self):
bn_module = {2 : torch.nn.BatchNorm2d, 3 : torch.nn.BatchNorm3d}
class BNFuncInplaceRelu(torch.nn.Module):
def __init__(self, dim):
super(BNFuncInplaceRelu, self).__init__()
self.bn = bn_module[dim](3).to(torch.float)
def forward(self, x):
return F.relu(self.bn(x), True)
options = itertools.product([True, False], [2, 3])
for tracing, dim in options:
instance = BNFuncInplaceRelu(dim)
model = self.checkGraphModeOp(instance, self.img_data_dict[dim],
"quantized::batch_norm_relu", tracing)
FileCheck().check_not("aten::batch_norm") \
.check_not("aten::relu") \
.check_not("aten::relu_") \
.run(model.graph)
@skipIfNoFBGEMM
def test_quantized_mul(self):
class QuantizedMul(torch.nn.Module):
def __init__(self):
super(QuantizedMul, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
return x * y
class QuantizedInplaceMul(torch.nn.Module):
def __init__(self):
super(QuantizedInplaceMul, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x *= y
return x
class NonQuantizedMul(torch.nn.Module):
def __init__(self):
super(NonQuantizedMul, self).__init__()
def forward(self, x, y):
return x * y
class NonQuantizedInplaceMul(torch.nn.Module):
def __init__(self):
super(NonQuantizedInplaceMul, self).__init__()
def forward(self, x, y):
x *= y
return x
data = [
[
torch.randn(1, 2, 10, 10, dtype=torch.float),
torch.randn(1, 2, 10, 10, dtype=torch.float),
]
]
for m, quantized in [
(QuantizedMul(), True),
(QuantizedInplaceMul(), True),
(NonQuantizedMul(), False),
(NonQuantizedInplaceMul(), False),
]:
for tracing in [True, False]:
op = "quantized::mul" if quantized else "aten::mul"
m = self.checkGraphModeOp(m, data, op, tracing)
# TODO: remove after refactor of checkGraphModeOp
if quantized:
FileCheck().check_not("aten::mul").check_not("aten::mul_").run(
m.graph
)
else:
FileCheck().check_not("quantized::mul").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_mul_scalar(self):
class QuantizedMulScalar(torch.nn.Module):
def __init__(self):
super(QuantizedMulScalar, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return x * 3
class QuantizedInplaceMulScalar(torch.nn.Module):
def __init__(self):
super(QuantizedInplaceMulScalar, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x *= 3
return x
class NonQuantizedMulScalar(torch.nn.Module):
def __init__(self):
super(NonQuantizedMulScalar, self).__init__()
def forward(self, x):
return x * 3
class NonQuantizedInplaceMulScalar(torch.nn.Module):
def __init__(self):
super(NonQuantizedInplaceMulScalar, self).__init__()
def forward(self, x):
x *= 3
return x
data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]]
for m, quantized in [
(QuantizedMulScalar(), True),
(QuantizedInplaceMulScalar(), True),
(NonQuantizedMulScalar(), False),
(NonQuantizedInplaceMulScalar(), False),
]:
for tracing in [True, False]:
op = "quantized::mul_scalar" if quantized else "aten::mul"
# we don't check the numerical consistency for add_scalar
# since it's not supported
m = self.checkGraphModeOp(m, data, op, tracing, check=False)
# TODO: remove after refactor of checkGraphModeOp
if quantized:
FileCheck().check_not("aten::mul").check_not("aten::mul_").run(
m.graph
)
else:
FileCheck().check_not("quantized::mul_scalar").run(m.graph)
@skipIfNoFBGEMM
def test_quantized_mul_relu(self):
class MulRelu(torch.nn.Module):
def __init__(self, inplace):
super(MulRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x * y
return self.relu(x)
class InplaceMulRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceMulRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x *= y
return self.relu(x)
class MulFunctionalRelu(torch.nn.Module):
def __init__(self):
super(MulFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x * y
return F.relu(x)
class InplaceMulFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceMulFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x *= y
return F.relu(x)
class MulInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(MulInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x = x * y
return F.relu(x, True)
class InplaceMulInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceMulInplaceFunctionalRelu, self).__init__()
self.conv1 = torch.nn.Conv2d(2, 2, 2).float()
self.conv2 = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x, y):
x = self.conv1(x)
y = self.conv2(y)
x *= y
return F.relu(x, True)
data = [
[
torch.rand((1, 2, 5, 5), dtype=torch.float),
torch.rand((1, 2, 5, 5), dtype=torch.float),
]
]
for m in [
MulRelu(True),
MulRelu(False),
InplaceMulRelu(True),
InplaceMulRelu(False),
MulFunctionalRelu(),
InplaceMulFunctionalRelu(),
MulInplaceFunctionalRelu(),
InplaceMulInplaceFunctionalRelu(),
]:
for tracing in [True, False]:
m = self.checkGraphModeOp(m, data, "quantized::mul_relu(", tracing)
FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not(
"aten::relu("
).check_not("aten::relu_(").check_not("quantized::mul(").check_not(
"quantized::relu("
).run(
m.graph
)
@skipIfNoFBGEMM
def test_quantized_mul_scalar_relu(self):
class MulScalarRelu(torch.nn.Module):
def __init__(self, inplace):
super(MulScalarRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
x = self.conv(x)
return self.relu(x * 3)
class InplaceMulScalarRelu(torch.nn.Module):
def __init__(self, inplace):
super(InplaceMulScalarRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu = torch.nn.ReLU(inplace)
def forward(self, x):
x = self.conv(x)
x *= 3
return self.relu(x)
class MulScalarFunctionalRelu(torch.nn.Module):
def __init__(self):
super(MulScalarFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return F.relu(x * 3)
class InplaceMulScalarFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceMulScalarFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x *= 3
return F.relu(x)
class MulScalarInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(MulScalarInplaceFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
return F.relu(x * 3, True)
class InplaceMulScalarInplaceFunctionalRelu(torch.nn.Module):
def __init__(self):
super(InplaceMulScalarInplaceFunctionalRelu, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
def forward(self, x):
x = self.conv(x)
x *= 3
return F.relu(x, True)
data = [[torch.randn(1, 2, 5, 5, dtype=torch.float)]]
for m in [
MulScalarRelu(True),
MulScalarRelu(False),
InplaceMulScalarRelu(True),
InplaceMulScalarRelu(False),
MulScalarFunctionalRelu(),
InplaceMulScalarFunctionalRelu(),
MulScalarInplaceFunctionalRelu(),
InplaceMulScalarInplaceFunctionalRelu(),
]:
for tracing in [True, False]:
# quantized::mul_scalar_relu or quantized::mul_scalar_relu_out
m = self.checkGraphModeOp(
m, data, "quantized::mul_scalar_relu", tracing, check=False
)
FileCheck().check_not("aten::mul(").check_not("aten::mul_(").check_not(
"aten::relu("
).check_not("aten::relu_(").check_not(
"quantized::mul_scalar("
).check_not(
"quantized::relu("
).run(
m.graph
)
def test_hardswish(self):
class FunctionalHardswish(torch.nn.Module):
def __init__(self, inplace):
super(FunctionalHardswish, self).__init__()
self.inplace = inplace
def forward(self, input):
return torch.nn.functional.hardswish(input, inplace=self.inplace)
modules = [
torch.nn.Hardswish(),
FunctionalHardswish(True),
FunctionalHardswish(False),
]
for test_case in itertools.product([True, False], modules):
tracing, m = test_case
m = self.checkGraphModeOp(
m, self.img_data_2d, "quantized::hardswish", tracing
)
FileCheck().check_not("aten::hardswish").check_not("aten::hardswish_").run(
m.graph
)
def test_elu(self):
class FunctionalELU(torch.nn.Module):
def __init__(self, inplace=False):
super(FunctionalELU, self).__init__()
self.inplace = inplace
def forward(self, input):
return torch.nn.functional.elu(input, inplace=self.inplace)
modules = [torch.nn.ELU, FunctionalELU]
for test_case in itertools.product([True, False], [True, False], modules):
tracing, inplace, mod_class = test_case
m = mod_class(inplace=inplace)
m = self.checkGraphModeOp(m, self.img_data_2d, "quantized::elu", tracing)
FileCheck().check_not("aten::elu").check_not("aten::elu_").run(m.graph)
def test_layer_norm(self):
data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)] for _ in range(2)]
layer_norm = torch.nn.LayerNorm([2, 5, 5])
for tracing in [True, False]:
m = self.checkGraphModeOp(
layer_norm, data, "quantized::layer_norm", tracing
)
FileCheck().check_not("aten::layer_norm").run(m.graph)
def test_group_norm(self):
data = [[torch.rand((1, 4, 5, 5), dtype=torch.float)] for _ in range(2)]
group_norm = torch.nn.GroupNorm(2, 4)
for tracing in [True, False]:
m = self.checkGraphModeOp(
group_norm, data, "quantized::group_norm", tracing
)
FileCheck().check_not("aten::group_norm").run(m.graph)
def test_instance_norm(self):
data_1d = [[torch.rand((1, 4, 5), dtype=torch.float)] for _ in range(2)]
data_2d = [[torch.rand((1, 4, 5, 1), dtype=torch.float)] for _ in range(2)]
data_3d = [[torch.rand((1, 4, 5, 1, 1), dtype=torch.float)] for _ in range(2)]
data = {1: data_1d, 2: data_2d, 3: data_3d}
instance_norm_modules = {
1: torch.nn.InstanceNorm1d,
2: torch.nn.InstanceNorm2d,
3: torch.nn.InstanceNorm3d,
}
options = itertools.product([1, 2, 3], [True, False])
for dim, tracing in options:
instance_norm = instance_norm_modules[dim](4)
m = self.checkGraphModeOp(
instance_norm, data[dim], "quantized::instance_norm", tracing
)
FileCheck().check_not("aten::instance_norm").run(m.graph)
@skipIfNoFBGEMM
def test_dequantize_tuple(self):
"""Make sure dequantize can support Tuple of tensor"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv1 = torch.nn.Conv2d(3, 3, 3).float()
self.conv2 = torch.nn.Conv2d(3, 3, 3).float()
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
x1 = self.conv1(x)
x2 = self.conv2(x)
return x1, x2
for tracing in [True, False]:
self.checkGraphModeOp(M(), self.img_data_2d, "quantized::conv2d", tracing)
@skipIfNoFBGEMM
def test_clamp(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(2, 2, 2).float()
self.relu6 = torch.nn.ReLU6()
self.relu6_ = torch.nn.ReLU6(True)
self.hardtanh = torch.nn.Hardtanh()
self.hardtanh_ = torch.nn.Hardtanh(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.relu6(x)
self.relu6_(x)
x = F.relu6(x)
x = torch.clamp(x, -3, 3)
x = x.clamp(-2.5, 2.5)
# x = x.clamp_(-2, 2) # Enable when quantized `clamp_` is ready
x = self.hardtanh(x)
self.hardtanh_(x)
x = F.hardtanh(x)
F.hardtanh_(x)
return x
data = [[torch.rand((1, 2, 5, 5), dtype=torch.float)]]
options = itertools.product(
["aten::clamp", "aten::hardtanh", "aten::hardtanh_"], [True, False]
)
for op, tracing in options:
m = self.checkGraphModeOp(M(), data, op, tracing)
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
m.graph
)
FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
def test_general_shape_ops(self):
"""A test that checks dequantize will be swapped for
all supported general shape ops like aten::flatten
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.maxpool1d = torch.nn.MaxPool1d(kernel_size=3)
self.maxpool2d = torch.nn.MaxPool2d(kernel_size=3)
self.maxpool3d = torch.nn.MaxPool3d(kernel_size=3)
self.dropout = torch.nn.Dropout()
self.conv1 = torch.nn.Conv2d(3, 3, 3)
self.conv2 = torch.nn.Conv2d(3, 3, 3)
self.relu = torch.nn.ReLU()
def forward(self, x):
x = self.conv1(x)
# add_scalar
x = x + 3
# mul_scalar
x = x * 3
# add_scalar_out
x += 3
# mul_scalar_out
x *= 3
# add_scalar_relu
x = x + 3
x = F.relu(x)
# add_scalar_relu_out
x += 3
x = F.relu(x)
# mul_scalar_relu
x = x * 3
x = F.relu(x)
# mul_scalar_relu_out
x *= 3
x = F.relu(x)
x = self.maxpool1d(x)
x = self.maxpool2d(x)
x = self.maxpool3d(x)
x = torch.flatten(x)
x = torch.max(x)
x = torch.min(x)
x = x.reshape([-1])
x = x.resize_(1, 1, x.numel())
x = x.view(-1)
# prim::ListConstruct
xs = [x, x]
# prim::ListUnpack
x, y = xs
# prim::TupleConstruct
xs = (x, x)
# prim::TupleUnpack
x, y = xs
x = x.transpose(1, 2)
x = x.contiguous()
x, y = torch.chunk(x, 2)
x = F.dropout(x)
x = self.dropout(x)
x, _ = torch.sort(x)
x = x.permute(0, 2, 3, 1)
x = torch.repeat_interleave(x, 3, 1)
x = self.relu(x)
x = F.relu(x)
x.relu_()
x = x.squeeze(0)
x.squeeze_(0)
x = torch.squeeze(x, 0)
x = x.unsqueeze(0)
x.unsqueeze_(0)
x = torch.unsqueeze(x, 0)
x = x.detach()
x.detach_()
x = x.repeat(4, 2)
y = []
y.append(x)
z = torch.stack(y, 0)
z = [z, z]
x, _ = z
x = self.conv2(x)
return x
data = torch.rand(1, 3, 10, 10)
# This model is not executable since we just put all ops
# in the same forward, therefore we only test scripting
m = torch.jit.script(M())
qconfig = script_qconfig(default_qconfig)
# dummy data to suppress warning
get_forward(qconfig.activation)(data)
get_forward(qconfig.weight)(data)
m = wrap_cpp_module(
torch._C._jit_pass_insert_observers(
m._c, "forward", {"": qconfig}, inplace=False
)
)
m = convert_jit(m)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
FileCheck().check_count("aten::quantize_per_tensor", 1, exactly=True).run(
m.graph
)
FileCheck().check_count("quantized::conv2d(", 2, exactly=True).run(m.graph)
FileCheck().check_count("aten::dequantize", 1, exactly=True).run(m.graph)
FileCheck().check("quantized::add_scalar").check("quantized::mul_scalar").run(
m.graph
)
def test_general_value_ops(self):
""" A test that checks correct patterns are produced for
all supported general value ops like aten::avg_pool2d \
without actually checking for execution of these ops
"""
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
self.avg_pool1d = torch.nn.AvgPool1d(3)
self.avg_pool2d = torch.nn.AvgPool2d(3)
self.avg_pool3d = torch.nn.AvgPool3d(3)
self.adaptive_avg_pool1d = torch.nn.AdaptiveAvgPool1d((1))
self.adaptive_avg_pool2d = torch.nn.AdaptiveAvgPool2d((1, 1))
self.adaptive_avg_pool3d = torch.nn.AdaptiveAvgPool3d((1, 1, 1))
self.leaky_relu = torch.nn.LeakyReLU()
self.hardsigmoid = torch.nn.Hardsigmoid()
self.sigmoid = torch.nn.Sigmoid()
self.tanh = torch.nn.Tanh()
def forward(self, x):
x = self.conv(x)
x = self.avg_pool1d(x)
x = self.avg_pool2d(x)
x = self.avg_pool3d(x)
x = self.adaptive_avg_pool1d(x)
x = self.adaptive_avg_pool2d(x)
x = self.adaptive_avg_pool3d(x)
x = F.avg_pool1d(x, 3)
x = F.avg_pool2d(x, 3)
x = F.avg_pool3d(x, 3)
x = F.adaptive_avg_pool1d(x, (1))
x = F.adaptive_avg_pool2d(x, (1, 1))
x = F.adaptive_avg_pool3d(x, (1, 1, 1))
x = torch.mean(x)
x = torch.mean(x, [2, 3], False)
x = x.mean()
x = x.mean([2, 3], True)
# interpolate node will introduce 3 quantize_per_tensor ops
x = F.interpolate(x, 4, mode="nearest") # interpolate node
x = F.upsample(x, (32, 32)) # interpolate node
x = F.upsample_nearest(x, (32, 32)) # interpolate node
x = F.interpolate(x, 4, mode="linear") # common node
x = F.upsample_bilinear(x, (32, 32)) # common node
x = self.leaky_relu(x)
x = F.leaky_relu(x)
x.leaky_relu_()
x = self.hardsigmoid(x)
x = F.hardsigmoid(x)
x.hardsigmoid_()
x = self.sigmoid(x)
x = torch.sigmoid(x)
# F.sigmoid is deprecated
x = x.sigmoid()
x.sigmoid_()
x = self.tanh(x)
# F.tanh is deprecated
x = torch.tanh(x)
x = x.tanh()
x.tanh_()
x = self.conv(x)
return x
# This model is not executable since we just put all ops
# in the same forward, therefore we only test scripting
m = torch.jit.script(M())
qconfig = script_qconfig(default_qconfig)
# dummy data to suppress warning
data = torch.rand(1, 3, 10, 10)
get_forward(qconfig.activation)(data)
get_forward(qconfig.weight)(data)
m = wrap_cpp_module(
torch._C._jit_pass_insert_observers(
m._c, "forward", {"": qconfig}, inplace=False
)
)
# Checking the model before fianlize contain unfused patterns
# that numerically matches the model after quantize by checking
# number of aten::quantize_per_tensor functions
# conv has 3 quantize_per_tensor for activations and 1 for weight
# and for N general value op between conv we should have
# N + 1 quantize_per_tensor between these ops
m1 = convert_jit(m, debug=True)
# NB: This Needs to be updated when we add more ops to test
# mapping from number of quant for the op to the number of these ops
# for example, for `3` in the key means for this type of op
# we'll have 3 quantize_per_tensor
num_op_by_num_quant = {1: 32, 2: 2, 3: 3}
num_quantize_per_tensor = 1 # for output
for num_quant, num_op in num_op_by_num_quant.items():
num_quantize_per_tensor += num_op * num_quant
num_quantize_per_tensor -= 4 # constant propagation removes some prepacks
FileCheck().check_count(
"aten::quantize_per_tensor(", num_quantize_per_tensor, exactly=True
).run(m1.graph)
# This checks that the dequantize from the output of first conv
# is being propagated to the end, so that we don't insert extra
# observers and also successfully fused two quantized::conv2d
# patterns
# one quantize_per_tensor for input
m2 = convert_jit(m, debug=False)
FileCheck().check_count("aten::quantize_per_tensor(", 1, exactly=True).run(
m2.graph
)
FileCheck().check_count("quantized::conv2d(", 2, exactly=True).check(
"aten::dequantize("
).run(m2.graph)
@override_qengines
def test_conv_with_benchmark_flag(self):
r"""Verifies that convolutions get quantized when
torch.backends.cudnn.benchmark is enabled
"""
if not qengine_is_qnnpack():
return
with torch.backends.cudnn.flags(enabled=True):
m = torch.nn.Sequential(torch.nn.Conv2d(1, 1, 1))
m.eval()
m = torch.jit.trace(m, torch.rand(4, 1, 4, 4))
qconfig = torch.quantization.get_default_qconfig("qnnpack")
prepared_model = torch.quantization.prepare_jit(m, {"": qconfig})
prepared_model(torch.rand(4, 1, 4, 4))
converted_model = torch.quantization.convert_jit(prepared_model)
FileCheck().check("quantized::conv2d").run(converted_model.graph)
@skipIfNoFBGEMM
def test_cat_linear(self):
class LinearModel(torch.nn.Module):
def __init__(self):
super(LinearModel, self).__init__()
self.weight = torch.randn(5, 5)
def forward(self, x, y):
a = torch.cat([x, y])
b = F.linear(a, self.weight)
c = F.linear(b, self.weight)
return b, c
model = LinearModel().eval()
qconfig = {"": default_qconfig}
float_model = torch.jit.script(model)
prepared_model = prepare_jit(float_model, qconfig)
prepared_model(torch.rand(5, 5), torch.rand(5, 5))
converted_model = convert_jit(prepared_model)
FileCheck().check("quantized::linear").check("quantized::linear").run(
converted_model.graph
)
class TestQuantizeDynamicJitPasses(QuantizationTestCase):
def test_prepare_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
model = torch.jit.script(M())
for qconfig in [float16_dynamic_qconfig, default_dynamic_qconfig]:
m = prepare_dynamic_jit(model, {"": qconfig})
# observer for weight
assert len(attrs_with_prefix(m.fc, "_observer_")) == 1
if qconfig == float16_dynamic_qconfig:
observer_name = 'PlaceholderObserver = prim::GetAttr[name="_observer_'
FileCheck().check(observer_name).run(m.fc.graph)
else:
# for input of FC for dynamic quant
assert len(attrs_with_prefix(m, "_observer_")) == 1
observer_name = 'Observer = prim::GetAttr[name="_observer_'
FileCheck().check(observer_name).check(
'prim::GetAttr[name="fc"]'
).check("prim::CallMethod").check_not(observer_name).run(m.graph)
def test_prepare_dynamic_child_qconfig(self):
class Sub(torch.nn.Module):
def __init__(self):
super(Sub, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.conv = torch.nn.Conv2d(3, 5, 3)
self.sub = Sub()
def forward(self, x):
return self.sub(self.conv(x))
m = torch.jit.script(M())
# only quantize child module.
m = prepare_dynamic_jit(m, {"sub.fc": default_dynamic_qconfig})
# input of sub for dynamic quant
assert len(attrs_with_prefix(m, "_observer_")) == 1
# not quantized
assert len(attrs_with_prefix(m.conv, "_observer_")) == 0
# no observers since we observe in the outer most call site
assert len(attrs_with_prefix(m.sub, "_observer_")) == 0
# weight of linear
assert len(attrs_with_prefix(m.sub.fc, "_observer_")) == 1
FileCheck().check('prim::GetAttr[name="sub').check("prim::CallMethod").check(
'Observer = prim::GetAttr[name="_observer_'
).check("prim::CallMethod").check_not(
'Observer = prim::GetAttr[name="_observer_'
).run(
m.graph
)
def test_insert_quant_dequant_linear_dynamic(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).float()
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
x = self.fc1(x)
return self.fc2(x)
for is_per_channel in [True, False]:
m = torch.jit.script(M())
qconfig = (
per_channel_dynamic_qconfig
if is_per_channel is True
else default_dynamic_qconfig
)
m = quantize_dynamic_jit(m, {"": qconfig}, debug=True)
assert (
len(m._modules._c.items()) == 2
), "Expected to have two submodule of linear"
wt_quant_func = (
"aten::quantize_per_channel"
if is_per_channel
else "aten::quantize_per_tensor"
)
act_quant_func = "aten::quantize_per_tensor"
# quantizing activations
FileCheck().check("aten::_choose_qparams_per_tensor").check_next(
act_quant_func
).check_next("aten::dequantize").check(
"aten::_choose_qparams_per_tensor"
).check_next(
act_quant_func
).check_next(
"aten::dequantize"
).check(
wt_quant_func
).check_next(
"aten::dequantize"
).check_not(
wt_quant_func
).check(
"return"
).run(
m.graph
)
@override_qengines
def test_dynamic_multi_op(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc1 = torch.nn.Linear(5, 5).to(dtype=torch.float)
def forward(self, x):
x = x + 5
return self.fc1(x)
x = torch.randn(5, 5)
for tracing in [True, False]:
model = self.checkGraphModeOp(
M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True
)
# add op is not dynamically quantized.
FileCheck().check("aten::add").run(model.graph)
@override_qengines
def test_dynamic_quant_multi_uses(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
def forward(self, x):
size1 = x.size()
size2 = x.size()
return self.fc(x), size1, size2
x = torch.randn(5, 5)
for tracing in [True, False]:
model = self.checkGraphModeOp(
M(), x, "quantized::linear_dynamic", tracing=tracing, dynamic=True
)
FileCheck().check_not("aten::_choose_qparams_per_tensor").run(model.graph)
@override_qengines
def test_dynamic_shared_weights(self):
class myMod(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.linear = nn.Linear(5, 5)
self.linear.weight = weight
def forward(self, x):
return self.linear(x)
class DynamicModel(torch.nn.Module):
def __init__(self):
super(DynamicModel, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(5, 5))
self.mod1 = myMod(self.weight)
def forward(self, x):
y = self.mod1(x)
z = torch.nn.functional.linear(y, self.weight)
return z
model = torch.jit.script(DynamicModel()).eval()
data = torch.randn(5, 5, dtype=torch.float)
quant_ops = ["mod1", ""]
counts = [1, 2]
for op, count in zip(quant_ops, counts):
qconfig_dict = {op: default_dynamic_qconfig}
m1 = quantize_dynamic_jit(model, qconfig_dict)
out_graph = m1(data)
FileCheck().check_count(
"quantized::linear_dynamic(", count, exactly=True
).check_not("aten::_choose_qparams_per_tensor").run(m1.graph)
# Explicitly call forward on model before convert
m2 = prepare_dynamic_jit(model, qconfig_dict)
m2(data)
m2 = convert_dynamic_jit(m2, debug=False)
out_ref = m2(data)
self.assertEqual(out_graph, out_ref)
@override_qengines
def test_dynamic_with_if(self):
class Res(torch.nn.Module):
def __init__(self):
super(Res, self).__init__()
self.weight = torch.nn.Parameter(torch.ones(5, 5))
def forward(self, x: torch.Tensor, cond: bool) -> torch.Tensor:
if cond:
return torch.nn.functional.linear(x, self.weight)
else:
return torch.nn.functional.linear(x, self.weight)
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.res1 = Res()
self.res2 = Res()
def forward(self, x):
x = self.res1(x, True)
x = self.res2(x, False)
return x
model = torch.jit.script(M()).eval()
data = torch.randn(5, 5, dtype=torch.float)
qconfig_dict = {"": default_dynamic_qconfig}
for tracing in [True, False]:
m1 = self.checkGraphModeOp(
M(), data, "quantized::linear_dynamic", tracing=tracing, dynamic=True
)
FileCheck().check_count(
"quantized::linear_dynamic(", 2, exactly=True
).check_not("aten::_choose_qparams_per_tensor").run(m1.graph)
# Check to make sure weight observers run correctly
ref_qparams = []
qconfig = script_qconfig(default_dynamic_qconfig)
wt_module = wrap_cpp_module(qconfig.weight)
for wt in [model.res1.weight, model.res2.weight]:
wt_module(wt)
qparams = wt_module.calculate_qparams()
ref_qparams.append((qparams[0].item(), qparams[1].item()))
m2 = quantize_dynamic_jit(model, qconfig_dict, debug=True)
graph_params = []
for x, obs in m2._modules._c.items():
if x == "res1":
graph_params.append(
(obs.getattr("weight.2_scale_0"), obs.getattr("weight.2_zero_point_0"))
)
elif x == "res2":
graph_params.append(
(obs.getattr("weight.4_scale_0"), obs.getattr("weight.4_zero_point_0"))
)
self.assertEqual(ref_qparams, graph_params)
def test_dynamic_weight_observer(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5).float()
self.fc2 = torch.nn.Linear(5, 5).float()
def forward(self, x):
x = self.fc(x)
return self.fc2(x)
qconfig_dict = {"": default_dynamic_qconfig}
eager_model = M().eval()
for tracing in [True, False]:
x = torch.rand(5, 5)
model = get_script_module(eager_model, tracing, x)
ref_qparams = []
for wt in [model.fc.weight, model.fc2.weight]:
wt_module = default_dynamic_qconfig.weight()
wt_module(wt)
qparams = wt_module.calculate_qparams()
ref_qparams.append((qparams[0].item(), qparams[1].item()))
model = quantize_dynamic_jit(model, qconfig_dict, debug=True)
graph_qparams = []
for x, obs in model._modules._c.items():
if x == 'fc' and tracing:
graph_qparams.append(
(obs.getattr("weight.6_scale_0"), obs.getattr("weight.6_zero_point_0"))
)
else:
graph_qparams.append(
(obs.getattr("weight.1_scale_0"), obs.getattr("weight.1_zero_point_0"))
)
self.assertEqual(ref_qparams, graph_qparams)
def test_convert_dynamic_fp16(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
m = torch.jit.script(M())
m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig}, debug=True)
FileCheck().check("aten::_saturate_weight_to_fp16").check(
"aten::linear"
).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph)
def test_quantize_dynamic_fp16(self):
class M(torch.nn.Module):
def __init__(self):
super(M, self).__init__()
self.fc = torch.nn.Linear(5, 5)
def forward(self, x):
return self.fc(x)
m = torch.jit.script(M())
m = quantize_dynamic_jit(m, {"": float16_dynamic_qconfig})
FileCheck().check("quantized::linear_dynamic_fp16").check_not(
"aten::linear"
).check_not("aten::dequantize").check_not("aten::quantize").run(m.graph)
class TestQuantizeDynamicJitOps(QuantizationTestCase):
"""Test graph mode post training dynamic quantization works
for individual ops end to end.
"""
@override_qengines
def test_linear(self):
class FunctionalLinear(torch.nn.Module):
def __init__(self, weight, bias):
super(FunctionalLinear, self).__init__()
self.weight = weight
self.bias = bias
def forward(self, x):
return F.linear(x, self.weight, self.bias)
x = torch.rand(5, 5)
for tracing in [True, False]:
model = self.checkGraphModeOp(
torch.nn.Linear(5, 5),
x,
"quantized::linear_dynamic",
tracing=tracing,
dynamic=True,
)
weight = torch.rand(5, 5)
b = torch.rand(5)
for tracing, has_bias in itertools.product([True, False], [True, False]):
bias = b if has_bias else None
model = self.checkGraphModeOp(
FunctionalLinear(weight, bias),
x,
"quantized::linear_dynamic",
tracing=tracing,
dynamic=True,
)
@skipIfNoFBGEMM
def test_embedding_bag(self):
class M(torch.nn.Module):
def __init__(self, weights):
super(M, self).__init__()
self.embedding1 = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=True,
_weight=weights,
mode="sum",
)
self.embedding2 = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=True,
_weight=weights,
mode="sum",
)
def forward(self, indices1, offsets1, indices2, offsets2):
e1 = self.embedding1(indices1, offsets1)
e2 = self.embedding2(indices2, offsets2)
return e1, e2
weights = torch.randn(10, 12, dtype=torch.float32)
module = M(weights)
indices = torch.tensor(
[
9,
6,
5,
7,
8,
8,
9,
2,
8,
6,
6,
9,
1,
6,
8,
8,
3,
2,
3,
6,
3,
6,
5,
7,
0,
8,
4,
6,
5,
8,
2,
3,
]
)
offsets = torch.tensor([0, 19, 20, 28, 28, 32])
dummy_inputs = (indices, offsets, indices, offsets)
for trace in [True, False]:
if trace:
m = torch.jit.trace(module, dummy_inputs)
else:
m = torch.jit.script(module)
int4_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_4bit"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_4bit"
),
)
int8_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_byte"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_byte"
),
)
m = prepare_jit(m, {"embedding1": int4_qconfig, "embedding2": int8_qconfig})
m = convert_jit(m)
FileCheck().check("quantized::embedding_bag_4bit_rowwise_offsets").check(
"quantized::embedding_bag_byte_rowwise_offsets"
).run(m.graph)
m(*dummy_inputs)
# Ensure that attempting to quantize an EmbeddingBag throws an error if
# padding_idx is not None
@skipIfNoFBGEMM
def test_embedding_bag_padding_idx_error(self):
class M(torch.nn.Module):
def __init__(self, weights):
super(M, self).__init__()
self.embedding = torch.nn.EmbeddingBag(
num_embeddings=10,
embedding_dim=12,
include_last_offset=True,
sparse=True,
_weight=weights,
mode="sum",
padding_idx=0,
)
def forward(self, indices, offsets):
e = self.embedding(indices, offsets)
return e
weights = torch.randn(10, 12, dtype=torch.float32)
module = M(weights)
indices = torch.tensor([0, 1, 2, 3, 4])
offsets = torch.tensor([0, 2, 5])
dummy_inputs = (indices, offsets)
int4_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_4bit"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_4bit"
),
)
int8_qconfig = QConfig(
activation=PlaceholderObserver.with_args(
dtype=torch.float, custom_op_name="embedding_bag_byte"
),
weight=PlaceholderObserver.with_args(
custom_op_name="embedding_bag_byte"
),
)
error_msg = r'Expected aten::embedding_bag padding_idx input to be None'
for trace, qconfig in itertools.product([True, False], [int4_qconfig, int8_qconfig]):
if trace:
m = torch.jit.trace(module, dummy_inputs)
else:
m = torch.jit.script(module)
m = prepare_jit(m, {"embedding": qconfig})
with self.assertRaisesRegex(RuntimeError, error_msg):
m = convert_jit(m)
class TestQuantizeJit(QuantizationTestCase):
@override_qengines
def test_single_linear(self):
r"""Compare the result of quantizing single linear layer in
eager mode and graph mode
"""
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel(
torch.backends.quantized.engine
).eval()
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(
annotated_linear_model.fc1.module.weight.detach()
)
linear_model.fc1.bias = torch.nn.Parameter(
annotated_linear_model.fc1.module.bias.detach()
)
model_eager = quantize(
annotated_linear_model, test_only_eval_fn, [self.calib_data]
)
qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False,
)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
@skipIfNoFBGEMM
def test_observer_with_ignored_function(self):
r"""Test observers with ignored function and make sure it works in
graph mode
"""
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel("fbgemm").eval()
for qconfig in [
QConfig(activation=default_observer, weight=default_weight_observer),
QConfig(
activation=default_histogram_observer, weight=default_weight_observer
),
QConfig(
activation=default_observer, weight=default_per_channel_weight_observer
),
]:
annotated_linear_model.qconfig = qconfig
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(
annotated_linear_model.fc1.module.weight.detach()
)
linear_model.fc1.bias = torch.nn.Parameter(
annotated_linear_model.fc1.module.bias.detach()
)
model_eager = quantize(
annotated_linear_model, test_only_eval_fn, [self.calib_data]
)
qconfig_dict = {"": qconfig}
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False,
)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
@override_qengines
def test_conv(self):
r"""Compare the result of quantizing conv layer in
eager mode and graph mode
"""
# eager mode
annotated_conv_model = AnnotatedConvModel(
torch.backends.quantized.engine
).eval()
conv_model = ConvModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model.conv.weight = torch.nn.Parameter(
annotated_conv_model.conv.weight.detach()
)
model_eager = quantize(
annotated_conv_model, test_only_eval_fn, [self.img_data_2d]
)
qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
model_script = torch.jit.script(conv_model)
result_eager = model_eager(self.img_data_2d[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.img_data_2d],
inplace=False,
)
self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager)
@override_qengines
def test_conv_transpose(self):
r"""Compare the result of quantizing conv_transpose layer in
eager mode and graph mode
"""
if not qengine_is_qnnpack():
return # Currently only qnnpack is supported
# eager mode
annotated_conv_model = AnnotatedConvTransposeModel(
torch.backends.quantized.engine
).eval()
conv_model = ConvTransposeModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model.conv.weight = torch.nn.Parameter(
annotated_conv_model.conv.weight.detach()
)
model_eager = quantize(
annotated_conv_model, test_only_eval_fn, [self.img_data_2d]
)
qconfig_dict = {"": get_default_qconfig(torch.backends.quantized.engine)}
model_traced = torch.jit.trace(conv_model, self.img_data_2d[0][0])
model_script = torch.jit.script(conv_model)
result_eager = model_eager(self.img_data_2d[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.img_data_2d],
inplace=False,
)
self.assertEqual(model_quantized(self.img_data_2d[0][0]), result_eager)
@override_qengines
def test_conv_bn(self):
r"""Compare the result of quantizing conv + bn layer in
eager mode and graph mode
"""
# eager mode
conv_model = AnnotatedConvBnModel().eval()
conv_model_to_script = ConvBnModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
conv_model_to_script.conv.weight = torch.nn.Parameter(
conv_model.conv.weight.detach()
)
fuse_modules(conv_model, ["conv", "bn"], inplace=True)
model_eager = quantize(conv_model, test_only_eval_fn, [self.img_data_2d])
qconfig_dict = {"": default_qconfig}
model_script = quantize_jit(
torch.jit.script(conv_model_to_script),
qconfig_dict,
test_only_eval_fn,
[self.img_data_2d],
inplace=False,
)
result_eager = model_eager(self.img_data_2d[0][0])
result_script = model_script(self.img_data_2d[0][0])
self.assertEqual(result_eager, result_script)
@override_qengines
def test_nested(self):
# Eager mode
eager_model = AnnotatedNestedModel(torch.backends.quantized.engine).eval()
# Graph mode
script_model = NestedModel().eval()
# Copy weights for eager_model
script_model.sub1.fc.weight = torch.nn.Parameter(
eager_model.sub1.fc.weight.detach()
)
script_model.sub1.fc.bias = torch.nn.Parameter(
eager_model.sub1.fc.bias.detach()
)
script_model.sub2.fc1.weight = torch.nn.Parameter(
eager_model.sub2.fc1.module.weight.detach()
)
script_model.sub2.fc1.bias = torch.nn.Parameter(
eager_model.sub2.fc1.module.bias.detach()
)
script_model.sub2.fc2.weight = torch.nn.Parameter(
eager_model.sub2.fc2.weight.detach()
)
script_model.sub2.fc2.bias = torch.nn.Parameter(
eager_model.sub2.fc2.bias.detach()
)
script_model.fc3.weight = torch.nn.Parameter(
eager_model.fc3.module.weight.detach()
)
script_model.fc3.bias = torch.nn.Parameter(eager_model.fc3.module.bias.detach())
model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
qconfig_dict = {
"sub2.fc1": default_per_channel_qconfig
if qengine_is_fbgemm()
else default_qconfig,
"fc3": default_qconfig,
}
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
model_script = torch.jit.script(script_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False,
)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
@override_qengines
def test_skip_quant(self):
"""Test None qconfig"""
# Eager mode
eager_model = AnnotatedSkipQuantModel(torch.backends.quantized.engine).eval()
# Graph mode
script_model = SkipQuantModel().eval()
# Copy weights for eager_model
script_model.sub.fc1.weight = torch.nn.Parameter(
eager_model.sub.module.fc1.weight.detach()
)
script_model.sub.fc1.bias = torch.nn.Parameter(
eager_model.sub.module.fc1.bias.detach()
)
script_model.sub.fc2.weight = torch.nn.Parameter(
eager_model.sub.module.fc2.weight.detach()
)
script_model.sub.fc2.bias = torch.nn.Parameter(
eager_model.sub.module.fc2.bias.detach()
)
script_model.fc.weight = torch.nn.Parameter(eager_model.fc.weight.detach())
script_model.fc.bias = torch.nn.Parameter(eager_model.fc.bias.detach())
eager_model.fuse_modules()
model_eager = quantize(eager_model, test_only_eval_fn, [self.calib_data])
qconfig_dict = {
"": get_default_qconfig(torch.backends.quantized.engine),
"fc": None,
}
model_traced = torch.jit.trace(script_model, self.calib_data[0][0])
model_script = torch.jit.script(script_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_jit(
model_under_test,
qconfig_dict,
test_only_eval_fn,
[self.calib_data],
inplace=False,
)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
@override_qengines
def test_single_linear_dynamic(self):
r"""Compare the result of dynamic quantization of single linear layer in
eager mode and graph mode.
"""
if qengine_is_qnnpack():
# eager mode
annotated_linear_model = AnnotatedSingleLayerLinearModel("qnnpack").eval()
linear_model = SingleLayerLinearModel().eval()
# copy the weight from eager mode so that we can
# compare the result of the two quantized models later
linear_model.fc1.weight = torch.nn.Parameter(
annotated_linear_model.fc1.module.weight.detach()
)
linear_model.fc1.bias = torch.nn.Parameter(
annotated_linear_model.fc1.module.bias.detach()
)
qconfig_dict = {"": default_dynamic_qconfig}
model_eager = quantize_dynamic(annotated_linear_model, qconfig_dict)
model_traced = torch.jit.trace(linear_model, self.calib_data[0][0])
model_script = torch.jit.script(linear_model)
result_eager = model_eager(self.calib_data[0][0])
for model_under_test in [model_traced, model_script]:
model_quantized = quantize_dynamic_jit(model_under_test, qconfig_dict)
self.assertEqual(model_quantized(self.calib_data[0][0]), result_eager)
# Check to make sure choose_qparams->quant->dequant->linear is numerically
# equivalent to the final quantized model.
model_fake_quantized = quantize_dynamic_jit(
model_under_test, qconfig_dict, debug=True
)
self.assertEqual(
model_fake_quantized(self.calib_data[0][0]), result_eager
)
@skipIfNoFBGEMM
def test_linear_dynamic_fp16(self):
linear_model = SingleLayerLinearModel().eval()
# Create weight tensor values that are beyond fp16 max
x = torch.ones(5, 5) * 65532
linear_model.fc1.weight = torch.nn.Parameter(x)
import warnings
model_eager = quantize_dynamic(linear_model, dtype=torch.float16)
result_eager = model_eager(self.calib_data[0][0])
for trace in [True]:
with warnings.catch_warnings(record=True) as w:
quantized_model = self.checkGraphModeOp(
linear_model,
self.calib_data[0][0],
"quantized::linear_dynamic_fp16",
tracing=trace,
dynamic=True,
qconfig=float16_dynamic_qconfig,
)
# compare result with eager mode
self.assertEqual(quantized_model(self.calib_data[0][0]), result_eager)