diff --git a/test/quantization/test_quantize_jit.py b/test/quantization/test_quantize_jit.py index 88848d2267f..e6e02fb811c 100644 --- a/test/quantization/test_quantize_jit.py +++ b/test/quantization/test_quantize_jit.py @@ -2177,11 +2177,23 @@ class TestQuantizeJitOps(QuantizationTestCase): .run(m.graph) def test_hardswish(self): - data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)] - hardswish = torch.nn.Hardswish() - for tracing in [True, False]: - m = self.checkGraphModeOp(hardswish, data, "quantized::hardswish", tracing) + class FunctionalHardswish(torch.nn.Module): + def __init__(self, inplace): + super(FunctionalHardswish, self).__init__() + self.inplace = inplace + + def forward(self, input): + return torch.nn.functional.hardswish(input, inplace=self.inplace) + + modules = [torch.nn.Hardswish(), FunctionalHardswish(True), + FunctionalHardswish(False)] + + for test_case in itertools.product([True, False], modules): + tracing, m = test_case + m = self.checkGraphModeOp( + m, self.img_data, "quantized::hardswish", tracing) FileCheck().check_not("aten::hardswish") \ + .check_not("aten::hardswish_") \ .run(m.graph) def test_elu(self): diff --git a/torch/csrc/jit/passes/quantization/helper.cpp b/torch/csrc/jit/passes/quantization/helper.cpp index feae91e5592..a9dfbc6e31f 100644 --- a/torch/csrc/jit/passes/quantization/helper.cpp +++ b/torch/csrc/jit/passes/quantization/helper.cpp @@ -34,6 +34,7 @@ std::vector _static_quantizable_aten_funcs = { "addmm", "matmul", "hardswish", + "hardswish_", "elu", "elu_", "batch_norm", diff --git a/torch/csrc/jit/passes/quantization/quantization_patterns.h b/torch/csrc/jit/passes/quantization/quantization_patterns.h index 62bcf1a5d5e..654e9e8f1df 100644 --- a/torch/csrc/jit/passes/quantization/quantization_patterns.h +++ b/torch/csrc/jit/passes/quantization/quantization_patterns.h @@ -848,6 +848,9 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype) auto hardswish = getObservedQParamOpFusionInfo( "aten::hardswish", "quantized::hardswish", {}, {}); + auto hardswish_ = getObservedQParamOpFusionInfo( + "aten::hardswish_", "quantized::hardswish", {}, {}); + auto layer_norm = getObservedQParamOpFusionInfo( "aten::layer_norm", "quantized::layer_norm", @@ -968,6 +971,7 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype) {"quantized::mul", mul, quantized_mul}, {"quantized::mul", inplace_mul, quantized_mul}, hardswish, + hardswish_, layer_norm, group_norm, instance_norm, diff --git a/torch/testing/_internal/common_quantization.py b/torch/testing/_internal/common_quantization.py index d0576ec99e4..0c97f07c09a 100644 --- a/torch/testing/_internal/common_quantization.py +++ b/torch/testing/_internal/common_quantization.py @@ -7,6 +7,7 @@ r"""Importing this file includes common utility methods and base clases for checking quantization api and properties of resulting modules. """ +import copy import io import functools import torch @@ -282,8 +283,12 @@ class QuantizationTestCase(TestCase): # make sure it runs outputs[d] = models[d](inputs) else: + # module under test can contain in-place ops, and we depend on + # input data staying constant for comparisons + data_copy = copy.deepcopy(data) models[d] = quantize_jit( - model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d) + model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False, + debug=d) # make sure it runs outputs[d] = models[d](*inputs)