mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
correlate forward and backward op (#62553)
Summary: Use startThreadId+seqNumber of forward-op and fwdThreadId+seqNumber of backward-op to correlate pair of them. third_party/kineto should be updated accordingly: https://github.com/pytorch/kineto/pull/372 Pull Request resolved: https://github.com/pytorch/pytorch/pull/62553 Reviewed By: malfet Differential Revision: D30125728 Pulled By: gdankel fbshipit-source-id: 9877a54392ba043d0eac56ce5b7bbf244277fa7e
This commit is contained in:
parent
f0ada4bd54
commit
d35ee431d8
|
|
@ -709,5 +709,42 @@ class TestProfiler(TestCase):
|
|||
if kineto_available():
|
||||
self._test_profiler_tracing(True)
|
||||
|
||||
def test_profiler_fwd_bwd_link(self):
|
||||
with _profile(use_kineto=True) as prof:
|
||||
t1, t2 = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
|
||||
z = torch.add(t1, t2)
|
||||
y = torch.ones(1)
|
||||
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
|
||||
loss.backward()
|
||||
with TemporaryFileName(mode="w+") as fname:
|
||||
prof.export_chrome_trace(fname)
|
||||
with io.open(fname, 'r') as f:
|
||||
j = json.load(f)
|
||||
events = j["traceEvents"]
|
||||
ts_to_name = {}
|
||||
flow_s_to_ts = {}
|
||||
flow_f_to_ts = {}
|
||||
for e in events:
|
||||
if e["ph"] == "X":
|
||||
ts_to_name[e["ts"]] = e["name"]
|
||||
if "cat" in e and "name" in e and e["cat"] == "forward_backward" and e["name"] == "fwd_bwd":
|
||||
if e["ph"] == "s":
|
||||
flow_s_to_ts[e["id"]] = e["ts"]
|
||||
elif e["ph"] == "f":
|
||||
flow_f_to_ts[e["id"]] = e["ts"]
|
||||
self.assertTrue(len(flow_s_to_ts) == 2)
|
||||
self.assertTrue(len(flow_f_to_ts) == 2)
|
||||
self.assertTrue(1 in flow_s_to_ts.keys())
|
||||
self.assertTrue(1 in flow_f_to_ts.keys())
|
||||
self.assertTrue(2 in flow_s_to_ts.keys())
|
||||
self.assertTrue(2 in flow_f_to_ts.keys())
|
||||
s_ts_1 = flow_s_to_ts[1]
|
||||
f_ts_1 = flow_f_to_ts[1]
|
||||
s_ts_2 = flow_s_to_ts[2]
|
||||
f_ts_2 = flow_f_to_ts[2]
|
||||
self.assertTrue(all([ts in ts_to_name.keys() for ts in [s_ts_1, f_ts_1, s_ts_2, f_ts_2]]))
|
||||
self.assertTrue(ts_to_name[s_ts_1] == "aten::binary_cross_entropy_with_logits")
|
||||
self.assertTrue(ts_to_name[s_ts_2] == "aten::add")
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -47,6 +47,11 @@ std::string stacksToStr(const std::vector<std::string>& stacks, const char* deli
|
|||
std::string dtypesToStr(const std::vector<std::string>& types);
|
||||
std::vector<std::string> inputTypes(const at::RecordFunction& fn);
|
||||
|
||||
// Assumption: Total threads number will not exceed 2^16-1, and total ops will not exceed 2^48 -1.
|
||||
static inline uint64_t getForwardThreadKey(uint64_t tid, uint64_t seqNr) {
|
||||
return (((tid) << 48) | ((seqNr) & (((uint64_t)1 << 48) - 1)));
|
||||
}
|
||||
|
||||
struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
||||
explicit KinetoThreadLocalState(const ProfilerConfig& config)
|
||||
: ProfilerThreadLocalState(config) {
|
||||
|
|
@ -232,6 +237,11 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
|
||||
void finalizeCPUTrace() {
|
||||
TORCH_INTERNAL_ASSERT(cpu_trace->activities.size() == kineto_events_.size());
|
||||
// startThreadId_seqNum to pointer of activity.
|
||||
// Low-16bits of startThreadId and low-48bits seqNum are concatenated into one uint64_t variable as key.
|
||||
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> tidSeq2activity;
|
||||
uint64_t fwd_bwd_link_id = 1;
|
||||
|
||||
for (size_t idx = 0; idx < cpu_trace->activities.size(); ++idx) {
|
||||
auto& kineto_event = kineto_events_[idx];
|
||||
auto& activity = cpu_trace->activities[idx];
|
||||
|
|
@ -258,6 +268,43 @@ struct KinetoThreadLocalState : public ProfilerThreadLocalState {
|
|||
activity.addMetadata(
|
||||
"Sequence number",
|
||||
std::to_string(kineto_event.sequenceNr()));
|
||||
generateForwardBackwardLink(kineto_event, fwd_bwd_link_id, activity, tidSeq2activity);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void generateForwardBackwardLink(const KinetoEvent &kineto_event,
|
||||
uint64_t &fwd_bwd_link_id,
|
||||
libkineto::GenericTraceActivity &activity,
|
||||
std::unordered_map<uint64_t, libkineto::GenericTraceActivity*> &tidSeq2activity) {
|
||||
if (kineto_event.fwdThreadId() > 0) {
|
||||
// act is backward op.
|
||||
uint64_t key = getForwardThreadKey(kineto_event.fwdThreadId(), kineto_event.sequenceNr());
|
||||
auto iter = tidSeq2activity.find(key);
|
||||
if (iter != tidSeq2activity.end()) {
|
||||
libkineto::GenericTraceActivity* fwd = iter->second;
|
||||
activity.flow.linkedActivity = fwd; // Only destination side set this, to distinguish with start side.
|
||||
activity.flow.id = fwd->flow.id = fwd_bwd_link_id;
|
||||
activity.flow.type = fwd->flow.type = libkineto::kLinkFwdBwd;
|
||||
++fwd_bwd_link_id;
|
||||
}
|
||||
}
|
||||
else if (kineto_event.startThreadId() != 0) {
|
||||
// act is forward op.
|
||||
uint64_t key = getForwardThreadKey(kineto_event.startThreadId(), kineto_event.sequenceNr());
|
||||
// Assumption: Among all ops with same sequence number,
|
||||
// the one with biggest start time is most likely launching backward op.
|
||||
auto iter = tidSeq2activity.find(key);
|
||||
if (iter == tidSeq2activity.end()) {
|
||||
tidSeq2activity[key] = &activity;
|
||||
}
|
||||
else {
|
||||
// Now the sequence number is only incremented on creating a "Node" object for backward pass,
|
||||
// by calling "at::sequence_number::get_and_increment()".
|
||||
// Among all ops with same sequence number, the one with biggest startTime is the one launching backward op.
|
||||
if (activity.startTime >= iter->second->startTime) {
|
||||
tidSeq2activity[key] = &activity;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user