pytorch/test/inductor/test_inductor_annotations.py
Boyuan Feng 5f1010fbb3 [Graph Partition] Pass all OSS unit tests (#154667)
Graph partition leads to 6.2% speedup on vision_maskrcnn, 5.8% speedup on yolov3. [P1819700563](https://www.internalfb.com/phabricator/paste/view/P1819700563), 39.5% speedup on speech_transformer inference [P1830602200](https://www.internalfb.com/phabricator/paste/view/P1830602200), 85% speedup on speech_transformer training [P1831115315](https://www.internalfb.com/phabricator/paste/view/P1831115315).

Run the same diff on two days and both show speedup on average.

[first TorchInductor Benchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Mon%2C%2021%20Jul%202025%2016%3A37%3A55%20GMT&stopTime=Mon%2C%2028%20Jul%202025%2016%3A37%3A55%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=75ef90fe89b82c967362a2d40fdf1af047202bc2&rBranch=main&rCommit=abcb24f4de11f8fedf2c2c9ff53b6092ef42306d)
<img width="1885" height="752" alt="image" src="https://github.com/user-attachments/assets/13bba9fc-5dbf-42ad-8558-d54f7e367b41" />

[second TorchInductorBenchmark ci run](https://hud.pytorch.org/benchmark/compilers?dashboard=torchinductor&startTime=Wed%2C%2023%20Jul%202025%2016%3A38%3A27%20GMT&stopTime=Wed%2C%2030%20Jul%202025%2016%3A38%3A27%20GMT&granularity=hour&mode=inference&dtype=bfloat16&deviceName=cuda%20(h100)&lBranch=bf/partition-turn-on&lCommit=66de27e29338c26b1be94733049868cb0309ea52&rBranch=main&rCommit=70d2e9ba455c3c910f6f95b24171c8eee7bc00bf)
<img width="2513" height="1030" alt="image" src="https://github.com/user-attachments/assets/3a413dcb-2314-4292-919a-7ca181f9eeac" />

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154667
Approved by: https://github.com/eellison
2025-08-12 04:37:58 +00:00

43 lines
1.3 KiB
Python

# Owner(s): ["module: inductor"]
import torch
import torch._inductor.config as inductor_config
from torch._inductor.test_case import run_tests, TestCase
from torch._inductor.utils import run_and_get_code
from torch.testing._internal.triton_utils import requires_cuda_and_triton
class InductorAnnotationTestCase(TestCase):
def get_code(self):
def f(a, b):
return a + b, a * b
a = torch.randn(5, device="cuda")
b = torch.randn(5, device="cuda")
f_comp = torch.compile(f)
_, code = run_and_get_code(f_comp, a, b)
return code[0]
@requires_cuda_and_triton
def test_no_annotations(self):
code = self.get_code()
self.assertTrue("from torch.cuda import nvtx" not in code)
self.assertTrue("training_annotation" not in code)
@inductor_config.patch(annotate_training=True)
@requires_cuda_and_triton
def test_training_annotation(self):
code = self.get_code()
self.assertTrue("from torch.cuda import nvtx" in code)
self.assertTrue(
code.count("training_annotation = nvtx._device_range_start('inference')")
>= 1
)
self.assertTrue(code.count("nvtx._device_range_end(training_annotation)") >= 1)
if __name__ == "__main__":
run_tests()