mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti] follow up to use new api in test_provenance_tracing.py (#149387)
Summary: As title. Follow up of D71181284. and some minor refactoring Context : D69609685 (update test runner to use new api) / https://github.com/pytorch/pytorch/pull/147105 Test Plan: ``` buck2 run -c fbcode.enable_gpu_sections=true -c fbcode.nvcc_arch=h100 @//mode/opt fbcode//caffe2/test/inductor:provenance_tracing -- -r test_triton_kernel_to_post_grad_tracing_cpu ``` Differential Revision: D71375725 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149387 Approved by: https://github.com/yushangdi
This commit is contained in:
parent
5327894812
commit
ccd5d811e8
|
|
@ -40,89 +40,31 @@ class TestProvenanceTracingArtifact(TestCase):
|
|||
corresponding "inductor triton kernel node" is expected.
|
||||
"""
|
||||
|
||||
@requires_cuda
|
||||
def _check_provenance_tracing_artifact(self, filepath):
|
||||
def _check_provenance_tracing_artifact(self, filepath, expected_data):
|
||||
self.assertTrue(filepath.is_dir())
|
||||
filename = Path(filepath) / "inductor_triton_kernel_to_post_grad_nodes.json"
|
||||
with open(filename) as f:
|
||||
actual_data = json.load(f)
|
||||
# check that the generated provenance tracing artifact is expected
|
||||
expected_data = {
|
||||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
}
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_data.items()))
|
||||
|
||||
def _check_provenance_tracking_node_mappings(self, filepath, expected_mapping):
|
||||
self.assertTrue(filepath.is_dir())
|
||||
filename = Path(filepath) / "inductor_provenance_tracking_node_mappings.json"
|
||||
with open(filename) as f:
|
||||
actual_data = json.load(f)
|
||||
# check that the generated provenance tracing artifact is expected
|
||||
expected_data = [
|
||||
(
|
||||
"cppCodeToPost",
|
||||
{
|
||||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToCppCode",
|
||||
{
|
||||
"mul": ["triton_poi_fused_mul_0"],
|
||||
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"erf": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToPre",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"mm_default": ["addmm"],
|
||||
"add_tensor": ["addmm"],
|
||||
"mul_1": ["gelu"],
|
||||
"mul_2": ["gelu"],
|
||||
"erf": ["gelu"],
|
||||
"add": ["gelu"],
|
||||
"mul_3": ["gelu"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"preToPost",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"addmm": ["mm_default", "add_tensor"],
|
||||
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
|
||||
},
|
||||
),
|
||||
]
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_data))
|
||||
# check that the generated provenance tracing node mapping is expected
|
||||
self.assertEqual(sorted(actual_data.items()), sorted(expected_mapping))
|
||||
|
||||
@requires_cuda
|
||||
def test_triton_kernel_to_post_grad_tracing(self):
|
||||
a = torch.randn(10, 20, device="cuda")
|
||||
b = torch.randn(20, 30, device="cuda")
|
||||
c = torch.randn(10, 30, device="cuda")
|
||||
def _test_triton_kernel_to_post_grad_tracing(self, device):
|
||||
a = torch.randn(10, 20, device=device)
|
||||
b = torch.randn(20, 30, device=device)
|
||||
c = torch.randn(10, 30, device=device)
|
||||
example_inputs = (a, b, c)
|
||||
|
||||
model = Model()
|
||||
filepath = None
|
||||
|
||||
for backend in ["aot_inductor", "inductor"]:
|
||||
try:
|
||||
with config.patch(
|
||||
|
|
@ -145,62 +87,98 @@ class TestProvenanceTracingArtifact(TestCase):
|
|||
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
|
||||
self.assertTrue(m)
|
||||
filepath = Path(m.group(1))
|
||||
self._check_provenance_tracing_artifact(filepath)
|
||||
finally:
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
|
||||
def test_triton_kernel_to_post_grad_tracing_cpu(self):
|
||||
a = torch.randn(10, 20, device="cpu")
|
||||
b = torch.randn(20, 30, device="cpu")
|
||||
c = torch.randn(10, 30, device="cpu")
|
||||
example_inputs = (a, b, c)
|
||||
|
||||
model = Model()
|
||||
ep = torch.export._trace._export(model, example_inputs)
|
||||
gm = ep.module()
|
||||
filepath = None
|
||||
|
||||
for backend in ["aot_inductor", "inductor"]:
|
||||
try:
|
||||
with config.patch(
|
||||
{
|
||||
"trace.debug_dir": tempfile.mkdtemp(),
|
||||
"force_disable_caches": True,
|
||||
}
|
||||
):
|
||||
with self.assertLogs(
|
||||
logging.getLogger("torch._inductor.debug"),
|
||||
level=logging.WARNING,
|
||||
) as cm:
|
||||
if backend == "aot_inductor":
|
||||
AOTIRunnerUtil.run(model, example_inputs)
|
||||
else:
|
||||
compiled = torch.compile(gm, backend=backend)
|
||||
compiled(*example_inputs)
|
||||
self.assertEqual(len(cm.output), 1)
|
||||
m = re.match(r"WARNING.* debug trace: (.*)", cm.output[0])
|
||||
self.assertTrue(m)
|
||||
filepath = Path(m.group(1))
|
||||
filename = (
|
||||
Path(filepath)
|
||||
/ "inductor_triton_kernel_to_post_grad_nodes.json"
|
||||
)
|
||||
with open(filename) as f:
|
||||
actual_data = json.load(f)
|
||||
# check the inductor kernel to post grad nodes mapping is expected for cpu
|
||||
expected_data = {
|
||||
"cpp_fused_mul_0": ["mul"],
|
||||
"cpp_fused_gelu_1": ["mul_3", "mul_1", "add", "erf", "mul_2"],
|
||||
}
|
||||
self.assertEqual(
|
||||
sorted(actual_data.items()), sorted(expected_data.items())
|
||||
)
|
||||
if device == "cuda":
|
||||
expected_data = {
|
||||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
}
|
||||
self._check_provenance_tracing_artifact(filepath, expected_data)
|
||||
expected_mapping = [
|
||||
(
|
||||
"cppCodeToPost",
|
||||
{
|
||||
"triton_poi_fused_mul_0": ["mul"],
|
||||
"triton_poi_fused_addmm_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add_tensor",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToCppCode",
|
||||
{
|
||||
"mul": ["triton_poi_fused_mul_0"],
|
||||
"mul_3": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_1": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add_tensor": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"add": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"erf": ["triton_poi_fused_addmm_gelu_1"],
|
||||
"mul_2": ["triton_poi_fused_addmm_gelu_1"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"postToPre",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"mm_default": ["addmm"],
|
||||
"add_tensor": ["addmm"],
|
||||
"mul_1": ["gelu"],
|
||||
"mul_2": ["gelu"],
|
||||
"erf": ["gelu"],
|
||||
"add": ["gelu"],
|
||||
"mul_3": ["gelu"],
|
||||
},
|
||||
),
|
||||
(
|
||||
"preToPost",
|
||||
{
|
||||
"mul": ["mul"],
|
||||
"addmm": ["mm_default", "add_tensor"],
|
||||
"gelu": ["mul_1", "mul_2", "erf", "add", "mul_3"],
|
||||
},
|
||||
),
|
||||
]
|
||||
self._check_provenance_tracking_node_mappings(
|
||||
filepath, expected_mapping
|
||||
)
|
||||
else:
|
||||
assert device == "cpu"
|
||||
# check the inductor kernel to post grad nodes mapping is expected for cpu
|
||||
expected_data = {
|
||||
"cpp_fused_mul_0": ["mul"],
|
||||
"cpp_fused_gelu_1": [
|
||||
"mul_3",
|
||||
"mul_1",
|
||||
"add",
|
||||
"erf",
|
||||
"mul_2",
|
||||
],
|
||||
}
|
||||
self._check_provenance_tracing_artifact(filepath, expected_data)
|
||||
|
||||
finally:
|
||||
if filepath:
|
||||
shutil.rmtree(filepath)
|
||||
|
||||
@requires_cuda
|
||||
def test_triton_kernel_to_post_grad_tracing_cuda(self):
|
||||
self._test_triton_kernel_to_post_grad_tracing(device="cuda")
|
||||
|
||||
@unittest.skipIf(HAS_GPU, "the test is only for cpu")
|
||||
def test_triton_kernel_to_post_grad_tracing_cpu(self):
|
||||
self._test_triton_kernel_to_post_grad_tracing(device="cpu")
|
||||
|
||||
|
||||
class TestProvenanceTracingNodeMapping(TestCase):
|
||||
def test_create_node_mapping(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user