[quant][graphmode][fx] Remove inplace option for fuse_fx (#46953)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46953

Test Plan: Imported from OSS

Reviewed By: supriyar

Differential Revision: D24579402

fbshipit-source-id: 5e0b8abf682287ab3c7dd54c2fc2cf309295e147
This commit is contained in:
Jerry Zhang 2020-10-27 22:32:31 -07:00 committed by Facebook GitHub Bot
parent e299393fd5
commit d92bf921db
2 changed files with 6 additions and 10 deletions

View File

@ -12,13 +12,10 @@ from .pattern_utils import (
from .fusion_patterns import * # noqa: F401
import copy
class Fuser:
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
def fuse(self, model, fuse_custom_config_dict=None):
if fuse_custom_config_dict is None:
fuse_custom_config_dict = {}
if not inplace:
model = copy.deepcopy(model)
input_root = model
input_graph = model.graph

View File

@ -27,7 +27,7 @@ def _swap_ff_with_fxff(model):
del model._modules[name]
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
def _fuse_fx(graph_module, fuse_custom_config_dict=None):
r""" Internal helper function to fuse modules in preparation for quantization
Args:
@ -35,7 +35,7 @@ def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
"""
_check_is_graph_module(graph_module)
fuser = Fuser()
return fuser.fuse(graph_module, inplace, fuse_custom_config_dict)
return fuser.fuse(graph_module, fuse_custom_config_dict)
class CustomTracer(Tracer):
def __init__(self, skipped_module_names, skipped_module_classes):
@ -80,7 +80,7 @@ forward graph of the parent module,
skipped_module_classes += float_custom_module_classes
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
graph_module = GraphModule(model, tracer.trace(model))
graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict)
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
quantizer = Quantizer()
return quantizer.prepare(
graph_module,
@ -108,12 +108,11 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)
def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
def fuse_fx(model, fuse_custom_config_dict=None):
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
Args:
`model`: a torch.nn.Module model
`inplace`: flag for whether we fuse modules inplace or out of place
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
fuse_custom_config_dict = {
"additional_fuser_method_mapping": {
@ -131,7 +130,7 @@ def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
assert not model.training, 'fuse_fx only works on models in eval mode'
graph_module = torch.fx.symbolic_trace(model)
return _fuse_fx(graph_module, inplace, fuse_custom_config_dict)
return _fuse_fx(graph_module, fuse_custom_config_dict)
def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
r""" Prepare a model for post training static quantization