mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/87791 Fixing the interface so that the fuse_func is honored and not replaced but the default fuse_known_method. Test Plan: Wait for sandcastle Reviewed By: jerryzh168 Differential Revision: D40722395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88193 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
433746300d
commit
65de9a2b81
|
|
@ -28,6 +28,7 @@ from torch.testing._internal.common_quantization import (
|
|||
ModelForLinearBNFusion,
|
||||
ModelForFusionWithBias,
|
||||
ModelForConvTransposeBNFusion,
|
||||
SingleLayerLinearModel,
|
||||
test_only_eval_fn,
|
||||
test_only_train_fn,
|
||||
skipIfNoFBGEMM,
|
||||
|
|
@ -363,6 +364,17 @@ class TestFuseEager(QuantizationTestCase):
|
|||
|
||||
self.assertEqual(golden, model(inp2))
|
||||
|
||||
def test_fuse_function_customization(self):
|
||||
dummy_model = SingleLayerLinearModel().train()
|
||||
dummy_model.eval()
|
||||
|
||||
# A custom fuse funct
|
||||
def custom_fuse_func(module, is_qat, add_fuser_mapping):
|
||||
return [torch.nn.Identity()]
|
||||
|
||||
dummy_model = fuse_modules(dummy_model, [["fc1"]], fuser_func=custom_fuse_func)
|
||||
self.assertEqual(type(dummy_model.fc1), nn.Identity)
|
||||
|
||||
def test_forward_hooks_preserved(self):
|
||||
r"""Test case that checks whether forward pre hooks of the first module and
|
||||
post forward hooks of the last module in modules list passed to fusion function preserved.
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ def fuse_modules(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_mo
|
|||
modules_to_fuse,
|
||||
is_qat=False,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuser_func=fuser_func,
|
||||
fuse_custom_config_dict=None)
|
||||
|
||||
def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_known_modules, fuse_custom_config_dict=None):
|
||||
|
|
@ -171,5 +171,5 @@ def fuse_modules_qat(model, modules_to_fuse, inplace=False, fuser_func=fuse_know
|
|||
modules_to_fuse,
|
||||
is_qat=True,
|
||||
inplace=inplace,
|
||||
fuser_func=fuse_known_modules,
|
||||
fuser_func=fuser_func,
|
||||
fuse_custom_config_dict=None)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user