[profiler][retry] don't disable CUPTI_LAZY_REINIT for cuda >= 12.6 (#151124)

Retry of https://github.com/pytorch/pytorch/pull/150957, which was reverted due to internal meta failures

Credit to @mgmtea who wrote the initial version of this PR: https://github.com/pytorch/pytorch/pull/146604

Context: CUPTI is the NVIDIA library that Kineto uses for collecting GPU-side info during profiling. The intended usage is to register a callback while you want profiling to occur, and then unregister the callback when you want profiling to stop. But a bug would cause crashes if CUPTI callbacks were de-registered when used with cudagraphs. The workaround was to disable "CUPTI_LAZY_REINIT" and "CUPTI_TEARDOWN" in Kineto - which prevents crashes, but can result in slower execution after profiling has occurred and completed.

This bug is believed to be fixed in CUDA >= 12.6, so this PR qualifies that DISABLE_CUPTI_LAZY_REINIT=1 and CUPTI_TEARDOWN=0 should only be applied if CUDA >= 12.6. Additionally, `profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()` is added as an escape hatch so that we can add a killswitch in case we see more crashes related to this.

Differential Revision: [D72842114](https://our.internmc.facebook.com/intern/diff/D72842114/)

**NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D72842114/)!

Differential Revision: [D72842114](https://our.internmc.facebook.com/intern/diff/D72842114)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151124
Approved by: https://github.com/sraikund16
This commit is contained in:
David Berard 2025-04-14 14:11:53 -07:00 committed by PyTorch MergeBot
parent c5de6ff079
commit 7d205b22b5
4 changed files with 44 additions and 2 deletions

View File

@ -13,6 +13,7 @@ from torch._inductor import config
from torch.profiler import ProfilerActivity
from torch.testing._internal.common_utils import TemporaryFileName
from torch.testing._internal.inductor_utils import HAS_CUDA
from torch.torch_version import TorchVersion
from torch.utils._triton import has_triton
@ -280,6 +281,23 @@ class DynamoProfilerTests(torch._inductor.test_case.TestCase):
for e in triton_events:
check_triton_event(e)
@unittest.skipIf(not HAS_TRITON, "requires cuda & triton")
def test_cupti_lazy_reinit(self):
x, y = (torch.randn(4, 4, device="cuda") for _ in range(2))
def fn(x, y):
return (x + y).sin()
fn_c = torch.compile(fn, mode="reduce-overhead")
with torch.profiler.profile():
fn_c(x, y)
if TorchVersion(torch.version.cuda) >= "12.6":
self.assertEqual("0", os.environ.get("DISABLE_CUPTI_LAZY_REINIT", "0"))
else:
self.assertEqual("1", os.environ.get("DISABLE_CUPTI_LAZY_REINIT", "0"))
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -54,6 +54,7 @@ from torch._utils import (
from torch._utils_internal import (
get_file_path,
prepare_multiprocessing_environment,
profiler_allow_cudagraph_cupti_lazy_reinit_cuda12,
USE_GLOBAL_DEPS,
USE_RTLD_GLOBAL_WITH_LIBTORCH,
)
@ -2301,7 +2302,16 @@ class _TorchCompileInductorWrapper:
self.apply_options(options)
self.apply_options(CompilerBisector.get_config_change("inductor"))
if self.config.get("triton.cudagraphs", False):
cuda_version = None
if hasattr(torch, "version"):
from torch.torch_version import TorchVersion
cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
if self.config.get("triton.cudagraphs", False) and (
(cuda_version and cuda_version < "12.6")
or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
):
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
# FIXME: CUDA Graph does not work well with CUPTI teardown.
# 1) crashes on 1st lazy CUPTI re-init after teardown (CUDA 11)

View File

@ -274,3 +274,7 @@ def record_chromium_event_internal(
event: dict[str, Any],
):
return None
def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12():
return True

View File

@ -23,6 +23,7 @@ from torch._C._profiler import (
_remove_execution_trace_observer,
)
from torch._environment import is_fbcode
from torch._utils_internal import profiler_allow_cudagraph_cupti_lazy_reinit_cuda12
from torch.autograd import kineto_available, ProfilerActivity
from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
@ -223,7 +224,16 @@ class _KinetoProfile:
if hasattr(torch, "_inductor"):
import torch._inductor.config as inductor_config
if inductor_config.triton.cudagraphs:
cuda_version = None
if hasattr(torch, "version"):
from torch.torch_version import TorchVersion
cuda_version = TorchVersion(getattr(torch.version, "cuda", "0.0"))
if inductor_config.triton.cudagraphs and (
(cuda_version and cuda_version < "12.6")
or not profiler_allow_cudagraph_cupti_lazy_reinit_cuda12()
):
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
self.add_metadata_json("DISABLE_CUPTI_LAZY_REINIT", "1")
# FIXME: CUDA Graph does not work well with CUPTI teardown.