Revert "[FX][export] DCE pass, check schema for node impurity (#130395)"

This reverts commit e22b0acc76.

Reverted https://github.com/pytorch/pytorch/pull/130395 on behalf of https://github.com/yushangdi due to breaking tests, need to rebase and fix ([comment](https://github.com/pytorch/pytorch/pull/130395#issuecomment-2235192986))
This commit is contained in:
PyTorch MergeBot 2024-07-18 02:46:03 +00:00
parent bd56bcf0ab
commit 433ef4e444
5 changed files with 30 additions and 50 deletions

View File

@ -12,9 +12,11 @@ class TestDCE(TestCase):
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
if node.is_impure():
return True
# a custom function that defines add operators as impure.
if node.target == torch.ops.aten.add:
return True
if node.op == "call_function":
schema = getattr(node.target, "_schema", None)
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable
return False
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
@ -203,7 +205,7 @@ class TestDCE(TestCase):
return a * 2
# %add_ node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
def test_impure_kwargs(self):
"""
@ -216,19 +218,5 @@ class TestDCE(TestCase):
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
return a
# %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
def test_impure_custom(self):
"""
Test that DCE doesn't remove nodes marked as impure by a custom function.
"""
class TestModule(torch.nn.Module):
def forward(self, a: torch.Tensor) -> torch.Tensor:
b = a + 1
c = torch._ops.ops.aten.add(b, b)
return a
# %add_out node should not be removed because it has side effects.
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)

View File

@ -540,21 +540,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
is_qat=False,
debug=False,
):
def recreate_m(m_eager, is_qat, run_convert_pt2e):
m = copy.deepcopy(m_eager)
m = capture_pre_autograd_graph(
m,
example_inputs,
)
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
if run_convert_pt2e:
torch.ao.quantization.move_exported_model_to_eval(m)
m = convert_pt2e(m)
return m
m_eager = model.train() if is_qat else model.eval()
# program capture
@ -569,12 +554,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
# Calibrate
m(*example_inputs)
prepare_model = recreate_m(m_eager, is_qat, False)
# Change mutable operations, e.g. change `aten._native_batch_norm_legit.default`
# to `aten._native_batch_norm_legit_no_training.default`, for DCE pass.
torch.ao.quantization.move_exported_model_to_eval(m)
prepare_model = copy.deepcopy(m)
m = convert_pt2e(m)
convert_model = recreate_m(m_eager, is_qat, True)
convert_model = copy.deepcopy(m)
if debug:
convert_model.print_readable(True)
pt2_quant_output = m(*example_inputs)
@ -1738,7 +1720,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1812,7 +1793,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1858,7 +1838,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1907,7 +1886,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,
@ -1953,7 +1931,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
# BN should be folded into Conv
torch.ops.aten._native_batch_norm_legit.default: 0,
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
}
node_list = [
torch.ops.quantized_decomposed.quantize_per_tensor.default,

View File

@ -15,6 +15,20 @@ from .graph_signature import (
)
def _is_impure_node(node: torch.fx.Node) -> bool:
"""
Check the schema of node target to detect side-effectful nodes.
"""
if node.is_impure():
return True
if node.op == "call_function":
schema = getattr(node.target, "_schema", None)
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable
return False
def _remove_effect_tokens_from_graph_helper(
ep, num_tokens, input_token_names, output_token_names
):
@ -114,7 +128,7 @@ def _remove_effect_tokens_from_graph_helper(
assert inp_token.name in input_token_names
ep.graph.erase_node(inp_token)
ep.graph.eliminate_dead_code()
ep.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:

View File

@ -8,7 +8,7 @@ import torch.utils._pytree as pytree
from torch._export.utils import _check_input_constraints_for_graph
from torch.export.unflatten import _assign_attr, _AttrKind
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ._remove_effect_tokens_pass import _remove_effect_tokens
from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
from .exported_program import (
ExportedProgram,
@ -184,7 +184,7 @@ def _unlift(
)
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
gm.graph.lint()
gm.graph.eliminate_dead_code()
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
gm.recompile()
return gm

View File

@ -47,6 +47,9 @@ _side_effectful_functions: Set[Callable] = {
torch._assert_async,
_ops.aten._assert_async.msg,
_ops.aten._assert_scalar.default,
_ops.aten.copy_.default,
_ops.aten.set_.source_Tensor,
_ops.aten.index_put_.default,
_ops.aten.sym_constrain_range.default,
_ops.aten.sym_constrain_range_for_size.default,
_ops.profiler._record_function_enter,
@ -645,11 +648,9 @@ class Node(_NodeBase):
if self.op in {"placeholder", "output"}:
return True
# Check if an impure function based on schema.
# Check if an impure function.
if self.op == "call_function":
schema = getattr(self.target, "_schema", None)
schema_mutable = schema is not None and schema.is_mutable
return schema_mutable or self.target in _side_effectful_functions
return self.target in _side_effectful_functions
# Check if an impure module.
if self.op == "call_module":