[Quant] [PT2] Enable Inplace Dropout in _move_exported_model_to_eval (#114725)

**Summary**
Enable Inplace Dropout replacement in `_move_exported_model_to_eval`

**Test Plan**
```
python -u -m pytest -s -v test_quantize_pt2e.py -k test_move_exported_model_to_eval
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114725
Approved by: https://github.com/andrewor14, https://github.com/jgong5
ghstack dependencies: #114547
This commit is contained in:
leslie-fang-intel 2023-11-30 09:45:33 +08:00 committed by PyTorch MergeBot
parent 06eb28c32a
commit fd7201029a
2 changed files with 35 additions and 21 deletions

View File

@ -1606,11 +1606,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
qconfig_mapping,
)
def test_move_exported_model_to_eval(self):
def _test_move_exported_model_to_eval_dropout(self, inplace=False):
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.dropout = torch.nn.Dropout(0.5)
self.dropout = torch.nn.Dropout(0.5, inplace=inplace)
def forward(self, x):
return self.dropout(x)
@ -1622,7 +1622,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# Assert that dropout op exists and is in train mode
dropout_node = None
for n in m.graph.nodes:
if n.target == torch.ops.aten.native_dropout.default:
if n.target == torch.ops.aten.native_dropout.default or n.target == torch.ops.aten.dropout_.default:
dropout_node = n
break
self.assertTrue(dropout_node is not None)
@ -1633,8 +1633,20 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
# Assert that dropout op is now replaced with a clone op
targets = [n.target for n in m.graph.nodes]
self.assertTrue(torch.ops.aten.clone.default in targets)
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
if inplace:
dropout_eval_node = None
for node in m.graph.nodes:
if node.target == torch.ops.aten.dropout_.default:
dropout_eval_node = node
self.assertTrue(dropout_eval_node is not None)
self.assertFalse(dropout_eval_node.args[2])
else:
self.assertTrue(torch.ops.aten.clone.default in targets)
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
def test_move_exported_model_to_eval(self):
self._test_move_exported_model_to_eval_dropout(inplace=False)
self._test_move_exported_model_to_eval_dropout(inplace=True)
def test_bn_move_exported_model_to_eval(self):
class M(torch.nn.Module):

View File

@ -21,26 +21,28 @@ def _replace_dropout_for_eval(m: torch.fx.GraphModule):
m.graph.eliminate_dead_code()
m.recompile()
def dropout_train(x):
return F.dropout(x, p=0.5, training=True)
for inplace in [False, True]:
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False)
def dropout_train(x):
return F.dropout(x, p=0.5, training=True, inplace=inplace)
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
def dropout_eval(x):
return F.dropout(x, p=0.5, training=False, inplace=inplace)
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
example_inputs = (torch.randn(1),)
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
replace_pattern_with_filters(
m,
match_pattern,
replacement_pattern,
match_filters=[],
ignore_literals=True,
)
m.recompile()
def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):