mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][graphmode][fx] Remove inplace option for fuse_fx (#46953)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46953 Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D24579402 fbshipit-source-id: 5e0b8abf682287ab3c7dd54c2fc2cf309295e147
This commit is contained in:
parent
e299393fd5
commit
d92bf921db
|
|
@ -12,13 +12,10 @@ from .pattern_utils import (
|
|||
|
||||
from .fusion_patterns import * # noqa: F401
|
||||
|
||||
import copy
|
||||
class Fuser:
|
||||
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
|
||||
def fuse(self, model, fuse_custom_config_dict=None):
|
||||
if fuse_custom_config_dict is None:
|
||||
fuse_custom_config_dict = {}
|
||||
if not inplace:
|
||||
model = copy.deepcopy(model)
|
||||
|
||||
input_root = model
|
||||
input_graph = model.graph
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def _swap_ff_with_fxff(model):
|
|||
del model._modules[name]
|
||||
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
|
||||
|
||||
def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
|
||||
def _fuse_fx(graph_module, fuse_custom_config_dict=None):
|
||||
r""" Internal helper function to fuse modules in preparation for quantization
|
||||
|
||||
Args:
|
||||
|
|
@ -35,7 +35,7 @@ def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
|
|||
"""
|
||||
_check_is_graph_module(graph_module)
|
||||
fuser = Fuser()
|
||||
return fuser.fuse(graph_module, inplace, fuse_custom_config_dict)
|
||||
return fuser.fuse(graph_module, fuse_custom_config_dict)
|
||||
|
||||
class CustomTracer(Tracer):
|
||||
def __init__(self, skipped_module_names, skipped_module_classes):
|
||||
|
|
@ -80,7 +80,7 @@ forward graph of the parent module,
|
|||
skipped_module_classes += float_custom_module_classes
|
||||
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
|
||||
graph_module = GraphModule(model, tracer.trace(model))
|
||||
graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict)
|
||||
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
|
||||
quantizer = Quantizer()
|
||||
return quantizer.prepare(
|
||||
graph_module,
|
||||
|
|
@ -108,12 +108,11 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu
|
|||
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)
|
||||
|
||||
|
||||
def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
|
||||
def fuse_fx(model, fuse_custom_config_dict=None):
|
||||
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
|
||||
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
|
||||
Args:
|
||||
`model`: a torch.nn.Module model
|
||||
`inplace`: flag for whether we fuse modules inplace or out of place
|
||||
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
|
||||
fuse_custom_config_dict = {
|
||||
"additional_fuser_method_mapping": {
|
||||
|
|
@ -131,7 +130,7 @@ def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
|
|||
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
|
||||
assert not model.training, 'fuse_fx only works on models in eval mode'
|
||||
graph_module = torch.fx.symbolic_trace(model)
|
||||
return _fuse_fx(graph_module, inplace, fuse_custom_config_dict)
|
||||
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
||||
|
||||
def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
|
||||
r""" Prepare a model for post training static quantization
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user