graph mode: add hardswish inplace handling (#40284)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/40284

Adds graph mode handling for inplace hardswish, and test coverage for functional hardswish.

Test Plan:
```
python test/test_quantization.py TestQuantizeScriptPTSQOps.test_hardswish
```

Imported from OSS

Differential Revision: D22140628

fbshipit-source-id: 55a514f7dc1130d510f69ee4e611d7cb5e08d02e
This commit is contained in:
Vasiliy Kuznetsov 2020-06-21 09:35:44 -07:00 committed by Facebook GitHub Bot
parent c6dbfcaf9e
commit ab8a99bd36
4 changed files with 27 additions and 5 deletions

View File

@ -2177,11 +2177,23 @@ class TestQuantizeJitOps(QuantizationTestCase):
.run(m.graph) .run(m.graph)
def test_hardswish(self): def test_hardswish(self):
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)] class FunctionalHardswish(torch.nn.Module):
hardswish = torch.nn.Hardswish() def __init__(self, inplace):
for tracing in [True, False]: super(FunctionalHardswish, self).__init__()
m = self.checkGraphModeOp(hardswish, data, "quantized::hardswish", tracing) self.inplace = inplace
def forward(self, input):
return torch.nn.functional.hardswish(input, inplace=self.inplace)
modules = [torch.nn.Hardswish(), FunctionalHardswish(True),
FunctionalHardswish(False)]
for test_case in itertools.product([True, False], modules):
tracing, m = test_case
m = self.checkGraphModeOp(
m, self.img_data, "quantized::hardswish", tracing)
FileCheck().check_not("aten::hardswish") \ FileCheck().check_not("aten::hardswish") \
.check_not("aten::hardswish_") \
.run(m.graph) .run(m.graph)
def test_elu(self): def test_elu(self):

View File

@ -34,6 +34,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
"addmm", "addmm",
"matmul", "matmul",
"hardswish", "hardswish",
"hardswish_",
"elu", "elu",
"elu_", "elu_",
"batch_norm", "batch_norm",

View File

@ -848,6 +848,9 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
auto hardswish = getObservedQParamOpFusionInfo( auto hardswish = getObservedQParamOpFusionInfo(
"aten::hardswish", "quantized::hardswish", {}, {}); "aten::hardswish", "quantized::hardswish", {}, {});
auto hardswish_ = getObservedQParamOpFusionInfo(
"aten::hardswish_", "quantized::hardswish", {}, {});
auto layer_norm = getObservedQParamOpFusionInfo( auto layer_norm = getObservedQParamOpFusionInfo(
"aten::layer_norm", "aten::layer_norm",
"quantized::layer_norm", "quantized::layer_norm",
@ -968,6 +971,7 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
{"quantized::mul", mul, quantized_mul}, {"quantized::mul", mul, quantized_mul},
{"quantized::mul", inplace_mul, quantized_mul}, {"quantized::mul", inplace_mul, quantized_mul},
hardswish, hardswish,
hardswish_,
layer_norm, layer_norm,
group_norm, group_norm,
instance_norm, instance_norm,

View File

@ -7,6 +7,7 @@ r"""Importing this file includes common utility methods and base clases for
checking quantization api and properties of resulting modules. checking quantization api and properties of resulting modules.
""" """
import copy
import io import io
import functools import functools
import torch import torch
@ -282,8 +283,12 @@ class QuantizationTestCase(TestCase):
# make sure it runs # make sure it runs
outputs[d] = models[d](inputs) outputs[d] = models[d](inputs)
else: else:
# module under test can contain in-place ops, and we depend on
# input data staying constant for comparisons
data_copy = copy.deepcopy(data)
models[d] = quantize_jit( models[d] = quantize_jit(
model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d) model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False,
debug=d)
# make sure it runs # make sure it runs
outputs[d] = models[d](*inputs) outputs[d] = models[d](*inputs)