mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/43898 Adding with_source parameter to enable tracking source code (filename and line) in profiler for eager, torchscript and autograd modes Test Plan: python test/test_profiler.py ``` Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls Source Location ----------------------------------- --------------- --------------- --------------- --------------- --------------- --------------- -------------------------------------------- ts_method_1 10.43% 235.364us 36.46% 822.920us 822.920us 1 test/test_profiler.py(70): test_source aten::add 7.52% 169.833us 8.88% 200.439us 200.439us 1 test/test_profiler.py(69): test_source aten::normal_ 6.26% 141.380us 6.26% 141.380us 141.380us 1 test/test_profiler.py(67): test_source aten::add 5.80% 130.830us 8.41% 189.800us 63.267us 3 test/test_profiler.py(72): test_source aten::sum 5.02% 113.340us 8.39% 189.475us 189.475us 1 test/test_profiler.py(64): ts_method_1 aten::add 4.58% 103.346us 6.33% 142.847us 142.847us 1 test/test_profiler.py(62): ts_method_1 aten::mul 4.05% 91.498us 9.62% 217.113us 217.113us 1 test/test_profiler.py(71): test_source aten::add 4.03% 90.880us 5.60% 126.405us 126.405us 1 test/test_profiler.py(58): ts_method_2 aten::empty 3.49% 78.735us 3.49% 78.735us 19.684us 4 test/test_profiler.py(72): test_source ``` Reviewed By: ngimel Differential Revision: D23432664 Pulled By: ilia-cher fbshipit-source-id: 83ad7ebe0c2502494d3b48c4e687802db9c77615
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import collections
|
|
import gc
|
|
import unittest
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase, run_tests, TEST_WITH_ASAN, IS_WINDOWS)
|
|
from torch.autograd.profiler import profile
|
|
|
|
try:
|
|
import psutil
|
|
HAS_PSUTIL = True
|
|
except ImportError:
|
|
HAS_PSUTIL = False
|
|
|
|
|
|
@unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run")
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN")
|
|
@unittest.skipIf(IS_WINDOWS, "Test is flaky on Windows")
|
|
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is required")
|
|
class TestProfilerCUDA(TestCase):
|
|
def test_mem_leak(self):
|
|
"""Checks that there's no memory leak when using profiler with CUDA
|
|
"""
|
|
t = torch.rand(1, 1).cuda()
|
|
p = psutil.Process()
|
|
last_rss = collections.deque(maxlen=5)
|
|
for outer_idx in range(10):
|
|
with profile(use_cuda=True):
|
|
for _ in range(1024):
|
|
t = torch.mm(t, t)
|
|
|
|
gc.collect()
|
|
torch.cuda.empty_cache()
|
|
last_rss.append(p.memory_info().rss)
|
|
|
|
# with CUDA events leaking the increase in memory was ~7 MB between
|
|
# profiler invocations above
|
|
is_increasing = all(
|
|
[last_rss[idx] > last_rss[idx - 1] for idx in range(1, len(last_rss))])
|
|
max_diff = -1
|
|
for idx in range(1, len(last_rss)):
|
|
max_diff = max(max_diff, last_rss[idx] - last_rss[idx - 1])
|
|
self.assertTrue(not (is_increasing and max_diff > 100 * 1024),
|
|
msg='memory usage is increasing, {}'.format(str(last_rss)))
|
|
|
|
class TestProfiler(TestCase):
|
|
def test_source(self):
|
|
"""Checks that source code attribution works for eager, TS and autograd mode
|
|
"""
|
|
# avoid automatic inlining
|
|
prev_opt = torch._C._get_graph_executor_optimize()
|
|
torch._C._set_graph_executor_optimize(False)
|
|
|
|
@torch.jit.script
|
|
def ts_method_2(x, y):
|
|
return torch.matmul(x, y)
|
|
|
|
@torch.jit.script
|
|
def ts_method_1(x, y, z):
|
|
a = x + z
|
|
w = ts_method_2(x, y) + a
|
|
return w.sum()
|
|
|
|
class DummyModule(nn.Module):
|
|
def __init__(self):
|
|
super(DummyModule, self).__init__()
|
|
self.conv = torch.nn.Conv2d(3, 2, kernel_size=1, stride=2, padding=3, bias=False)
|
|
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
|
|
mod = DummyModule()
|
|
|
|
with profile(with_stack=True) as p:
|
|
x = torch.randn(10, 10, requires_grad=True)
|
|
y = torch.randn(10, 10, requires_grad=True)
|
|
z = x + y
|
|
w = ts_method_1(x, y, z)
|
|
v = 2 * w
|
|
v.backward()
|
|
a = torch.randn(2, 3, 2, 2, requires_grad=True)
|
|
b = mod(a)
|
|
c = b.sum()
|
|
c.backward()
|
|
|
|
print(p.key_averages(
|
|
group_by_stack_n=5).table(
|
|
sort_by="self_cpu_time_total", row_limit=-1))
|
|
|
|
for e in p.function_events:
|
|
if "aten::add" in e.name or "AddBackward" in e.name:
|
|
self.assertTrue(any(["test_profiler" in entry for entry in e.stack]))
|
|
self.assertTrue(any([(
|
|
"test_source" in entry or
|
|
"ts_method_1" in entry or
|
|
"ts_method_2" in entry) for entry in e.stack]))
|
|
|
|
torch._C._set_graph_executor_optimize(prev_opt)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|