filter out alloc-free pairs from trace plot (#165752)

Summary:
When dealing with a large memory trace, the resulting plot can be challenging to interpret and analyze.
This commit introduces a feature that enables filtering of allocations that have already been freed, providing a more focused view.
The remaining events in the plot often warrant closer examination, as they may be indicative of potential out-of-memory (OOM) issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165752
Approved by: https://github.com/zdevito
This commit is contained in:
Way Wang 2025-10-29 12:44:51 +00:00 committed by PyTorch MergeBot
parent 5e7272b60a
commit 4a94591321
2 changed files with 57 additions and 2 deletions

View File

@ -4222,6 +4222,7 @@ class TestCudaMallocAsync(TestCase):
ss = torch.cuda.memory._snapshot()
trace_plot(ss)
trace_plot(ss, filter_freed=True)
segment_plot(ss)
text = json.dumps(ss)

View File

@ -446,7 +446,43 @@ def _format_viz(data, viz_kind, device):
)
def trace_plot(data, device=None, plot_segments=False):
def filter_alloc_free_pairs(data):
for dev_id in range(len(data["device_traces"])):
# set of indexes of trace events for alloc-free pairs
filterSet = set()
# map from addr to index of alloc event
allocMap = {}
# set of addrs from free_requested events
freeRequested = set()
for idx, event in enumerate(data["device_traces"][dev_id]):
if event["action"] == "alloc":
allocMap[event["addr"]] = idx
elif event["action"] == "free_requested":
freeRequested.add(event["addr"])
if allocMap.get(event["addr"]) is not None:
filterSet.add(idx)
filterSet.add(allocMap[event["addr"]])
allocMap.pop(event["addr"])
elif event["action"] == "free_completed":
if event["addr"] in freeRequested:
freeRequested.remove(event["addr"])
filterSet.add(idx)
else:
print(f"free_completed without free_requested: {event}")
# Remove events whose index is in filterSet
if filterSet:
# Create a new list excluding events with indices in filterSet
data["device_traces"][dev_id] = [
event
for idx, event in enumerate(data["device_traces"][dev_id])
if idx not in filterSet
]
return data
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
Args:
@ -454,10 +490,15 @@ def trace_plot(data, device=None, plot_segments=False):
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
Defaults to False.
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
Defaults to False to plot all trace events.
Returns:
str: HTML of visualization
"""
if filter_freed:
data = filter_alloc_free_pairs(data)
return _format_viz(
data,
"Active Memory Timeline"
@ -698,6 +739,14 @@ if __name__ == "__main__":
"-s", "--segments", action="store_true", help=help
)
help = (
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
"useful to reduce the number of events for large traces for debugging OOM"
)
trace_plot_a.add_argument(
"-f", "--filter_freed", action="store_true", help=help
)
args = parser.parse_args()
def _read(name):
@ -734,7 +783,12 @@ if __name__ == "__main__":
data = _read(args.input)
_write(
args.output,
trace_plot(data, device=args.device, plot_segments=args.segments),
trace_plot(
data,
device=args.device,
plot_segments=args.segments,
filter_freed=args.filter_freed,
),
)
elif args.action == "segment_plot":
data = _read(args.input)