mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][graphmode][fx] Remove inplace option for fuse_fx (#46953)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46953 Test Plan: Imported from OSS Reviewed By: supriyar Differential Revision: D24579402 fbshipit-source-id: 5e0b8abf682287ab3c7dd54c2fc2cf309295e147
This commit is contained in:
parent
e299393fd5
commit
d92bf921db
|
|
@ -12,13 +12,10 @@ from .pattern_utils import (
|
||||||
|
|
||||||
from .fusion_patterns import * # noqa: F401
|
from .fusion_patterns import * # noqa: F401
|
||||||
|
|
||||||
import copy
|
|
||||||
class Fuser:
|
class Fuser:
|
||||||
def fuse(self, model, inplace=False, fuse_custom_config_dict=None):
|
def fuse(self, model, fuse_custom_config_dict=None):
|
||||||
if fuse_custom_config_dict is None:
|
if fuse_custom_config_dict is None:
|
||||||
fuse_custom_config_dict = {}
|
fuse_custom_config_dict = {}
|
||||||
if not inplace:
|
|
||||||
model = copy.deepcopy(model)
|
|
||||||
|
|
||||||
input_root = model
|
input_root = model
|
||||||
input_graph = model.graph
|
input_graph = model.graph
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ def _swap_ff_with_fxff(model):
|
||||||
del model._modules[name]
|
del model._modules[name]
|
||||||
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
|
model._modules[name] = torch.nn.quantized.FXFloatFunctional()
|
||||||
|
|
||||||
def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
|
def _fuse_fx(graph_module, fuse_custom_config_dict=None):
|
||||||
r""" Internal helper function to fuse modules in preparation for quantization
|
r""" Internal helper function to fuse modules in preparation for quantization
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -35,7 +35,7 @@ def _fuse_fx(graph_module, inplace=False, fuse_custom_config_dict=None):
|
||||||
"""
|
"""
|
||||||
_check_is_graph_module(graph_module)
|
_check_is_graph_module(graph_module)
|
||||||
fuser = Fuser()
|
fuser = Fuser()
|
||||||
return fuser.fuse(graph_module, inplace, fuse_custom_config_dict)
|
return fuser.fuse(graph_module, fuse_custom_config_dict)
|
||||||
|
|
||||||
class CustomTracer(Tracer):
|
class CustomTracer(Tracer):
|
||||||
def __init__(self, skipped_module_names, skipped_module_classes):
|
def __init__(self, skipped_module_names, skipped_module_classes):
|
||||||
|
|
@ -80,7 +80,7 @@ forward graph of the parent module,
|
||||||
skipped_module_classes += float_custom_module_classes
|
skipped_module_classes += float_custom_module_classes
|
||||||
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
|
tracer = CustomTracer(skipped_module_names, skipped_module_classes)
|
||||||
graph_module = GraphModule(model, tracer.trace(model))
|
graph_module = GraphModule(model, tracer.trace(model))
|
||||||
graph_module = _fuse_fx(graph_module, inplace, prepare_custom_config_dict)
|
graph_module = _fuse_fx(graph_module, prepare_custom_config_dict)
|
||||||
quantizer = Quantizer()
|
quantizer = Quantizer()
|
||||||
return quantizer.prepare(
|
return quantizer.prepare(
|
||||||
graph_module,
|
graph_module,
|
||||||
|
|
@ -108,12 +108,11 @@ def _prepare_standalone_module_fx(model, qconfig_dict, inplace=False, prepare_cu
|
||||||
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)
|
return _prepare_fx(model, qconfig_dict, inplace, prepare_custom_config_dict, is_standalone_module=True)
|
||||||
|
|
||||||
|
|
||||||
def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
|
def fuse_fx(model, fuse_custom_config_dict=None):
|
||||||
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
|
r""" Fuse modules like conv+bn, conv+bn+relu etc, model must be in eval mode.
|
||||||
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
|
Fusion rules are defined in torch.quantization.fx.fusion_pattern.py
|
||||||
Args:
|
Args:
|
||||||
`model`: a torch.nn.Module model
|
`model`: a torch.nn.Module model
|
||||||
`inplace`: flag for whether we fuse modules inplace or out of place
|
|
||||||
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
|
`fuse_custom_config_dict`: Dictionary for custom configurations for fuse_fx, e.g.
|
||||||
fuse_custom_config_dict = {
|
fuse_custom_config_dict = {
|
||||||
"additional_fuser_method_mapping": {
|
"additional_fuser_method_mapping": {
|
||||||
|
|
@ -131,7 +130,7 @@ def fuse_fx(model, inplace=False, fuse_custom_config_dict=None):
|
||||||
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
|
torch._C._log_api_usage_once("quantization_api.quantize_fx.fuse_fx")
|
||||||
assert not model.training, 'fuse_fx only works on models in eval mode'
|
assert not model.training, 'fuse_fx only works on models in eval mode'
|
||||||
graph_module = torch.fx.symbolic_trace(model)
|
graph_module = torch.fx.symbolic_trace(model)
|
||||||
return _fuse_fx(graph_module, inplace, fuse_custom_config_dict)
|
return _fuse_fx(graph_module, fuse_custom_config_dict)
|
||||||
|
|
||||||
def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
|
def prepare_fx(model, qconfig_dict, inplace=False, prepare_custom_config_dict=None):
|
||||||
r""" Prepare a model for post training static quantization
|
r""" Prepare a model for post training static quantization
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user