pytorch/test/fx/test_fx_xform_observer.py
Xuehai Pan 76169cf691 [BE][Easy][9/19] enforce style for empty lines in import segments in test/[e-h]*/ (#129760)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129760
Approved by: https://github.com/ezyang
2024-07-17 14:25:29 +00:00

61 lines
1.7 KiB
Python

# Owner(s): ["module: fx"]
import os
import tempfile
import torch
from torch.fx import subgraph_rewriter, symbolic_trace
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
from torch.testing._internal.common_utils import TestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_fx.py TESTNAME\n\n"
"instead."
)
class TestGraphTransformObserver(TestCase):
def test_graph_transform_observer(self):
class M(torch.nn.Module):
def forward(self, x):
val = torch.neg(x)
return torch.add(val, val)
def pattern(x):
return torch.neg(x)
def replacement(x):
return torch.relu(x)
traced = symbolic_trace(M())
log_url = tempfile.mkdtemp()
with GraphTransformObserver(traced, "replace_neg_with_relu", log_url) as ob:
subgraph_rewriter.replace_pattern(traced, pattern, replacement)
self.assertTrue("relu" in ob.created_nodes)
self.assertTrue("neg" in ob.erased_nodes)
current_pass_count = GraphTransformObserver.get_current_pass_count()
self.assertTrue(
os.path.isfile(
os.path.join(
log_url,
f"pass_{current_pass_count}_replace_neg_with_relu_input_graph.dot",
)
)
)
self.assertTrue(
os.path.isfile(
os.path.join(
log_url,
f"pass_{current_pass_count}_replace_neg_with_relu_output_graph.dot",
)
)
)