mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[FX][export] DCE pass, check schema for node impurity (#130395)"
This reverts commit e22b0acc76.
Reverted https://github.com/pytorch/pytorch/pull/130395 on behalf of https://github.com/yushangdi due to breaking tests, need to rebase and fix ([comment](https://github.com/pytorch/pytorch/pull/130395#issuecomment-2235192986))
This commit is contained in:
parent
bd56bcf0ab
commit
433ef4e444
|
|
@ -12,9 +12,11 @@ class TestDCE(TestCase):
|
||||||
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
|
def _custom_is_impure_node(self, node: torch.fx.Node) -> bool:
|
||||||
if node.is_impure():
|
if node.is_impure():
|
||||||
return True
|
return True
|
||||||
# a custom function that defines add operators as impure.
|
|
||||||
if node.target == torch.ops.aten.add:
|
if node.op == "call_function":
|
||||||
return True
|
schema = getattr(node.target, "_schema", None)
|
||||||
|
schema_mutable = schema is not None and schema.is_mutable
|
||||||
|
return schema_mutable
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
|
def _has_nodes_without_users(self, m: torch.fx.GraphModule, custom: bool = False):
|
||||||
|
|
@ -203,7 +205,7 @@ class TestDCE(TestCase):
|
||||||
return a * 2
|
return a * 2
|
||||||
|
|
||||||
# %add_ node should not be removed because it has side effects.
|
# %add_ node should not be removed because it has side effects.
|
||||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
||||||
|
|
||||||
def test_impure_kwargs(self):
|
def test_impure_kwargs(self):
|
||||||
"""
|
"""
|
||||||
|
|
@ -216,19 +218,5 @@ class TestDCE(TestCase):
|
||||||
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
|
torch._ops.ops.aten.add.out(b, b, out=a, alpha=2)
|
||||||
return a
|
return a
|
||||||
|
|
||||||
# %add_out node should not be removed because it has side effects.
|
|
||||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False)
|
|
||||||
|
|
||||||
def test_impure_custom(self):
|
|
||||||
"""
|
|
||||||
Test that DCE doesn't remove nodes marked as impure by a custom function.
|
|
||||||
"""
|
|
||||||
|
|
||||||
class TestModule(torch.nn.Module):
|
|
||||||
def forward(self, a: torch.Tensor) -> torch.Tensor:
|
|
||||||
b = a + 1
|
|
||||||
c = torch._ops.ops.aten.add(b, b)
|
|
||||||
return a
|
|
||||||
|
|
||||||
# %add_out node should not be removed because it has side effects.
|
# %add_out node should not be removed because it has side effects.
|
||||||
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
self._run_dce_and_test(TestModule(), expect_dce_changes=False, custom=True)
|
||||||
|
|
|
||||||
|
|
@ -540,21 +540,6 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||||
is_qat=False,
|
is_qat=False,
|
||||||
debug=False,
|
debug=False,
|
||||||
):
|
):
|
||||||
def recreate_m(m_eager, is_qat, run_convert_pt2e):
|
|
||||||
m = copy.deepcopy(m_eager)
|
|
||||||
m = capture_pre_autograd_graph(
|
|
||||||
m,
|
|
||||||
example_inputs,
|
|
||||||
)
|
|
||||||
|
|
||||||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
|
||||||
# Calibrate
|
|
||||||
m(*example_inputs)
|
|
||||||
if run_convert_pt2e:
|
|
||||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
|
||||||
m = convert_pt2e(m)
|
|
||||||
return m
|
|
||||||
|
|
||||||
m_eager = model.train() if is_qat else model.eval()
|
m_eager = model.train() if is_qat else model.eval()
|
||||||
|
|
||||||
# program capture
|
# program capture
|
||||||
|
|
@ -569,12 +554,9 @@ class X86InductorQuantTestCase(QuantizationTestCase):
|
||||||
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
m = prepare_qat_pt2e(m, quantizer) if is_qat else prepare_pt2e(m, quantizer)
|
||||||
# Calibrate
|
# Calibrate
|
||||||
m(*example_inputs)
|
m(*example_inputs)
|
||||||
prepare_model = recreate_m(m_eager, is_qat, False)
|
prepare_model = copy.deepcopy(m)
|
||||||
# Change mutable operations, e.g. change `aten._native_batch_norm_legit.default`
|
|
||||||
# to `aten._native_batch_norm_legit_no_training.default`, for DCE pass.
|
|
||||||
torch.ao.quantization.move_exported_model_to_eval(m)
|
|
||||||
m = convert_pt2e(m)
|
m = convert_pt2e(m)
|
||||||
convert_model = recreate_m(m_eager, is_qat, True)
|
convert_model = copy.deepcopy(m)
|
||||||
if debug:
|
if debug:
|
||||||
convert_model.print_readable(True)
|
convert_model.print_readable(True)
|
||||||
pt2_quant_output = m(*example_inputs)
|
pt2_quant_output = m(*example_inputs)
|
||||||
|
|
@ -1738,7 +1720,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
# BN should be folded into Conv
|
# BN should be folded into Conv
|
||||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
|
||||||
}
|
}
|
||||||
node_list = [
|
node_list = [
|
||||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
|
@ -1812,7 +1793,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
# BN should be folded into Conv
|
# BN should be folded into Conv
|
||||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
|
||||||
}
|
}
|
||||||
node_list = [
|
node_list = [
|
||||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
|
@ -1858,7 +1838,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
# BN should be folded into Conv
|
# BN should be folded into Conv
|
||||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
|
||||||
}
|
}
|
||||||
node_list = [
|
node_list = [
|
||||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
|
@ -1907,7 +1886,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 2,
|
||||||
# BN should be folded into Conv
|
# BN should be folded into Conv
|
||||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
|
||||||
}
|
}
|
||||||
node_list = [
|
node_list = [
|
||||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
|
@ -1953,7 +1931,6 @@ class TestQuantizePT2EX86Inductor(X86InductorQuantTestCase):
|
||||||
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
torch.ops.quantized_decomposed.dequantize_per_channel.default: 1,
|
||||||
# BN should be folded into Conv
|
# BN should be folded into Conv
|
||||||
torch.ops.aten._native_batch_norm_legit.default: 0,
|
torch.ops.aten._native_batch_norm_legit.default: 0,
|
||||||
torch.ops.aten._native_batch_norm_legit_no_training.default: 0,
|
|
||||||
}
|
}
|
||||||
node_list = [
|
node_list = [
|
||||||
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
torch.ops.quantized_decomposed.quantize_per_tensor.default,
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,20 @@ from .graph_signature import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_impure_node(node: torch.fx.Node) -> bool:
|
||||||
|
"""
|
||||||
|
Check the schema of node target to detect side-effectful nodes.
|
||||||
|
"""
|
||||||
|
if node.is_impure():
|
||||||
|
return True
|
||||||
|
|
||||||
|
if node.op == "call_function":
|
||||||
|
schema = getattr(node.target, "_schema", None)
|
||||||
|
schema_mutable = schema is not None and schema.is_mutable
|
||||||
|
return schema_mutable
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _remove_effect_tokens_from_graph_helper(
|
def _remove_effect_tokens_from_graph_helper(
|
||||||
ep, num_tokens, input_token_names, output_token_names
|
ep, num_tokens, input_token_names, output_token_names
|
||||||
):
|
):
|
||||||
|
|
@ -114,7 +128,7 @@ def _remove_effect_tokens_from_graph_helper(
|
||||||
assert inp_token.name in input_token_names
|
assert inp_token.name in input_token_names
|
||||||
ep.graph.erase_node(inp_token)
|
ep.graph.erase_node(inp_token)
|
||||||
|
|
||||||
ep.graph.eliminate_dead_code()
|
ep.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||||
|
|
||||||
|
|
||||||
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram:
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.utils._pytree as pytree
|
||||||
from torch._export.utils import _check_input_constraints_for_graph
|
from torch._export.utils import _check_input_constraints_for_graph
|
||||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||||
from ._remove_effect_tokens_pass import _remove_effect_tokens
|
from ._remove_effect_tokens_pass import _is_impure_node, _remove_effect_tokens
|
||||||
|
|
||||||
from .exported_program import (
|
from .exported_program import (
|
||||||
ExportedProgram,
|
ExportedProgram,
|
||||||
|
|
@ -184,7 +184,7 @@ def _unlift(
|
||||||
)
|
)
|
||||||
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
|
gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names)
|
||||||
gm.graph.lint()
|
gm.graph.lint()
|
||||||
gm.graph.eliminate_dead_code()
|
gm.graph.eliminate_dead_code(is_impure_node=_is_impure_node)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,9 @@ _side_effectful_functions: Set[Callable] = {
|
||||||
torch._assert_async,
|
torch._assert_async,
|
||||||
_ops.aten._assert_async.msg,
|
_ops.aten._assert_async.msg,
|
||||||
_ops.aten._assert_scalar.default,
|
_ops.aten._assert_scalar.default,
|
||||||
|
_ops.aten.copy_.default,
|
||||||
|
_ops.aten.set_.source_Tensor,
|
||||||
|
_ops.aten.index_put_.default,
|
||||||
_ops.aten.sym_constrain_range.default,
|
_ops.aten.sym_constrain_range.default,
|
||||||
_ops.aten.sym_constrain_range_for_size.default,
|
_ops.aten.sym_constrain_range_for_size.default,
|
||||||
_ops.profiler._record_function_enter,
|
_ops.profiler._record_function_enter,
|
||||||
|
|
@ -645,11 +648,9 @@ class Node(_NodeBase):
|
||||||
if self.op in {"placeholder", "output"}:
|
if self.op in {"placeholder", "output"}:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Check if an impure function based on schema.
|
# Check if an impure function.
|
||||||
if self.op == "call_function":
|
if self.op == "call_function":
|
||||||
schema = getattr(self.target, "_schema", None)
|
return self.target in _side_effectful_functions
|
||||||
schema_mutable = schema is not None and schema.is_mutable
|
|
||||||
return schema_mutable or self.target in _side_effectful_functions
|
|
||||||
|
|
||||||
# Check if an impure module.
|
# Check if an impure module.
|
||||||
if self.op == "call_module":
|
if self.op == "call_module":
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user