[Memento] Add PT2 to Memory Snapshot (#152707)

Summary:
To add PT2 information to memory snapshot we piggyback off of the Kineto implementation using record_function similar to adding the user annotations. To do this we add the following:

1. Stack implementation that we instantiate to keep track of which compile context stack we are currently in (top element of the stack). The stack will be per device and thread-local since different threads of a process can be in different compile contexts at a given time. For this reason, we do not need to add mutexes to our stack impl since no two threads will touch a given stack
2. RecordFunction hooks to properly pipe the correct events to the compile context stack. These hooks are similar to the annotation ones in the fact that we just register them lazily and DO NOT unregister them. This is done out of convenience. In the future, we should save the handles and unregister them to minimize overhead after profiling is finished. As of now, we are registering this at the FUNCTION scope which is wide; however, we treat any function that does not start with "Torch-Compiled Region" as a no-op so we anticipate the difference in performance to be negligible during and after profiling. We also hide this feature behind a flag set to off on default so existing jobs will be unaffected
3. Piping for compile context to pickle output

Test Plan:
In D74039793, we add CompileContext to the visualizer and we see the following {F1977654658}

Differential Revision: D74028214

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152707
Approved by: https://github.com/eqy
This commit is contained in:
Shivam Raikundalia 2025-05-12 21:12:47 +00:00 committed by PyTorch MergeBot
parent f78e4529a9
commit dbb4444ce3
8 changed files with 174 additions and 12 deletions

View File

@ -33,6 +33,7 @@
#include <new>
#include <regex>
#include <set>
#include <stack>
#include <utility>
#include <vector>
@ -989,7 +990,6 @@ class RingBuffer {
// deallocation it can hold references to Python state which
// will already be destroyed when we are in exit handlers
};
} // anonymous namespace
} // namespace Native
@ -1128,6 +1128,9 @@ class DeviceCachingAllocator {
// was used while cudagraph capturing
std::unordered_map<Block*, stream_set> block_to_cudagraph_stream_uses;
// thread local compile context for each device
static thread_local std::stack<std::string> compile_context;
public:
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
DeviceCachingAllocator()
@ -1158,6 +1161,16 @@ class DeviceCachingAllocator {
return record_history;
}
void pushCompileContext(std::string& md) {
compile_context.push(md);
}
void popCompileContext() {
if (!compile_context.empty()) {
compile_context.pop();
}
}
bool checkPoolLiveAllocations(
MempoolId_t mempool_id,
const std::unordered_set<void*>& expected_live_allocations) {
@ -3294,7 +3307,10 @@ class DeviceCachingAllocator {
std::shared_ptr<GatheredContext> context) {
if (!record_history && trace_trackers_.empty())
return;
std::string compile_string = "N/A";
if (!compile_context.empty()) {
compile_string = compile_context.top();
}
auto te = TraceEntry(
action,
device,
@ -3303,7 +3319,8 @@ class DeviceCachingAllocator {
stream,
mempool_id,
getApproximateTime(),
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr);
record_context_ >= RecordContext::ALLOC ? std::move(context) : nullptr,
compile_string);
// Callbacks should not include any Pytorch call
for (const auto& cb : trace_trackers_) {
@ -3357,7 +3374,7 @@ static void uncached_delete(void* ptr) {
}
static void local_raw_delete(void* ptr);
thread_local std::stack<std::string> DeviceCachingAllocator::compile_context;
#ifdef __cpp_lib_hardware_interference_size
using std::hardware_destructive_interference_size;
#else
@ -3526,6 +3543,24 @@ class NativeCachingAllocator : public CUDAAllocator {
annotation_buffer.insertEntries(ae);
}
void pushCompileContext(std::string& md) override {
if (!record_history) {
return;
}
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->pushCompileContext(md);
}
void popCompileContext() override {
if (!record_history) {
return;
}
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));
device_allocator[device]->popCompileContext();
}
bool isHistoryEnabled() override {
c10::DeviceIndex device = 0;
C10_CUDA_CHECK(c10::cuda::GetDevice(&device));

View File

@ -117,14 +117,16 @@ struct TraceEntry {
cudaStream_t stream,
MempoolId_t mempool,
approx_time_t time,
std::shared_ptr<GatheredContext> context = nullptr)
std::shared_ptr<GatheredContext> context = nullptr,
std::string compile_context = "")
: action_(action),
device_(device),
addr_(addr),
context_(std::move(context)),
stream_(stream),
size_(size),
mempool_(std::move(mempool)) {
mempool_(std::move(mempool)),
compile_context_(std::move(compile_context)) {
time_.approx_t_ = time;
}
Action action_;
@ -135,6 +137,7 @@ struct TraceEntry {
size_t size_;
MempoolId_t mempool_;
trace_time_ time_{};
std::string compile_context_{};
};
// Calls made by record_function will save annotations
@ -284,6 +287,8 @@ class CUDAAllocator : public Allocator {
bool clearHistory) = 0;
virtual void recordAnnotation(
const std::vector<std::pair<std::string, std::string>>& /*md*/) {}
virtual void pushCompileContext(std::string& md) {}
virtual void popCompileContext() {}
virtual void attachOutOfMemoryObserver(OutOfMemoryObserver observer) = 0;
// Attached AllocatorTraceTracker callbacks will be called while the
@ -442,6 +447,14 @@ inline void recordAnnotation(
return get()->recordAnnotation(md);
}
inline void pushCompileContext(std::string& md) {
return get()->pushCompileContext(md);
}
inline void popCompileContext() {
return get()->popCompileContext();
}
inline bool isHistoryEnabled() {
return get()->isHistoryEnabled();
}

View File

@ -22,6 +22,7 @@ import psutil
import torch
import torch.cuda
import torch.nn as nn
from torch import inf, nan
from torch.cuda._memory_viz import (
_profile_to_snapshot,
@ -79,6 +80,7 @@ from torch.testing._internal.common_utils import (
TEST_WITH_ROCM,
TestCase,
)
from torch.utils._triton import has_triton
from torch.utils.checkpoint import checkpoint_sequential
from torch.utils.viz._cycles import observe_tensor_cycles
@ -3961,6 +3963,67 @@ class TestCudaMallocAsync(TestCase):
finally:
torch.cuda.memory._record_memory_history(None)
@unittest.skipIf(
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
)
@unittest.skipIf(not has_triton(), "test needs triton")
@requiresCppContext
def test_memory_compile_regions(self):
expected_allocation_sequence = [
"Torch-Compiled Region: 0/0",
"Torch-Compiled Region: 1/0",
"Torch-Compiled Region: 0/0",
]
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x):
x = self.linear1(x)
if x.sum() > 0:
x = x + 1
else:
x = x - 1
x = self.linear2(x)
return x
try:
torch.cuda.memory.empty_cache()
input_tensor = torch.randn(1, 10, device="cuda")
# Create an instance of the model
model = MyModel()
model.to("cuda")
# Compile the model using torch.compile
compiled_model = torch.compile(model)
# Create a sample input tensor
torch.cuda.memory._record_memory_history(
context="all", compile_context=True
)
compiled_model(input_tensor)
ss = torch.cuda.memory._snapshot()["device_traces"]
device_idx = 0
allocation_sequence = []
while len(ss[device_idx]) == 0:
device_idx = device_idx + 1
for s in ss[device_idx]:
context = s["compile_context"]
if context == "N/A":
continue
if len(allocation_sequence) > 0 and allocation_sequence[-1] == context:
continue
allocation_sequence.append(context)
self.assertTrue(allocation_sequence == expected_allocation_sequence)
except RuntimeError as e:
pass
finally:
torch.cuda.memory._record_memory_history(None)
@unittest.skipIf(
TEST_CUDAMALLOCASYNC, "setContextRecorder not supported by CUDAMallocAsync"
)

View File

@ -1949,6 +1949,7 @@ def _cuda_record_memory_history_legacy(
alloc_trace_max_entries: _int,
alloc_trace_record_context: _bool,
clear_history: _bool,
compile_context: _bool,
) -> None: ...
def _cuda_record_memory_history(
enabled: Optional[str],
@ -1956,6 +1957,7 @@ def _cuda_record_memory_history(
stacks: str,
max_entries: _int,
clear_history: _bool,
compile_context: _bool,
) -> None: ...
def _cuda_isHistoryEnabled() -> _bool: ...

View File

@ -749,6 +749,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
py::str is_expandable_s = "is_expandable";
py::str frames_s = "frames";
py::str time_us_s = "time_us";
py::str compile_context_s = "compile_context";
py::list empty_frames;
std::vector<CapturedTraceback*> to_gather_frames;
@ -865,6 +866,7 @@ PyObject* THCPModule_memorySnapshot(PyObject* _unused, PyObject* noargs) {
trace_entry[size_s] = te.size_;
trace_entry[stream_s] = int64_t(te.stream_);
trace_entry[time_us_s] = te.time_.t_;
trace_entry[compile_context_s] = te.compile_context_;
trace.append(trace_entry);
}
traces.append(trace);
@ -1107,7 +1109,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
m.def(
"_cuda_record_memory_history_legacy",
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool)>(
static_cast<void (*)(bool, bool, int64_t, bool, bool, bool, bool)>(
torch::cuda::_record_memory_history));
m.def(
@ -1117,6 +1119,7 @@ static void registerCudaDeviceProperties(PyObject* module) {
std::optional<std::string>,
const std::string&,
size_t,
bool,
bool)>(torch::cuda::_record_memory_history));
m.def("_cuda_isHistoryEnabled", []() {

View File

@ -117,6 +117,35 @@ void _initRecordAnnotations() {
}();
}
void _initCompileContexts() {
static auto init_placeholder [[maybe_unused]] = [&] {
// Save PT2 Compile Contexts to CCA memory snapshot tool
at::addGlobalCallback(
at::RecordFunctionCallback(
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
std::string functionName = fn.name();
const std::string functionNamePrefix = "Torch-Compiled Region";
if (functionName.compare(
0, functionNamePrefix.size(), functionNamePrefix) == 0) {
c10::cuda::CUDACachingAllocator::pushCompileContext(
functionName);
}
return nullptr;
},
[](const at::RecordFunction& fn, at::ObserverContext* ctx_ptr) {
std::string functionName = fn.name();
const std::string functionNamePrefix = "Torch-Compiled Region";
if (functionName.compare(
0, functionNamePrefix.size(), functionNamePrefix) == 0) {
c10::cuda::CUDACachingAllocator::popCompileContext();
}
})
.scopes({at::RecordScope::FUNCTION}));
return true;
}();
}
} // namespace
void _record_memory_history(
@ -125,7 +154,8 @@ void _record_memory_history(
int64_t trace_alloc_max_entries,
bool trace_alloc_record_context,
bool record_cpp_context,
bool clearHistory) {
bool clearHistory,
bool compileContext) {
c10::cuda::CUDACachingAllocator::CreateContextFn recorder = gather;
if (enabled && record_cpp_context &&
(trace_alloc_record_context || record_context)) {
@ -141,6 +171,9 @@ void _record_memory_history(
}
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
_initRecordAnnotations();
if (compileContext) {
_initCompileContexts();
}
c10::cuda::CUDACachingAllocator::recordHistory(
enabled, recorder, trace_alloc_max_entries, when, clearHistory);
}
@ -158,7 +191,8 @@ void _record_memory_history(
std::optional<std::string> context,
const std::string& stacks,
size_t max_entries,
bool clearHistory) {
bool clearHistory,
bool compileContext) {
if (enabled) {
checkOptionIn(
*enabled,
@ -193,6 +227,9 @@ void _record_memory_history(
}
at::globalContext().lazyInitDevice(c10::DeviceType::CUDA);
_initRecordAnnotations();
if (compileContext) {
_initCompileContexts();
}
c10::cuda::CUDACachingAllocator::recordHistory(
enabled.has_value(), recorder, max_entries, when, clearHistory);
}
@ -222,6 +259,7 @@ std::string _memory_snapshot_pickled() {
IValue blocks_s = "blocks";
IValue is_expandable_s = "is_expandable";
IValue time_us_s = "time_us";
IValue compile_contexts_s = "compile_context";
auto empty_frames = new_list();
@ -338,6 +376,7 @@ std::string _memory_snapshot_pickled() {
static_cast<int64_t>(te.addr_));
trace_entry.insert(size_s, (int64_t)te.size_);
trace_entry.insert(stream_s, int64_t(te.stream_));
trace_entry.insert(compile_contexts_s, te.compile_context_);
if (te.context_) {
auto sc = getFromContext(te.context_);
frame_tracebacks.push_back(sc);

View File

@ -15,14 +15,16 @@ TORCH_CUDA_CU_API void _record_memory_history(
int64_t trace_alloc_max_entries = 1,
bool trace_alloc_record_context = false,
bool record_cpp_context = false,
bool clearHistory = false);
bool clearHistory = false,
bool compileContext = false);
TORCH_CUDA_CU_API void _record_memory_history(
std::optional<std::string> enabled = "all",
std::optional<std::string> context = "all",
const std::string& stacks = "all",
size_t max_entries = SIZE_MAX,
bool clearHistory = false);
bool clearHistory = false,
bool compileContext = false);
TORCH_CUDA_CU_API std::string _memory_snapshot_pickled();

View File

@ -845,6 +845,7 @@ def _record_memory_history_legacy(
device: "Device" = None,
record_context_cpp=False,
clear_history=False,
compile_context=False,
):
_C._cuda_record_memory_history_legacy(
enabled,
@ -853,6 +854,7 @@ def _record_memory_history_legacy(
trace_alloc_record_context,
record_context_cpp,
clear_history,
compile_context,
)
@ -908,8 +910,11 @@ def _record_memory_history_impl(
max_entries: int = sys.maxsize,
device: "Device" = None,
clear_history: bool = False,
compile_context: bool = False,
):
_C._cuda_record_memory_history(enabled, context, stacks, max_entries, clear_history)
_C._cuda_record_memory_history(
enabled, context, stacks, max_entries, clear_history, compile_context
)
_record_memory_history.__signature__ = signature(_record_memory_history_impl) # type: ignore[attr-defined]