Revert "[FX][export] DCE pass, check schema for node impurity (#130395)"

This reverts commit e22b0acc76.

Reverted https://github.com/pytorch/pytorch/pull/130395 on behalf of https://github.com/yushangdi due to breaking tests, need to rebase and fix ([comment](https://github.com/pytorch/pytorch/pull/130395#issuecomment-2235192986))
This commit is contained in:
PyTorch MergeBot 2024-07-18 02:46:03 +00:00
parent bd56bcf0ab
commit 433ef4e444
5 changed files with 30 additions and 50 deletions

View File

@ -12,9 +12,11 @@ class TestDCE(TestCase):
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool: def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
if node.is_impure(): if node.is_impure():
return True return True
# a custom function that defines add operators as impure.
if node.target == torch.ops.aten.add: if node.op == "call_function":
return True schema = getattr(node.target, "_schema", None)
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable
return False return False
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False): def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
@ -203,7 +205,7 @@ class TestDCE(TestCase):
return a * 2 return a * 2
# %add_ node should not be removed because it has side effects. # %add_ node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False) self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
def test_impure_kwargs(self): def test_impure_kwargs(self):
""" """
@ -216,19 +218,5 @@ class TestDCE(TestCase):
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2) torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
return a return a
# %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_impure_custom(self):
"""
Test that DCE doesn't remove nodes marked as impure by a custom function.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
b = a + 1
c = torch._ops.ops.aten.add(b, b)
return a
# %add_out node should not be removed because it has side effects. # %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True) self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)

View File

@ -540,21 +540,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
is_qat=False, is_qat=False,
debug=False, debug=False,
): ):
def recreate_m(m_eager, is_qat, run_convert_pt2e):
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
if run_convert_pt2e:
torch.ao.quantization.move_exported_model_to_eval(m)
m = convert_pt2e(m)
return m
m_eager = model.train() if is_qat else model.eval() m_eager = model.train() if is_qat else model.eval()
# program capture # program capture
@ -569,12 +554,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer) m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate # Calibrate
m(*example_inputs) m(*example_inputs)
prepare_model = recreate_m(m_eager, is_qat, False) prepare_model = copy.deepcopy(m)
# Change mutable operations, e.g. change `aten._native_batch_norm_legit.default`
# to `aten._native_batch_norm_legit_no_training.default`, for DCE pass.
torch.ao.quantization.move_exported_model_to_eval(m)
m = convert_pt2e(m) m = convert_pt2e(m)
convert_model = recreate_m(m_eager, is_qat, True) convert_model = copy.deepcopy(m)
if debug: if debug:
convert_model.print_readable(True) convert_model.print_readable(True)
pt2_quant_output = m(*example_inputs) pt2_quant_output = m(*example_inputs)
@ -1738,7 +1720,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv # BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0, torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
} }
node_list = [ node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1812,7 +1793,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv # BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0, torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
} }
node_list = [ node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1858,7 +1838,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv # BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0, torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
} }
node_list = [ node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1907,7 +1886,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2, torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
# BN should be folded into Conv # BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0, torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
} }
node_list = [ node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1953,7 +1931,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1, torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv # BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0, torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
} }
node_list = [ node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default, torch.ops.quantized_decomposed.quantize_per_tensor.default,

View File

@ -15,6 +15,20 @@ from .graph_signature import (
) )
def _is_impure_node(node: torch.fx.Node) -> bool:
"""
Check the schema of node target to detect side-effectful nodes.
"""
if node.is_impure():
return True
if node.op == "call_function":
schema = getattr(node.target, "_schema", None)
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable
return False
def _remove_effect_tokens_from_graph_helper( def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names ep, num_tokens, input_token_names, output_token_names
): ):
@ -114,7 +128,7 @@ def _remove_effect_tokens_from_graph_helper(
assert inp_token.name in input_token_names assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token) ep.graph.erase_node(inp_token)
ep.graph.eliminate_dead_code() ep.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:

View File

@ -8,7 +8,7 @@ import torch.utils._pytree as pytree
from torch._export.utils import _check_input_constraints_for_graph from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ._remove_effect_tokens_pass import _remove_effect_tokens from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
from .exported_program import ( from .exported_program import (
ExportedProgram, ExportedProgram,
@ -184,7 +184,7 @@ def _unlift(
) )
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
gm.graph.lint() gm.graph.lint()
gm.graph.eliminate_dead_code() gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
gm.recompile() gm.recompile()
return gm return gm

View File

@ -47,6 +47,9 @@ _side_effectful_functions: Set[Callable] = {
torch._assert_async, torch._assert_async,
_ops.aten._assert_async.msg, _ops.aten._assert_async.msg,
_ops.aten._assert_scalar.default, _ops.aten._assert_scalar.default,
_ops.aten.copy_.default,
_ops.aten.set_.source_Tensor,
_ops.aten.index_put_.default,
_ops.aten.sym_constrain_range.default, _ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default, _ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter, _ops.profiler._record_function_enter,
@ -645,11 +648,9 @@ class Node(_NodeBase):
if self.op in {"placeholder", "output"}: if self.op in {"placeholder", "output"}:
return True return True
# Check if an impure function based on schema. # Check if an impure function.
if self.op == "call_function": if self.op == "call_function":
schema = getattr(self.target, "_schema", None) return self.target in _side_effectful_functions
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable or self.target in _side_effectful_functions
# Check if an impure module. # Check if an impure module.
if self.op == "call_module": if self.op == "call_module":