diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 5c0d77b1e8e..803e3a56fc8 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -989,7 +990,6 @@ class RingBuffer { // deallocation it can hold references to Python state which // will already be destroyed when we are in exit handlers }; - } // anonymous namespace } // namespace Native @@ -1128,6 +1128,9 @@ class DeviceCachingAllocator { // was used while cudagraph capturing std::unordered_map block_to_cudagraph_stream_uses; + // thread local compile context for each device + static thread_local std::stack compile_context; + public: // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) DeviceCachingAllocator() @@ -1158,6 +1161,16 @@ class DeviceCachingAllocator { return record_history; } + void pushCompileContext(std::string& md) { + compile_context.push(md); + } + + void popCompileContext() { + if (!compile_context.empty()) { + compile_context.pop(); + } + } + bool checkPoolLiveAllocations( MempoolId_t mempool_id, const std::unordered_set& expected_live_allocations) { @@ -3294,7 +3307,10 @@ class DeviceCachingAllocator { std::shared_ptr context) { if (!record_history && trace_trackers_.empty()) return; - + std::string compile_string = "N/A"; + if (!compile_context.empty()) { + compile_string = compile_context.top(); + } auto te = TraceEntry( action, device, @@ -3303,7 +3319,8 @@ class DeviceCachingAllocator { stream, mempool_id, getApproximateTime(), - record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr); + record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr, + compile_string); // Callbacks should not include any Pytorch call for (const auto& cb : trace_trackers_) { @@ -3357,7 +3374,7 @@ static void uncached_delete(void* ptr) { } static void local_raw_delete(void* ptr); - +thread_local std::stack DeviceCachingAllocator::compile_context; #ifdef __cpp_lib_hardware_interference_size using std::hardware_destructive_interference_size; #else @@ -3526,6 +3543,24 @@ class NativeCachingAllocator : public CUDAAllocator { annotation_buffer.insertEntries(ae); } + void pushCompileContext(std::string& md) override { + if (!record_history) { + return; + } + c10::DeviceIndex device = 0; + C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + device_allocator[device]->pushCompileContext(md); + } + + void popCompileContext() override { + if (!record_history) { + return; + } + c10::DeviceIndex device = 0; + C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + device_allocator[device]->popCompileContext(); + } + bool isHistoryEnabled() override { c10::DeviceIndex device = 0; C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); diff --git a/c10/cuda/CUDACachingAllocator.h b/c10/cuda/CUDACachingAllocator.h index d3facd929db..2d01d75bb92 100644 --- a/c10/cuda/CUDACachingAllocator.h +++ b/c10/cuda/CUDACachingAllocator.h @@ -117,14 +117,16 @@ struct TraceEntry { cudaStream_t stream, MempoolId_t mempool, approx_time_t time, - std::shared_ptr context = nullptr) + std::shared_ptr context = nullptr, + std::string compile_context = "") : action_(action), device_(device), addr_(addr), context_(std::move(context)), stream_(stream), size_(size), - mempool_(std::move(mempool)) { + mempool_(std::move(mempool)), + compile_context_(std::move(compile_context)) { time_.approx_t_ = time; } Action action_; @@ -135,6 +137,7 @@ struct TraceEntry { size_t size_; MempoolId_t mempool_; trace_time_ time_{}; + std::string compile_context_{}; }; // Calls made by record_function will save annotations @@ -284,6 +287,8 @@ class CUDAAllocator : public Allocator { bool clearHistory) = 0; virtual void recordAnnotation( const std::vector>& /*md*/) {} + virtual void pushCompileContext(std::string& md) {} + virtual void popCompileContext() {} virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0; // Attached AllocatorTraceTracker callbacks will be called while the @@ -442,6 +447,14 @@ inline void recordAnnotation( return get()->recordAnnotation(md); } +inline void pushCompileContext(std::string& md) { + return get()->pushCompileContext(md); +} + +inline void popCompileContext() { + return get()->popCompileContext(); +} + inline bool isHistoryEnabled() { return get()->isHistoryEnabled(); } diff --git a/test/test_cuda.py b/test/test_cuda.py index 391b45b4596..ec84b9fb170 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -22,6 +22,7 @@ import psutil import torch import torch.cuda +import torch.nn as nn from torch import inf, nan from torch.cuda._memory_viz import ( _profile_to_snapshot, @@ -79,6 +80,7 @@ from torch.testing._internal.common_utils import ( TEST_WITH_ROCM, TestCase, ) +from torch.utils._triton import has_triton from torch.utils.checkpoint import checkpoint_sequential from torch.utils.viz._cycles import observe_tensor_cycles @@ -3961,6 +3963,67 @@ class TestCudaMallocAsync(TestCase): finally: torch.cuda.memory._record_memory_history(None) + @unittest.skipIf( + TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" + ) + @unittest.skipIf(not has_triton(), "test needs triton") + @requiresCppContext + def test_memory_compile_regions(self): + expected_allocation_sequence = [ + "Torch-Compiled Region: 0/0", + "Torch-Compiled Region: 1/0", + "Torch-Compiled Region: 0/0", + ] + + class MyModel(nn.Module): + def __init__(self): + super().__init__() + self.linear1 = nn.Linear(10, 10) + self.linear2 = nn.Linear(10, 10) + + def forward(self, x): + x = self.linear1(x) + + if x.sum() > 0: + x = x + 1 + else: + x = x - 1 + + x = self.linear2(x) + + return x + + try: + torch.cuda.memory.empty_cache() + input_tensor = torch.randn(1, 10, device="cuda") + # Create an instance of the model + model = MyModel() + model.to("cuda") + # Compile the model using torch.compile + compiled_model = torch.compile(model) + # Create a sample input tensor + torch.cuda.memory._record_memory_history( + context="all", compile_context=True + ) + compiled_model(input_tensor) + ss = torch.cuda.memory._snapshot()["device_traces"] + device_idx = 0 + allocation_sequence = [] + while len(ss[device_idx]) == 0: + device_idx = device_idx + 1 + for s in ss[device_idx]: + context = s["compile_context"] + if context == "N/A": + continue + if len(allocation_sequence) > 0 and allocation_sequence[-1] == context: + continue + allocation_sequence.append(context) + self.assertTrue(allocation_sequence == expected_allocation_sequence) + except RuntimeError as e: + pass + finally: + torch.cuda.memory._record_memory_history(None) + @unittest.skipIf( TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync" ) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index f325da9135b..6ffdb5f2af2 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1949,6 +1949,7 @@ def _cuda_record_memory_history_legacy( alloc_trace_max_entries: _int, alloc_trace_record_context: _bool, clear_history: _bool, + compile_context: _bool, ) -> None: ... def _cuda_record_memory_history( enabled: Optional[str], @@ -1956,6 +1957,7 @@ def _cuda_record_memory_history( stacks: str, max_entries: _int, clear_history: _bool, + compile_context: _bool, ) -> None: ... def _cuda_isHistoryEnabled() -> _bool: ... diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index ae71e1cd9ce..0391067d886 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -749,6 +749,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { py::str is_expandable_s = "is_expandable"; py::str frames_s = "frames"; py::str time_us_s = "time_us"; + py::str compile_context_s = "compile_context"; py::list empty_frames; std::vector to_gather_frames; @@ -865,6 +866,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) { trace_entry[size_s] = te.size_; trace_entry[stream_s] = int64_t(te.stream_); trace_entry[time_us_s] = te.time_.t_; + trace_entry[compile_context_s] = te.compile_context_; trace.append(trace_entry); } traces.append(trace); @@ -1107,7 +1109,7 @@ static void registerCudaDeviceProperties(PyObject* module) { m.def( "_cuda_record_memory_history_legacy", - static_cast( + static_cast( torch::cuda::_record_memory_history)); m.def( @@ -1117,6 +1119,7 @@ static void registerCudaDeviceProperties(PyObject* module) { std::optional, const std::string&, size_t, + bool, bool)>(torch::cuda::_record_memory_history)); m.def("_cuda_isHistoryEnabled", []() { diff --git a/torch/csrc/cuda/memory_snapshot.cpp b/torch/csrc/cuda/memory_snapshot.cpp index fe5eb955740..fe19e803da2 100644 --- a/torch/csrc/cuda/memory_snapshot.cpp +++ b/torch/csrc/cuda/memory_snapshot.cpp @@ -117,6 +117,35 @@ void _initRecordAnnotations() { }(); } +void _initCompileContexts() { + static auto init_placeholder [[maybe_unused]] = [&] { + // Save PT2 Compile Contexts to CCA memory snapshot tool + at::addGlobalCallback( + at::RecordFunctionCallback( + [](const at::RecordFunction& fn) + -> std::unique_ptr { + std::string functionName = fn.name(); + const std::string functionNamePrefix = "Torch-Compiled Region"; + if (functionName.compare( + 0, functionNamePrefix.size(), functionNamePrefix) == 0) { + c10::cuda::CUDACachingAllocator::pushCompileContext( + functionName); + } + return nullptr; + }, + [](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) { + std::string functionName = fn.name(); + const std::string functionNamePrefix = "Torch-Compiled Region"; + if (functionName.compare( + 0, functionNamePrefix.size(), functionNamePrefix) == 0) { + c10::cuda::CUDACachingAllocator::popCompileContext(); + } + }) + .scopes({at::RecordScope::FUNCTION})); + return true; + }(); +} + } // namespace void _record_memory_history( @@ -125,7 +154,8 @@ void _record_memory_history( int64_t trace_alloc_max_entries, bool trace_alloc_record_context, bool record_cpp_context, - bool clearHistory) { + bool clearHistory, + bool compileContext) { c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather; if (enabled && record_cpp_context && (trace_alloc_record_context || record_context)) { @@ -141,6 +171,9 @@ void _record_memory_history( } at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); + if (compileContext) { + _initCompileContexts(); + } c10::cuda::CUDACachingAllocator::recordHistory( enabled, recorder, trace_alloc_max_entries, when, clearHistory); } @@ -158,7 +191,8 @@ void _record_memory_history( std::optional context, const std::string& stacks, size_t max_entries, - bool clearHistory) { + bool clearHistory, + bool compileContext) { if (enabled) { checkOptionIn( *enabled, @@ -193,6 +227,9 @@ void _record_memory_history( } at::globalContext().lazyInitDevice(c10::DeviceType::CUDA); _initRecordAnnotations(); + if (compileContext) { + _initCompileContexts(); + } c10::cuda::CUDACachingAllocator::recordHistory( enabled.has_value(), recorder, max_entries, when, clearHistory); } @@ -222,6 +259,7 @@ std::string _memory_snapshot_pickled() { IValue blocks_s = "blocks"; IValue is_expandable_s = "is_expandable"; IValue time_us_s = "time_us"; + IValue compile_contexts_s = "compile_context"; auto empty_frames = new_list(); @@ -338,6 +376,7 @@ std::string _memory_snapshot_pickled() { static_cast(te.addr_)); trace_entry.insert(size_s, (int64_t)te.size_); trace_entry.insert(stream_s, int64_t(te.stream_)); + trace_entry.insert(compile_contexts_s, te.compile_context_); if (te.context_) { auto sc = getFromContext(te.context_); frame_tracebacks.push_back(sc); diff --git a/torch/csrc/cuda/memory_snapshot.h b/torch/csrc/cuda/memory_snapshot.h index 28dfbedc7b2..5d89f2f6534 100644 --- a/torch/csrc/cuda/memory_snapshot.h +++ b/torch/csrc/cuda/memory_snapshot.h @@ -15,14 +15,16 @@ TORCH_CUDA_CU_API void _record_memory_history( int64_t trace_alloc_max_entries = 1, bool trace_alloc_record_context = false, bool record_cpp_context = false, - bool clearHistory = false); + bool clearHistory = false, + bool compileContext = false); TORCH_CUDA_CU_API void _record_memory_history( std::optional enabled = "all", std::optional context = "all", const std::string& stacks = "all", size_t max_entries = SIZE_MAX, - bool clearHistory = false); + bool clearHistory = false, + bool compileContext = false); TORCH_CUDA_CU_API std::string _memory_snapshot_pickled(); diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 736ff875fb7..71093e8039c 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -845,6 +845,7 @@ def _record_memory_history_legacy( device: "Device" = None, record_context_cpp=False, clear_history=False, + compile_context=False, ): _C._cuda_record_memory_history_legacy( enabled, @@ -853,6 +854,7 @@ def _record_memory_history_legacy( trace_alloc_record_context, record_context_cpp, clear_history, + compile_context, ) @@ -908,8 +910,11 @@ def _record_memory_history_impl( max_entries: int = sys.maxsize, device: "Device" = None, clear_history: bool = False, + compile_context: bool = False, ): - _C._cuda_record_memory_history(enabled, context, stacks, max_entries, clear_history) + _C._cuda_record_memory_history( + enabled, context, stacks, max_entries, clear_history, compile_context + ) _record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]