diff --git a/test/fx/test_dce_pass.py b/test/fx/test_dce_pass.py index 34dde98658c..b090072949b 100644 --- a/test/fx/test_dce_pass.py +++ b/test/fx/test_dce_pass.py @@ -12,9 +12,11 @@ class TestDCE(TestCase): def _custom_is_impure_node(self, node: torch.fx.Node) -> bool: if node.is_impure(): return True - # a custom function that defines add operators as impure. - if node.target == torch.ops.aten.add: - return True + + if node.op == "call_function": + schema = getattr(node.target, "_schema", None) + schema_mutable = schema is not None and schema.is_mutable + return schema_mutable return False def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False): @@ -203,7 +205,7 @@ class TestDCE(TestCase): return a * 2 # %add_ node should not be removed because it has side effects. - self._run_dce_and_test(TestModule(), expect_dce_changes=False) + self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True) def test_impure_kwargs(self): """ @@ -216,19 +218,5 @@ class TestDCE(TestCase): torch._ops.ops.aten.add.out(b, b, out=a, alpha=2) return a - # %add_out node should not be removed because it has side effects. - self._run_dce_and_test(TestModule(), expect_dce_changes=False) - - def test_impure_custom(self): - """ - Test that DCE doesn't remove nodes marked as impure by a custom function. - """ - - class TestModule(torch.nn.Module): - def forward(self, a: torch.Tensor) -> torch.Tensor: - b = a + 1 - c = torch._ops.ops.aten.add(b, b) - return a - # %add_out node should not be removed because it has side effects. self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True) diff --git a/test/quantization/pt2e/test_x86inductor_quantizer.py b/test/quantization/pt2e/test_x86inductor_quantizer.py index a3b1b6409ed..669ab687be0 100644 --- a/test/quantization/pt2e/test_x86inductor_quantizer.py +++ b/test/quantization/pt2e/test_x86inductor_quantizer.py @@ -540,21 +540,6 @@ class X86InductorQuantTestCase(QuantizationTestCase): is_qat=False, debug=False, ): - def recreate_m(m_eager, is_qat, run_convert_pt2e): - m = copy.deepcopy(m_eager) - m = capture_pre_autograd_graph( - m, - example_inputs, - ) - - m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) - # Calibrate - m(*example_inputs) - if run_convert_pt2e: - torch.ao.quantization.move_exported_model_to_eval(m) - m = convert_pt2e(m) - return m - m_eager = model.train() if is_qat else model.eval() # program capture @@ -569,12 +554,9 @@ class X86InductorQuantTestCase(QuantizationTestCase): m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) # Calibrate m(*example_inputs) - prepare_model = recreate_m(m_eager, is_qat, False) - # Change mutable operations, e.g. change `aten._native_batch_norm_legit.default` - # to `aten._native_batch_norm_legit_no_training.default`, for DCE pass. - torch.ao.quantization.move_exported_model_to_eval(m) + prepare_model = copy.deepcopy(m) m = convert_pt2e(m) - convert_model = recreate_m(m_eager, is_qat, True) + convert_model = copy.deepcopy(m) if debug: convert_model.print_readable(True) pt2_quant_output = m(*example_inputs) @@ -1738,7 +1720,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, - torch.ops.aten._native_batch_norm_legit_no_training.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, @@ -1812,7 +1793,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, - torch.ops.aten._native_batch_norm_legit_no_training.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, @@ -1858,7 +1838,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, - torch.ops.aten._native_batch_norm_legit_no_training.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, @@ -1907,7 +1886,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, - torch.ops.aten._native_batch_norm_legit_no_training.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, @@ -1953,7 +1931,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase): torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, # BN should be folded into Conv torch.ops.aten._native_batch_norm_legit.default: 0, - torch.ops.aten._native_batch_norm_legit_no_training.default: 0, } node_list = [ torch.ops.quantized_decomposed.quantize_per_tensor.default, diff --git a/torch/export/_remove_effect_tokens_pass.py b/torch/export/_remove_effect_tokens_pass.py index 14187b426e4..c93ff0a1eb3 100644 --- a/torch/export/_remove_effect_tokens_pass.py +++ b/torch/export/_remove_effect_tokens_pass.py @@ -15,6 +15,20 @@ from .graph_signature import ( ) +def _is_impure_node(node: torch.fx.Node) -> bool: + """ + Check the schema of node target to detect side-effectful nodes. + """ + if node.is_impure(): + return True + + if node.op == "call_function": + schema = getattr(node.target, "_schema", None) + schema_mutable = schema is not None and schema.is_mutable + return schema_mutable + return False + + def _remove_effect_tokens_from_graph_helper( ep, num_tokens, input_token_names, output_token_names ): @@ -114,7 +128,7 @@ def _remove_effect_tokens_from_graph_helper( assert inp_token.name in input_token_names ep.graph.erase_node(inp_token) - ep.graph.eliminate_dead_code() + ep.graph.eliminate_dead_code(is_impure_node=_is_impure_node) def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: diff --git a/torch/export/_unlift.py b/torch/export/_unlift.py index 5a8f144b04e..f0956195f22 100644 --- a/torch/export/_unlift.py +++ b/torch/export/_unlift.py @@ -8,7 +8,7 @@ import torch.utils._pytree as pytree from torch._export.utils import _check_input_constraints_for_graph from torch.export.unflatten import _assign_attr, _AttrKind from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo -from ._remove_effect_tokens_pass import _remove_effect_tokens +from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens from .exported_program import ( ExportedProgram, @@ -184,7 +184,7 @@ def _unlift( ) gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) gm.graph.lint() - gm.graph.eliminate_dead_code() + gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node) gm.recompile() return gm diff --git a/torch/fx/node.py b/torch/fx/node.py index 7bcb4233c75..010502594e0 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -47,6 +47,9 @@ _side_effectful_functions: Set[Callable] = { torch._assert_async, _ops.aten._assert_async.msg, _ops.aten._assert_scalar.default, + _ops.aten.copy_.default, + _ops.aten.set_.source_Tensor, + _ops.aten.index_put_.default, _ops.aten.sym_constrain_range.default, _ops.aten.sym_constrain_range_for_size.default, _ops.profiler._record_function_enter, @@ -645,11 +648,9 @@ class Node(_NodeBase): if self.op in {"placeholder", "output"}: return True - # Check if an impure function based on schema. + # Check if an impure function. if self.op == "call_function": - schema = getattr(self.target, "_schema", None) - schema_mutable = schema is not None and schema.is_mutable - return schema_mutable or self.target in _side_effectful_functions + return self.target in _side_effectful_functions # Check if an impure module. if self.op == "call_module":