Fix fuse_func method overwrite (#87791) (#88193)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87791

Fixing the interface so that the fuse_func is honored and not replaced but the default fuse_known_method.

Test Plan: Wait for sandcastle

Reviewed By: jerryzh168

Differential Revision: D40722395

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88193
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sam Tsai 2022-11-03 20:32:54 +00:00 committed by PyTorch MergeBot
parent 433746300d
commit 65de9a2b81
2 changed files with 14 additions and 2 deletions

View File

@ -28,6 +28,7 @@ from torch.testing._internal.common_quantization import (
ModelForLinearBNFusion,
ModelForFusionWithBias,
ModelForConvTransposeBNFusion,
SingleLayerLinearModel,
test_only_eval_fn,
test_only_train_fn,
skipIfNoFBGEMM,
@ -363,6 +364,17 @@ class TestFuseEager(QuantizationTestCase):
self.assertEqual(golden, model(inp2))
def test_fuse_function_customization(self):
dummy_model = SingleLayerLinearModel().train()
dummy_model.eval()
# A custom fuse funct
def custom_fuse_func(module, is_qat, add_fuser_mapping):
return [torch.nn.Identity()]
dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
self.assertEqual(type(dummy_model.fc1), nn.Identity)
def test_forward_hooks_preserved(self):
r"""Test case that checks whether forward pre hooks of the first module and
post forward hooks of the last module in modules list passed to fusion function preserved.

View File

@ -160,7 +160,7 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
modules_to_fuse,
is_qat=False,
inplace=inplace,
fuser_func=fuse_known_modules,
fuser_func=fuser_func,
fuse_custom_config_dict=None)
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
@ -171,5 +171,5 @@ def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_know
modules_to_fuse,
is_qat=True,
inplace=inplace,
fuser_func=fuse_known_modules,
fuser_func=fuser_func,
fuse_custom_config_dict=None)