[aoti] follow up to use new api in test_provenance_tracing.py (#149387)

Summary:
As title. Follow up of  D71181284. and some minor refactoring

Context : D69609685 (update test runner to use new api) / https://github.com/pytorch/pytorch/pull/147105

Test Plan:
```
buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_to_post_grad_tracing_cpu
```

Differential Revision: D71375725

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149387
Approved by: https://github.com/yushangdi
This commit is contained in:
Rachel Guo 2025-03-21 04:37:50 +00:00 committed by PyTorch MergeBot
parent 5327894812
commit ccd5d811e8

View File

@ -40,89 +40,31 @@ class TestProvenanceTracingArtifact(TestCase):
corresponding "inductor triton kernel node" is expected.
"""
@requires_cuda
def _check_provenance_tracing_artifact(self, filepath):
def _check_provenance_tracing_artifact(self, filepath, expected_data):
self.assertTrue(filepath.is_dir())
filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json"
with open(filename) as f:
actual_data = json.load(f)
# check that the generated provenance tracing artifact is expected
expected_data = {
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
}
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
def _check_provenance_tracking_node_mappings(self, filepath, expected_mapping):
self.assertTrue(filepath.is_dir())
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
with open(filename) as f:
actual_data = json.load(f)
# check that the generated provenance tracing artifact is expected
expected_data = [
(
"cppCodeToPost",
{
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
},
),
(
"postToCppCode",
{
"mul": ["triton_poi_fused_mul_0"],
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
"add": ["triton_poi_fused_addmm_gelu_1"],
"erf": ["triton_poi_fused_addmm_gelu_1"],
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
},
),
(
"postToPre",
{
"mul": ["mul"],
"mm_default": ["addmm"],
"add_tensor": ["addmm"],
"mul_1": ["gelu"],
"mul_2": ["gelu"],
"erf": ["gelu"],
"add": ["gelu"],
"mul_3": ["gelu"],
},
),
(
"preToPost",
{
"mul": ["mul"],
"addmm": ["mm_default", "add_tensor"],
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
},
),
]
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
# check that the generated provenance tracing node mapping is expected
self.assertEqual(sorted(actual_data.items()), sorted(expected_mapping))
@requires_cuda
def test_triton_kernel_to_post_grad_tracing(self):
a = torch.randn(10, 20, device="cuda")
b = torch.randn(20, 30, device="cuda")
c = torch.randn(10, 30, device="cuda")
def _test_triton_kernel_to_post_grad_tracing(self, device):
a = torch.randn(10, 20, device=device)
b = torch.randn(20, 30, device=device)
c = torch.randn(10, 30, device=device)
example_inputs = (a, b, c)
model = Model()
filepath = None
for backend in ["aot_inductor", "inductor"]:
try:
with config.patch(
@ -145,62 +87,98 @@ class TestProvenanceTracingArtifact(TestCase):
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1))
self._check_provenance_tracing_artifact(filepath)
finally:
shutil.rmtree(filepath)
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
def test_triton_kernel_to_post_grad_tracing_cpu(self):
a = torch.randn(10, 20, device="cpu")
b = torch.randn(20, 30, device="cpu")
c = torch.randn(10, 30, device="cpu")
example_inputs = (a, b, c)
model = Model()
ep = torch.export._trace._export(model, example_inputs)
gm = ep.module()
filepath = None
for backend in ["aot_inductor", "inductor"]:
try:
with config.patch(
{
"trace.debug_dir": tempfile.mkdtemp(),
"force_disable_caches": True,
}
):
with self.assertLogs(
logging.getLogger("torch._inductor.debug"),
level=logging.WARNING,
) as cm:
if backend == "aot_inductor":
AOTIRunnerUtil.run(model, example_inputs)
else:
compiled = torch.compile(gm, backend=backend)
compiled(*example_inputs)
self.assertEqual(len(cm.output), 1)
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
self.assertTrue(m)
filepath = Path(m.group(1))
filename = (
Path(filepath)
/ "inductor_triton_kernel_to_post_grad_nodes.json"
)
with open(filename) as f:
actual_data = json.load(f)
# check the inductor kernel to post grad nodes mapping is expected for cpu
expected_data = {
"cpp_fused_mul_0": ["mul"],
"cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"],
}
self.assertEqual(
sorted(actual_data.items()), sorted(expected_data.items())
)
if device == "cuda":
expected_data = {
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
}
self._check_provenance_tracing_artifact(filepath, expected_data)
expected_mapping = [
(
"cppCodeToPost",
{
"triton_poi_fused_mul_0": ["mul"],
"triton_poi_fused_addmm_gelu_1": [
"mul_3",
"mul_1",
"add_tensor",
"add",
"erf",
"mul_2",
],
},
),
(
"postToCppCode",
{
"mul": ["triton_poi_fused_mul_0"],
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
"add": ["triton_poi_fused_addmm_gelu_1"],
"erf": ["triton_poi_fused_addmm_gelu_1"],
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
},
),
(
"postToPre",
{
"mul": ["mul"],
"mm_default": ["addmm"],
"add_tensor": ["addmm"],
"mul_1": ["gelu"],
"mul_2": ["gelu"],
"erf": ["gelu"],
"add": ["gelu"],
"mul_3": ["gelu"],
},
),
(
"preToPost",
{
"mul": ["mul"],
"addmm": ["mm_default", "add_tensor"],
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
},
),
]
self._check_provenance_tracking_node_mappings(
filepath, expected_mapping
)
else:
assert device == "cpu"
# check the inductor kernel to post grad nodes mapping is expected for cpu
expected_data = {
"cpp_fused_mul_0": ["mul"],
"cpp_fused_gelu_1": [
"mul_3",
"mul_1",
"add",
"erf",
"mul_2",
],
}
self._check_provenance_tracing_artifact(filepath, expected_data)
finally:
if filepath:
shutil.rmtree(filepath)
@requires_cuda
def test_triton_kernel_to_post_grad_tracing_cuda(self):
self._test_triton_kernel_to_post_grad_tracing(device="cuda")
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
def test_triton_kernel_to_post_grad_tracing_cpu(self):
self._test_triton_kernel_to_post_grad_tracing(device="cpu")
class TestProvenanceTracingNodeMapping(TestCase):
def test_create_node_mapping(self):