mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] Add fx_graph_runnable test coverage (#157021)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157021 Approved by: https://github.com/StrongerXi, https://github.com/xmfan
This commit is contained in:
parent
130d4973bd
commit
20e40492b0
99
test/dynamo/test_fx_graph_runnable.py
Normal file
99
test/dynamo/test_fx_graph_runnable.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
import io
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
import torch._logging.structured
|
||||
from torch._inductor.test_case import TestCase
|
||||
from torch.testing._internal.common_utils import IS_FBCODE
|
||||
|
||||
|
||||
class FxGraphRunnableArtifactFilter(logging.Filter):
|
||||
def filter(self, record):
|
||||
return (
|
||||
"artifact" in record.metadata
|
||||
and record.metadata["artifact"]["name"] == "fx_graph_runnable"
|
||||
)
|
||||
|
||||
|
||||
class StructuredTracePayloadFormatter(logging.Formatter):
|
||||
def format(self, record):
|
||||
return record.payload.strip()
|
||||
|
||||
|
||||
trace_log = logging.getLogger("torch.__trace")
|
||||
|
||||
|
||||
class FxGraphRunnableTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._dynamo.reset()
|
||||
torch._logging.structured.INTERN_TABLE.clear()
|
||||
self.old_level = trace_log.level
|
||||
trace_log.setLevel(logging.DEBUG)
|
||||
|
||||
# Create a custom filter specifically for fx_graph_runnable entries
|
||||
self.filter = FxGraphRunnableArtifactFilter()
|
||||
|
||||
# Create a separate buffer and handler for capturing fx_graph_runnable entries
|
||||
self.buffer = io.StringIO()
|
||||
self.handler = logging.StreamHandler(self.buffer)
|
||||
self.handler.setFormatter(StructuredTracePayloadFormatter())
|
||||
self.handler.addFilter(self.filter)
|
||||
trace_log.addHandler(self.handler)
|
||||
|
||||
def tearDown(self):
|
||||
trace_log.removeHandler(self.handler)
|
||||
trace_log.setLevel(self.old_level)
|
||||
|
||||
def _exec_and_verify_payload(self):
|
||||
# Write captured payload & run it in a fresh Python process
|
||||
payload = self.buffer.getvalue().strip()
|
||||
self.assertTrue(payload, "Expected fx_graph_runnable payload but got nothing")
|
||||
self.assertIn("def forward", payload) # sanity-check for actual FX code
|
||||
|
||||
with tempfile.NamedTemporaryFile("w", suffix=".py") as tmp:
|
||||
tmp.write(payload)
|
||||
tmp.flush()
|
||||
res = subprocess.run(
|
||||
[sys.executable, tmp.name], capture_output=True, text=True, timeout=30
|
||||
)
|
||||
self.assertEqual(
|
||||
res.returncode,
|
||||
0,
|
||||
f"Standalone fx_graph_runnable failed:\nSTDERR:\n{res.stderr}",
|
||||
)
|
||||
|
||||
# basic tests
|
||||
def test_basic_tensor_add(self):
|
||||
def f(x):
|
||||
return x + 1
|
||||
|
||||
torch.compile(f)(torch.randn(4))
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
def test_two_inputs_matmul(self):
|
||||
def f(a, b):
|
||||
return (a @ b).relu()
|
||||
|
||||
a, b = torch.randn(2, 3), torch.randn(3, 4)
|
||||
torch.compile(f)(a, b)
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
def test_scalar_multiply(self):
|
||||
def f(x):
|
||||
return x * 2
|
||||
|
||||
torch.compile(f)(torch.randn(5))
|
||||
self._exec_and_verify_payload()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
if not IS_FBCODE:
|
||||
# fbcode complains about not being able to find torch in subprocess
|
||||
run_tests()
|
||||
Loading…
Reference in New Issue
Block a user