mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
graph mode: add hardswish inplace handling (#40284)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40284 Adds graph mode handling for inplace hardswish, and test coverage for functional hardswish. Test Plan: ``` python test/test_quantization.py TestQuantizeScriptPTSQOps.test_hardswish ``` Imported from OSS Differential Revision: D22140628 fbshipit-source-id: 55a514f7dc1130d510f69ee4e611d7cb5e08d02e
This commit is contained in:
parent
c6dbfcaf9e
commit
ab8a99bd36
|
|
@ -2177,11 +2177,23 @@ class TestQuantizeJitOps(QuantizationTestCase):
|
||||||
.run(m.graph)
|
.run(m.graph)
|
||||||
|
|
||||||
def test_hardswish(self):
|
def test_hardswish(self):
|
||||||
data = [(torch.rand((1, 2, 5, 5), dtype=torch.float), torch.randint(0, 1, (1,), dtype=torch.long)) for _ in range(2)]
|
class FunctionalHardswish(torch.nn.Module):
|
||||||
hardswish = torch.nn.Hardswish()
|
def __init__(self, inplace):
|
||||||
for tracing in [True, False]:
|
super(FunctionalHardswish, self).__init__()
|
||||||
m = self.checkGraphModeOp(hardswish, data, "quantized::hardswish", tracing)
|
self.inplace = inplace
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
return torch.nn.functional.hardswish(input, inplace=self.inplace)
|
||||||
|
|
||||||
|
modules = [torch.nn.Hardswish(), FunctionalHardswish(True),
|
||||||
|
FunctionalHardswish(False)]
|
||||||
|
|
||||||
|
for test_case in itertools.product([True, False], modules):
|
||||||
|
tracing, m = test_case
|
||||||
|
m = self.checkGraphModeOp(
|
||||||
|
m, self.img_data, "quantized::hardswish", tracing)
|
||||||
FileCheck().check_not("aten::hardswish") \
|
FileCheck().check_not("aten::hardswish") \
|
||||||
|
.check_not("aten::hardswish_") \
|
||||||
.run(m.graph)
|
.run(m.graph)
|
||||||
|
|
||||||
def test_elu(self):
|
def test_elu(self):
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ std::vector<std::string> _static_quantizable_aten_funcs = {
|
||||||
"addmm",
|
"addmm",
|
||||||
"matmul",
|
"matmul",
|
||||||
"hardswish",
|
"hardswish",
|
||||||
|
"hardswish_",
|
||||||
"elu",
|
"elu",
|
||||||
"elu_",
|
"elu_",
|
||||||
"batch_norm",
|
"batch_norm",
|
||||||
|
|
|
||||||
|
|
@ -848,6 +848,9 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
||||||
auto hardswish = getObservedQParamOpFusionInfo(
|
auto hardswish = getObservedQParamOpFusionInfo(
|
||||||
"aten::hardswish", "quantized::hardswish", {}, {});
|
"aten::hardswish", "quantized::hardswish", {}, {});
|
||||||
|
|
||||||
|
auto hardswish_ = getObservedQParamOpFusionInfo(
|
||||||
|
"aten::hardswish_", "quantized::hardswish", {}, {});
|
||||||
|
|
||||||
auto layer_norm = getObservedQParamOpFusionInfo(
|
auto layer_norm = getObservedQParamOpFusionInfo(
|
||||||
"aten::layer_norm",
|
"aten::layer_norm",
|
||||||
"quantized::layer_norm",
|
"quantized::layer_norm",
|
||||||
|
|
@ -968,6 +971,7 @@ graph(%a_quant, %alpha, %scale, %input_scale, %r_scale, %r_zero_point, %r_dtype)
|
||||||
{"quantized::mul", mul, quantized_mul},
|
{"quantized::mul", mul, quantized_mul},
|
||||||
{"quantized::mul", inplace_mul, quantized_mul},
|
{"quantized::mul", inplace_mul, quantized_mul},
|
||||||
hardswish,
|
hardswish,
|
||||||
|
hardswish_,
|
||||||
layer_norm,
|
layer_norm,
|
||||||
group_norm,
|
group_norm,
|
||||||
instance_norm,
|
instance_norm,
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ r"""Importing this file includes common utility methods and base clases for
|
||||||
checking quantization api and properties of resulting modules.
|
checking quantization api and properties of resulting modules.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import copy
|
||||||
import io
|
import io
|
||||||
import functools
|
import functools
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -282,8 +283,12 @@ class QuantizationTestCase(TestCase):
|
||||||
# make sure it runs
|
# make sure it runs
|
||||||
outputs[d] = models[d](inputs)
|
outputs[d] = models[d](inputs)
|
||||||
else:
|
else:
|
||||||
|
# module under test can contain in-place ops, and we depend on
|
||||||
|
# input data staying constant for comparisons
|
||||||
|
data_copy = copy.deepcopy(data)
|
||||||
models[d] = quantize_jit(
|
models[d] = quantize_jit(
|
||||||
model, qconfig_dict, test_only_eval_fn, [data], inplace=False, debug=d)
|
model, qconfig_dict, test_only_eval_fn, [data_copy], inplace=False,
|
||||||
|
debug=d)
|
||||||
# make sure it runs
|
# make sure it runs
|
||||||
outputs[d] = models[d](*inputs)
|
outputs[d] = models[d](*inputs)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user