Back out "Revert D23323486: DPP Async Tracing" plus windows build fix. (#44702)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44702

Original commit changeset: c6bd6d277aca

This diff caused windows build to fail due to a compiler bug in VS2019 (lambda capture constant int value). This back out works around the issue with explicit capture of const int value.

Test Plan: Tested and previously landed.

Reviewed By: mruberry

Differential Revision: D23703215

fbshipit-source-id: f9ef23be97540bc9cf78a855295fb8c69f360459
This commit is contained in:
Louis Feng 2020-09-16 11:27:46 -07:00 committed by Facebook GitHub Bot
parent ced8727d88
commit eb75cfb9c0
5 changed files with 133 additions and 18 deletions

View File

@ -92,11 +92,14 @@ class CallbackManager {
bool found_needs_ids = false;
auto init_handles = [
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
handles.clear();
size_t num_callbacks = 0;
for (const auto& cb : cbs) {
if (cb.first.shouldRun(scope)) {
handles.push_back(cb.second);
++num_callbacks;
found_active_cb = true;
if (cb.first.needsInputs()) {
found_needs_inputs = true;
@ -106,10 +109,12 @@ class CallbackManager {
}
}
}
// Pre-allocate observer context list with nullptr.
ctx_list.resize(num_callbacks);
};
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
rec_fn.active = found_active_cb;
rec_fn.needs_inputs = found_needs_inputs;
if (found_needs_ids && found_active_cb) {
@ -121,11 +126,13 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ true,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ true,
rf);
rf.called_start_callbacks_ = true;
@ -135,21 +142,30 @@ class CallbackManager {
mergeRunCallbacks(
sorted_global_callbacks_,
rf.sorted_active_global_handles_,
rf.global_ctx_,
/* is_start */ false,
rf);
mergeRunCallbacks(
sorted_tls_callbacks_,
rf.sorted_active_tls_handles_,
rf.tls_ctx_,
/* is_start */ false,
rf);
}
private:
bool tryRunCallback(
const std::function<void(const RecordFunction&)>& fn,
RecordFunction& rf) {
const RecordFunctionCallback& rfcb,
RecordFunction& rf,
std::unique_ptr<ObserverContext>& ctx,
bool is_start) {
try {
fn(rf);
if (is_start) {
ctx = rfcb.start()(rf);
}
else {
rfcb.end()(rf, ctx.get());
}
return true;
} catch (const std::exception &e) {
LOG(WARNING) << "Exception in RecordFunction callback: "
@ -165,11 +181,12 @@ class CallbackManager {
void mergeRunCallbacks(
const RecordFunctionCallbacks& sorted_callbacks,
const CallbackHandles& sorted_handles,
ObserverContextList& ctx_list,
bool is_start,
RecordFunction& rf) {
size_t num_executed = 0;
size_t idx_c = 0;
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
while (idx_c < sorted_callbacks.size() &&
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
++idx_c;
@ -178,11 +195,7 @@ class CallbackManager {
break;
}
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
if (is_start) {
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
} else {
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
}
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
++num_executed;
}
}

View File

@ -67,7 +67,16 @@ struct TORCH_API StringView {
// Soft limit on the number of callbacks to use;
constexpr std::size_t kSoftLimitCallbacks = 4;
// An abstract base class for various observer contexts that can be attached to
// the RecordFunction.
struct ObserverContext {
virtual ~ObserverContext() {}
protected:
ObserverContext() {}
};
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
typedef uint64_t RecordFunctionHandle;
struct TORCH_API RecordFunction {
@ -164,6 +173,15 @@ struct TORCH_API RecordFunction {
// public because of anonymous "friend" class
CallbackHandles sorted_active_tls_handles_;
CallbackHandles sorted_active_global_handles_;
// Stores various ObserverContext objects with event metadata for thread local
// callbacks.
ObserverContextList tls_ctx_;
// Stores various ObserverContext objects with event metadata for global
// callbacks.
ObserverContextList global_ctx_;
// Whether this RecordFunction runs any callbacks
bool active = false;
/// Whether any of the picked callbacks require inputs
@ -198,6 +216,8 @@ struct TORCH_API RecordFunction {
* RecordFunctionCallback represents a pair of callbacks to be used with
* RecordFunction, members:
* start, end - the callbacks to run when entering and exiting the scope;
* optionally, the start callback may return an ObserverContext which will
* be passed to the end callback, use appropriate constructor accordingly.
* needs_inputs - whether the callbacks need the inputs passed from the observed
* function/range; NOTE: passing the inputs incurs an additional overhead;
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
@ -211,12 +231,25 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
scopes_.fill(true);
}
// This interface is for observers that do not pass an ObserverContext object
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<void(const RecordFunction&)> start,
std::function<void(const RecordFunction&)> end =
[](const RecordFunction&) {}):
start_(std::move(start)),
end_(std::move(end)) {
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
scopes_.fill(true);
}
@ -272,11 +305,11 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}
inline const std::function<void(const RecordFunction&)>& start() const {
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
return start_;
}
inline const std::function<void(const RecordFunction&)>& end() const {
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
return end_;
}
@ -284,8 +317,8 @@ class TORCH_API RecordFunctionCallback {
bool shouldRun(RecordScope scope) const;
private:
std::function<void(const RecordFunction&)> start_;
std::function<void(const RecordFunction&)> end_;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
std::function<bool(const RecordFunctionCallback&)> should_run_;
bool needs_inputs_ = false;
bool needs_ids_ = false;

View File

@ -304,6 +304,9 @@ __host__ __device__
#endif // ANDROID / IOS
// Portably determine if a type T is trivially copyable or not.
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
// correctly. For example, T = std::unique_ptr may evaluate to true and be
// treated as POD. This can cause unexpected behavior.
#if defined(__GNUG__) && __GNUC__ < 5
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
#else

View File

@ -378,6 +378,9 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
/// This class consists of common code factored out of the SmallVector class to
/// reduce code duplication based on the SmallVector 'N' template parameter.
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
/// memory leaks.
template <typename T>
class SmallVectorImpl
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {

View File

@ -1036,6 +1036,69 @@ void testRecordFunction() {
clearCallbacks();
// START: thread local / global context check callbacks
struct TestContext : public ObserverContext {
int a{0};
std::string b;
};
ids.clear();
{ // START: global test
const int test_val = 123;
const std::string test_str = "test str";
addGlobalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ids.push_back(1);
return ctx;
},
[test_val, test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
}));
{ RECORD_USER_SCOPE("test"); }
TORCH_CHECK(ids.size() == 1);
TORCH_CHECK(ids[0] == 1);
ids.clear();
} // END: global test
{ // START: thread local test
auto ctx_th = std::thread([&ids]() {
const int test_val = 234;
const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ids.push_back(2);
return ctx;
},
[test_val, test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
}));
// Will call both global and thread local callbacks.
{ RECORD_USER_SCOPE("test_thread"); }
});
ctx_th.join();
TORCH_CHECK(ids.size() == 2);
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
ids.clear();
} // END: thread local test
clearCallbacks();
// test should_run
bool ran = false;