mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant] [PT2] Enable Inplace Dropout in _move_exported_model_to_eval (#114725)
**Summary** Enable Inplace Dropout replacement in `_move_exported_model_to_eval` **Test Plan** ``` python -u -m pytest -s -v test_quantize_pt2e.py -k test_move_exported_model_to_eval ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114725 Approved by: https://github.com/andrewor14, https://github.com/jgong5 ghstack dependencies: #114547
This commit is contained in:
parent
06eb28c32a
commit
fd7201029a
|
|
@ -1606,11 +1606,11 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
qconfig_mapping,
|
||||
)
|
||||
|
||||
def test_move_exported_model_to_eval(self):
|
||||
def _test_move_exported_model_to_eval_dropout(self, inplace=False):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.dropout = torch.nn.Dropout(0.5, inplace=inplace)
|
||||
|
||||
def forward(self, x):
|
||||
return self.dropout(x)
|
||||
|
|
@ -1622,7 +1622,7 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
# Assert that dropout op exists and is in train mode
|
||||
dropout_node = None
|
||||
for n in m.graph.nodes:
|
||||
if n.target == torch.ops.aten.native_dropout.default:
|
||||
if n.target == torch.ops.aten.native_dropout.default or n.target == torch.ops.aten.dropout_.default:
|
||||
dropout_node = n
|
||||
break
|
||||
self.assertTrue(dropout_node is not None)
|
||||
|
|
@ -1633,8 +1633,20 @@ class TestQuantizePT2E(PT2EQuantizationTestCase):
|
|||
|
||||
# Assert that dropout op is now replaced with a clone op
|
||||
targets = [n.target for n in m.graph.nodes]
|
||||
self.assertTrue(torch.ops.aten.clone.default in targets)
|
||||
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
|
||||
if inplace:
|
||||
dropout_eval_node = None
|
||||
for node in m.graph.nodes:
|
||||
if node.target == torch.ops.aten.dropout_.default:
|
||||
dropout_eval_node = node
|
||||
self.assertTrue(dropout_eval_node is not None)
|
||||
self.assertFalse(dropout_eval_node.args[2])
|
||||
else:
|
||||
self.assertTrue(torch.ops.aten.clone.default in targets)
|
||||
self.assertTrue(torch.ops.aten.native_dropout.default not in targets)
|
||||
|
||||
def test_move_exported_model_to_eval(self):
|
||||
self._test_move_exported_model_to_eval_dropout(inplace=False)
|
||||
self._test_move_exported_model_to_eval_dropout(inplace=True)
|
||||
|
||||
def test_bn_move_exported_model_to_eval(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -21,26 +21,28 @@ def _replace_dropout_for_eval(m: torch.fx.GraphModule):
|
|||
m.graph.eliminate_dead_code()
|
||||
m.recompile()
|
||||
|
||||
def dropout_train(x):
|
||||
return F.dropout(x, p=0.5, training=True)
|
||||
for inplace in [False, True]:
|
||||
|
||||
def dropout_eval(x):
|
||||
return F.dropout(x, p=0.5, training=False)
|
||||
def dropout_train(x):
|
||||
return F.dropout(x, p=0.5, training=True, inplace=inplace)
|
||||
|
||||
example_inputs = (torch.randn(1),)
|
||||
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||
def dropout_eval(x):
|
||||
return F.dropout(x, p=0.5, training=False, inplace=inplace)
|
||||
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
example_inputs = (torch.randn(1),)
|
||||
match_pattern = get_aten_graph_module(dropout_train, example_inputs)
|
||||
replacement_pattern = get_aten_graph_module(dropout_eval, example_inputs)
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
|
||||
|
||||
replace_pattern_with_filters(
|
||||
m,
|
||||
match_pattern,
|
||||
replacement_pattern,
|
||||
match_filters=[],
|
||||
ignore_literals=True,
|
||||
)
|
||||
m.recompile()
|
||||
|
||||
|
||||
def _replace_batchnorm_for_eval(m: torch.fx.GraphModule):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user