mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[FX][export] DCE pass, check schema for node impurity (#130395)"
This reverts commit e22b0acc76.
Reverted https://github.com/pytorch/pytorch/pull/130395 on behalf of https://github.com/yushangdi due to breaking tests, need to rebase and fix ([comment](https://github.com/pytorch/pytorch/pull/130395#issuecomment-2235192986))
This commit is contained in:
parent
bd56bcf0ab
commit
433ef4e444
|
|
@ -12,9 +12,11 @@ class TestDCE(TestCase):
|
|||
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
|
||||
if node.is_impure():
|
||||
return True
|
||||
# a custom function that defines add operators as impure.
|
||||
if node.target == torch.ops.aten.add:
|
||||
return True
|
||||
|
||||
if node.op == "call_function":
|
||||
schema = getattr(node.target, "_schema", None)
|
||||
schema_mutable = schema is not None and schema.is_mutable
|
||||
return schema_mutable
|
||||
return False
|
||||
|
||||
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
|
||||
|
|
@ -203,7 +205,7 @@ class TestDCE(TestCase):
|
|||
return a * 2
|
||||
|
||||
# %add_ node should not be removed because it has side effects.
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
||||
|
||||
def test_impure_kwargs(self):
|
||||
"""
|
||||
|
|
@ -216,19 +218,5 @@ class TestDCE(TestCase):
|
|||
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
|
||||
return a
|
||||
|
||||
# %add_out node should not be removed because it has side effects.
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
||||
|
||||
def test_impure_custom(self):
|
||||
"""
|
||||
Test that DCE doesn't remove nodes marked as impure by a custom function.
|
||||
"""
|
||||
|
||||
class TestModule(torch.nn.Module):
|
||||
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
||||
b = a + 1
|
||||
c = torch._ops.ops.aten.add(b, b)
|
||||
return a
|
||||
|
||||
# %add_out node should not be removed because it has side effects.
|
||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
||||
|
|
|
|||
|
|
@ -540,21 +540,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
is_qat=False,
|
||||
debug=False,
|
||||
):
|
||||
def recreate_m(m_eager, is_qat, run_convert_pt2e):
|
||||
m = copy.deepcopy(m_eager)
|
||||
m = capture_pre_autograd_graph(
|
||||
m,
|
||||
example_inputs,
|
||||
)
|
||||
|
||||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
if run_convert_pt2e:
|
||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
||||
m = convert_pt2e(m)
|
||||
return m
|
||||
|
||||
m_eager = model.train() if is_qat else model.eval()
|
||||
|
||||
# program capture
|
||||
|
|
@ -569,12 +554,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
|||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
||||
# Calibrate
|
||||
m(*example_inputs)
|
||||
prepare_model = recreate_m(m_eager, is_qat, False)
|
||||
# Change mutable operations, e.g. change `aten._native_batch_norm_legit.default`
|
||||
# to `aten._native_batch_norm_legit_no_training.default`, for DCE pass.
|
||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
||||
prepare_model = copy.deepcopy(m)
|
||||
m = convert_pt2e(m)
|
||||
convert_model = recreate_m(m_eager, is_qat, True)
|
||||
convert_model = copy.deepcopy(m)
|
||||
if debug:
|
||||
convert_model.print_readable(True)
|
||||
pt2_quant_output = m(*example_inputs)
|
||||
|
|
@ -1738,7 +1720,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
|
|
@ -1812,7 +1793,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
|
|
@ -1858,7 +1838,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
|
|
@ -1907,7 +1886,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
|
|
@ -1953,7 +1931,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
|||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||
# BN should be folded into Conv
|
||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
||||
}
|
||||
node_list = [
|
||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,20 @@ from .graph_signature import (
|
|||
)
|
||||
|
||||
|
||||
def _is_impure_node(node: torch.fx.Node) -> bool:
|
||||
"""
|
||||
Check the schema of node target to detect side-effectful nodes.
|
||||
"""
|
||||
if node.is_impure():
|
||||
return True
|
||||
|
||||
if node.op == "call_function":
|
||||
schema = getattr(node.target, "_schema", None)
|
||||
schema_mutable = schema is not None and schema.is_mutable
|
||||
return schema_mutable
|
||||
return False
|
||||
|
||||
|
||||
def _remove_effect_tokens_from_graph_helper(
|
||||
ep, num_tokens, input_token_names, output_token_names
|
||||
):
|
||||
|
|
@ -114,7 +128,7 @@ def _remove_effect_tokens_from_graph_helper(
|
|||
assert inp_token.name in input_token_names
|
||||
ep.graph.erase_node(inp_token)
|
||||
|
||||
ep.graph.eliminate_dead_code()
|
||||
ep.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||
|
||||
|
||||
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.utils._pytree as pytree
|
|||
from torch._export.utils import _check_input_constraints_for_graph
|
||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
||||
from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
|
||||
|
||||
from .exported_program import (
|
||||
ExportedProgram,
|
||||
|
|
@ -184,7 +184,7 @@ def _unlift(
|
|||
)
|
||||
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
|
||||
gm.graph.lint()
|
||||
gm.graph.eliminate_dead_code()
|
||||
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||
gm.recompile()
|
||||
return gm
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,9 @@ _side_effectful_functions: Set[Callable] = {
|
|||
torch._assert_async,
|
||||
_ops.aten._assert_async.msg,
|
||||
_ops.aten._assert_scalar.default,
|
||||
_ops.aten.copy_.default,
|
||||
_ops.aten.set_.source_Tensor,
|
||||
_ops.aten.index_put_.default,
|
||||
_ops.aten.sym_constrain_range.default,
|
||||
_ops.aten.sym_constrain_range_for_size.default,
|
||||
_ops.profiler._record_function_enter,
|
||||
|
|
@ -645,11 +648,9 @@ class Node(_NodeBase):
|
|||
if self.op in {"placeholder", "output"}:
|
||||
return True
|
||||
|
||||
# Check if an impure function based on schema.
|
||||
# Check if an impure function.
|
||||
if self.op == "call_function":
|
||||
schema = getattr(self.target, "_schema", None)
|
||||
schema_mutable = schema is not None and schema.is_mutable
|
||||
return schema_mutable or self.target in _side_effectful_functions
|
||||
return self.target in _side_effectful_functions
|
||||
|
||||
# Check if an impure module.
|
||||
if self.op == "call_module":
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user