mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
filter out alloc-free pairs from trace plot (#165752)
Summary: When dealing with a large memory trace, the resulting plot can be challenging to interpret and analyze. This commit introduces a feature that enables filtering of allocations that have already been freed, providing a more focused view. The remaining events in the plot often warrant closer examination, as they may be indicative of potential out-of-memory (OOM) issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165752 Approved by: https://github.com/zdevito
This commit is contained in:
parent
5e7272b60a
commit
4a94591321
|
|
@ -4222,6 +4222,7 @@ class TestCudaMallocAsync(TestCase):
|
|||
ss = torch.cuda.memory._snapshot()
|
||||
|
||||
trace_plot(ss)
|
||||
trace_plot(ss, filter_freed=True)
|
||||
segment_plot(ss)
|
||||
text = json.dumps(ss)
|
||||
|
||||
|
|
|
|||
|
|
@ -446,7 +446,43 @@ def _format_viz(data, viz_kind, device):
|
|||
)
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False):
|
||||
def filter_alloc_free_pairs(data):
|
||||
for dev_id in range(len(data["device_traces"])):
|
||||
# set of indexes of trace events for alloc-free pairs
|
||||
filterSet = set()
|
||||
# map from addr to index of alloc event
|
||||
allocMap = {}
|
||||
# set of addrs from free_requested events
|
||||
freeRequested = set()
|
||||
for idx, event in enumerate(data["device_traces"][dev_id]):
|
||||
if event["action"] == "alloc":
|
||||
allocMap[event["addr"]] = idx
|
||||
elif event["action"] == "free_requested":
|
||||
freeRequested.add(event["addr"])
|
||||
if allocMap.get(event["addr"]) is not None:
|
||||
filterSet.add(idx)
|
||||
filterSet.add(allocMap[event["addr"]])
|
||||
allocMap.pop(event["addr"])
|
||||
elif event["action"] == "free_completed":
|
||||
if event["addr"] in freeRequested:
|
||||
freeRequested.remove(event["addr"])
|
||||
filterSet.add(idx)
|
||||
else:
|
||||
print(f"free_completed without free_requested: {event}")
|
||||
|
||||
# Remove events whose index is in filterSet
|
||||
if filterSet:
|
||||
# Create a new list excluding events with indices in filterSet
|
||||
data["device_traces"][dev_id] = [
|
||||
event
|
||||
for idx, event in enumerate(data["device_traces"][dev_id])
|
||||
if idx not in filterSet
|
||||
]
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def trace_plot(data, device=None, plot_segments=False, filter_freed=False):
|
||||
"""Generate a visualization over time of the memory usage recorded by the trace as an html file.
|
||||
|
||||
Args:
|
||||
|
|
@ -454,10 +490,15 @@ def trace_plot(data, device=None, plot_segments=False):
|
|||
device (torch.device, optional): Generate the trace for this device, needed if multiple devices have allocations.
|
||||
plot_segments (bool, optional): Plots memory returned from cudaMalloc, rather than individual allocations.
|
||||
Defaults to False.
|
||||
filter_freed (bool, optional): Filter out alloc-free paired events to only plot allocations that are not freed yet.
|
||||
Defaults to False to plot all trace events.
|
||||
|
||||
Returns:
|
||||
str: HTML of visualization
|
||||
"""
|
||||
if filter_freed:
|
||||
data = filter_alloc_free_pairs(data)
|
||||
|
||||
return _format_viz(
|
||||
data,
|
||||
"Active Memory Timeline"
|
||||
|
|
@ -698,6 +739,14 @@ if __name__ == "__main__":
|
|||
"-s", "--segments", action="store_true", help=help
|
||||
)
|
||||
|
||||
help = (
|
||||
"filter out allocation-free pairs to only visualize the allocations that are not freed yet;"
|
||||
"useful to reduce the number of events for large traces for debugging OOM"
|
||||
)
|
||||
trace_plot_a.add_argument(
|
||||
"-f", "--filter_freed", action="store_true", help=help
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def _read(name):
|
||||
|
|
@ -734,7 +783,12 @@ if __name__ == "__main__":
|
|||
data = _read(args.input)
|
||||
_write(
|
||||
args.output,
|
||||
trace_plot(data, device=args.device, plot_segments=args.segments),
|
||||
trace_plot(
|
||||
data,
|
||||
device=args.device,
|
||||
plot_segments=args.segments,
|
||||
filter_freed=args.filter_freed,
|
||||
),
|
||||
)
|
||||
elif args.action == "segment_plot":
|
||||
data = _read(args.input)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user