DPP Async Tracing (#44252)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44252

Add tracing to DPP client. Because DPP requests are async, we need to be able to start a trace event in one thread and potentially end in a different thread. RecordFunction and LibgpumonObserver previously assume each trace event starts and finishes in the same thread. So they use a thread local context to track enter and exit call backs. Async events breaks this assumption. This change attaches the event context to the RecordFunction object so we do not need to use thread local context.

Test Plan:
Tested with dpp perf test and able to collect trace.

{F307824044}

Reviewed By: ilia-cher

Differential Revision: D23323486

fbshipit-source-id: 4b6ca6c0e32028fb38a476cd1f44c17a001fc03b
This commit is contained in:
Louis Feng 2020-09-14 18:38:02 -07:00 committed by Facebook GitHub Bot
parent e107ef5ca2
commit 71673b31f9
5 changed files with 133 additions and 18 deletions

View File

@ -92,11 +92,14 @@ class CallbackManager {
bool found_needs_ids = false;
auto init_handles = [
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
handles.clear();
size_t num_callbacks = 0;
for (const auto& cb : cbs) {
if (cb.first.shouldRun(scope)) {
handles.push_back(cb.second);
++num_callbacks;
found_active_cb = true;
if (cb.first.needsInputs()) {
found_needs_inputs = true;
@ -106,10 +109,12 @@ class CallbackManager {
}
}
}
// Pre-allocate observer context list with nullptr.
ctx_list.resize(num_callbacks);
};
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
rec_fn.active = found_active_cb;
rec_fn.needs_inputs = found_needs_inputs;
if (found_needs_ids && found_active_cb) {
@ -121,11 +126,13 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ true,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ true,
rf);
rf.called_start_callbacks_ = true;
@ -135,21 +142,30 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ false,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ false,
rf);
}
private:
bool tryRunCallback(
const std::function<void(const RecordFunction&)>& fn,
RecordFunction& rf) {
const RecordFunctionCallback& rfcb,
RecordFunction& rf,
std::unique_ptr<ObserverContext>& ctx,
bool is_start) {
try {
fn(rf);
if (is_start) {
ctx = rfcb.start()(rf);
}
else {
rfcb.end()(rf, ctx.get());
}
return true;
} catch (const std::exception &e) {
LOG(WARNING) << "Exception in RecordFunction callback: "
@ -165,11 +181,12 @@ class CallbackManager {
void mergeRunCallbacks(
const RecordFunctionCallbacks& sorted_callbacks,
const CallbackHandles& sorted_handles,
ObserverContextList& ctx_list,
bool is_start,
RecordFunction& rf) {
size_t num_executed = 0;
size_t idx_c = 0;
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
while (idx_c < sorted_callbacks.size() &&
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
++idx_c;
@ -178,11 +195,7 @@ class CallbackManager {
break;
}
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
if (is_start) {
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
} else {
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
}
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
++num_executed;
}
}

View File

@ -67,7 +67,16 @@ struct TORCH_API StringView {
// Soft limit on the number of callbacks to use;
constexpr std::size_t kSoftLimitCallbacks = 4;
// An abstract base class for various observer contexts that can be attached to
// the RecordFunction.
struct ObserverContext {
virtual ~ObserverContext() {}
protected:
ObserverContext() {}
};
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
typedef uint64_t RecordFunctionHandle;
struct TORCH_API RecordFunction {
@ -164,6 +173,15 @@ struct TORCH_API RecordFunction {
// public because of anonymous "friend" class
CallbackHandles sorted_active_tls_handles_;
CallbackHandles sorted_active_global_handles_;
// Stores various ObserverContext objects with event metadata for thread local
// callbacks.
ObserverContextList tls_ctx_;
// Stores various ObserverContext objects with event metadata for global
// callbacks.
ObserverContextList global_ctx_;
// Whether this RecordFunction runs any callbacks
bool active = false;
/// Whether any of the picked callbacks require inputs
@ -198,6 +216,8 @@ struct TORCH_API RecordFunction {
* RecordFunctionCallback represents a pair of callbacks to be used with
* RecordFunction, members:
* start, end - the callbacks to run when entering and exiting the scope;
* optionally, the start callback may return an ObserverContext which will
* be passed to the end callback, use appropriate constructor accordingly.
* needs_inputs - whether the callbacks need the inputs passed from the observed
* function/range; NOTE: passing the inputs incurs an additional overhead;
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
@ -211,12 +231,25 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
scopes_.fill(true);
}
// This interface is for observers that do not pass an ObserverContext object
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<void(const RecordFunction&)> start,
std::function<void(const RecordFunction&)> end =
[](const RecordFunction&) {}):
start_(std::move(start)),
end_(std::move(end)) {
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
scopes_.fill(true);
}
@ -272,11 +305,11 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}
inline const std::function<void(const RecordFunction&)>& start() const {
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
return start_;
}
inline const std::function<void(const RecordFunction&)>& end() const {
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
return end_;
}
@ -284,8 +317,8 @@ class TORCH_API RecordFunctionCallback {
bool shouldRun(RecordScope scope) const;
private:
std::function<void(const RecordFunction&)> start_;
std::function<void(const RecordFunction&)> end_;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
std::function<bool(const RecordFunctionCallback&)> should_run_;
bool needs_inputs_ = false;
bool needs_ids_ = false;

View File

@ -304,6 +304,9 @@ __host__ __device__
#endif // ANDROID / IOS
// Portably determine if a type T is trivially copyable or not.
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
// correctly. For example, T = std::unique_ptr may evaluate to true and be
// treated as POD. This can cause unexpected behavior.
#if defined(__GNUG__) && __GNUC__ < 5
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
#else

View File

@ -378,6 +378,9 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
/// This class consists of common code factored out of the SmallVector class to
/// reduce code duplication based on the SmallVector 'N' template parameter.
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
/// memory leaks.
template <typename T>
class SmallVectorImpl
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {

View File

@ -1036,6 +1036,69 @@ void testRecordFunction() {
clearCallbacks();
// START: thread local / global context check callbacks
struct TestContext : public ObserverContext {
int a{0};
std::string b;
};
ids.clear();
{ // START: global test
const int test_val = 123;
const std::string test_str = "test str";
addGlobalCallback(RecordFunctionCallback(
[test_str, &ids](const RecordFunction& /* unused */) {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ids.push_back(1);
return ctx;
},
[test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
} // END: global test
{ // START: thread local test
auto ctx_th = std::thread([&ids]() {
const int test_val = 234;
const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback(
[test_str, &ids](const RecordFunction& /* unused */) {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ids.push_back(2);
return ctx;
},
[test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
}));
// Will call both global and thread local callbacks.
{ RECORD_USER_SCOPE("test_thread"); }
});
ctx_th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
} // END: thread local test
clearCallbacks();
// test should_run
bool ran = false;