mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Partially addresses #123062 Ran lintrunner on: - test/quantization/jit - test/quantization/pt2e Detail: ``` $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` cc, please @ezyang Pull Request resolved: https://github.com/pytorch/pytorch/pull/124010 Approved by: https://github.com/ezyang
107 lines
4.1 KiB
Python
107 lines
4.1 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
# torch
|
|
import torch
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
|
|
|
|
class TestFusionPasses(QuantizationTestCase):
|
|
def test_quantized_add_relu_fusion(self):
|
|
class MAdd(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
a = torch.ops.quantized.add(x, y, 1.0, 0)
|
|
relu_out = torch.relu(a)
|
|
return relu_out
|
|
|
|
A = torch.arange(-128, 130, dtype=torch.float)
|
|
B = torch.arange(-128, 130, dtype=torch.float)
|
|
scale = 2.0
|
|
zero_point = 127
|
|
qA = torch.quantize_per_tensor(
|
|
A, scale=scale, zero_point=zero_point, dtype=torch.quint8
|
|
)
|
|
qB = torch.quantize_per_tensor(
|
|
B, scale=scale, zero_point=zero_point, dtype=torch.quint8
|
|
)
|
|
|
|
# Check quantized add + relu fusion
|
|
m = MAdd()
|
|
scripted_m = torch.jit.script(m)
|
|
ref_output = scripted_m(qA, qB)
|
|
|
|
# Must inline the graph.
|
|
# In this test case since we are directly calling ops
|
|
# it does not matter, however if we are calling nn
|
|
# modules we have to inline graph.
|
|
torch._C._jit_pass_inline(scripted_m.graph)
|
|
torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
|
|
FileCheck().check_not("aten::relu").check("quantized::add_relu").run(
|
|
scripted_m.graph
|
|
)
|
|
output = scripted_m(qA, qB)
|
|
self.assertEqual(ref_output, output)
|
|
|
|
class MAddOut(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
a = torch.ops.quantized.add_out(x, y, z)
|
|
relu_out = torch.relu(a)
|
|
return relu_out
|
|
|
|
qC = torch._empty_affine_quantized(
|
|
qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8
|
|
)
|
|
# Check quantized add + relu fusion
|
|
m = MAddOut()
|
|
scripted_m = torch.jit.script(m)
|
|
ref_output = scripted_m(qA, qB, qC)
|
|
# Must inline the graph.
|
|
# In this test case since we are directly calling ops
|
|
# it does not matter, however if we are calling nn
|
|
# modules we have to inline graph.
|
|
torch._C._jit_pass_inline(scripted_m.graph)
|
|
torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
|
|
FileCheck().check_not("aten::relu").check_not("quantized::add_out").check(
|
|
"quantized::add_relu_out"
|
|
).run(scripted_m.graph)
|
|
output = scripted_m(qA, qB, qC)
|
|
self.assertEqual(ref_output, output)
|
|
|
|
class MAddScalar(torch.nn.Module):
|
|
def forward(self, x, y: float):
|
|
a = torch.ops.quantized.add_scalar(x, y)
|
|
relu_out = torch.relu(a)
|
|
return relu_out
|
|
|
|
# Check quantized add + relu fusion
|
|
m = MAddScalar()
|
|
scripted_m = torch.jit.script(m)
|
|
ref_output = scripted_m(qA, 3.0)
|
|
torch._C._jit_pass_inline(scripted_m.graph)
|
|
torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
|
|
FileCheck().check_not("aten::relu").check_not("quantized::add_scalar(").check(
|
|
"quantized::add_scalar_relu"
|
|
).run(scripted_m.graph)
|
|
output = scripted_m(qA, 3.0)
|
|
self.assertEqual(ref_output, output)
|
|
|
|
class MAddScalarOut(torch.nn.Module):
|
|
def forward(self, x, y: float, z):
|
|
a = torch.ops.quantized.add_scalar_out(x, y, z)
|
|
relu_out = torch.relu(a)
|
|
return relu_out
|
|
|
|
qC = torch._empty_affine_quantized(
|
|
qA.shape, scale=scale, zero_point=zero_point, dtype=torch.quint8
|
|
)
|
|
m = MAddScalarOut()
|
|
scripted_m = torch.jit.script(m)
|
|
ref_output = scripted_m(qA, 3.0, qC)
|
|
torch._C._jit_pass_inline(scripted_m.graph)
|
|
torch._C._jit_pass_fuse_quantized_add_relu(scripted_m.graph)
|
|
FileCheck().check_not("aten::relu").check_not(
|
|
"quantized::add_scalar_out"
|
|
).check("quantized::add_scalar_relu_out").run(scripted_m.graph)
|
|
output = scripted_m(qA, 3.0, qC)
|
|
self.assertEqual(ref_output, output)
|