mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Back out "Revert D23323486: DPP Async Tracing" plus windows build fix. (#44702)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44702 Original commit changeset: c6bd6d277aca This diff caused windows build to fail due to a compiler bug in VS2019 (lambda capture constant int value). This back out works around the issue with explicit capture of const int value. Test Plan: Tested and previously landed. Reviewed By: mruberry Differential Revision: D23703215 fbshipit-source-id: f9ef23be97540bc9cf78a855295fb8c69f360459
This commit is contained in:
parent
ced8727d88
commit
eb75cfb9c0
|
|
@ -92,11 +92,14 @@ class CallbackManager {
|
|||
bool found_needs_ids = false;
|
||||
auto init_handles = [
|
||||
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
|
||||
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
|
||||
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
|
||||
handles.clear();
|
||||
|
||||
size_t num_callbacks = 0;
|
||||
for (const auto& cb : cbs) {
|
||||
if (cb.first.shouldRun(scope)) {
|
||||
handles.push_back(cb.second);
|
||||
++num_callbacks;
|
||||
found_active_cb = true;
|
||||
if (cb.first.needsInputs()) {
|
||||
found_needs_inputs = true;
|
||||
|
|
@ -106,10 +109,12 @@ class CallbackManager {
|
|||
}
|
||||
}
|
||||
}
|
||||
// Pre-allocate observer context list with nullptr.
|
||||
ctx_list.resize(num_callbacks);
|
||||
};
|
||||
|
||||
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
|
||||
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
|
||||
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
|
||||
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
|
||||
rec_fn.active = found_active_cb;
|
||||
rec_fn.needs_inputs = found_needs_inputs;
|
||||
if (found_needs_ids && found_active_cb) {
|
||||
|
|
@ -121,11 +126,13 @@ class CallbackManager {
|
|||
mergeRunCallbacks(
|
||||
sorted_global_callbacks_,
|
||||
rf.sorted_active_global_handles_,
|
||||
rf.global_ctx_,
|
||||
/* is_start */ true,
|
||||
rf);
|
||||
mergeRunCallbacks(
|
||||
sorted_tls_callbacks_,
|
||||
rf.sorted_active_tls_handles_,
|
||||
rf.tls_ctx_,
|
||||
/* is_start */ true,
|
||||
rf);
|
||||
rf.called_start_callbacks_ = true;
|
||||
|
|
@ -135,21 +142,30 @@ class CallbackManager {
|
|||
mergeRunCallbacks(
|
||||
sorted_global_callbacks_,
|
||||
rf.sorted_active_global_handles_,
|
||||
rf.global_ctx_,
|
||||
/* is_start */ false,
|
||||
rf);
|
||||
mergeRunCallbacks(
|
||||
sorted_tls_callbacks_,
|
||||
rf.sorted_active_tls_handles_,
|
||||
rf.tls_ctx_,
|
||||
/* is_start */ false,
|
||||
rf);
|
||||
}
|
||||
|
||||
private:
|
||||
bool tryRunCallback(
|
||||
const std::function<void(const RecordFunction&)>& fn,
|
||||
RecordFunction& rf) {
|
||||
const RecordFunctionCallback& rfcb,
|
||||
RecordFunction& rf,
|
||||
std::unique_ptr<ObserverContext>& ctx,
|
||||
bool is_start) {
|
||||
try {
|
||||
fn(rf);
|
||||
if (is_start) {
|
||||
ctx = rfcb.start()(rf);
|
||||
}
|
||||
else {
|
||||
rfcb.end()(rf, ctx.get());
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception &e) {
|
||||
LOG(WARNING) << "Exception in RecordFunction callback: "
|
||||
|
|
@ -165,11 +181,12 @@ class CallbackManager {
|
|||
void mergeRunCallbacks(
|
||||
const RecordFunctionCallbacks& sorted_callbacks,
|
||||
const CallbackHandles& sorted_handles,
|
||||
ObserverContextList& ctx_list,
|
||||
bool is_start,
|
||||
RecordFunction& rf) {
|
||||
size_t num_executed = 0;
|
||||
size_t idx_c = 0;
|
||||
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
|
||||
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
|
||||
while (idx_c < sorted_callbacks.size() &&
|
||||
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
|
||||
++idx_c;
|
||||
|
|
@ -178,11 +195,7 @@ class CallbackManager {
|
|||
break;
|
||||
}
|
||||
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
|
||||
if (is_start) {
|
||||
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
|
||||
} else {
|
||||
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
|
||||
}
|
||||
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
|
||||
++num_executed;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,16 @@ struct TORCH_API StringView {
|
|||
// Soft limit on the number of callbacks to use;
|
||||
constexpr std::size_t kSoftLimitCallbacks = 4;
|
||||
|
||||
// An abstract base class for various observer contexts that can be attached to
|
||||
// the RecordFunction.
|
||||
struct ObserverContext {
|
||||
virtual ~ObserverContext() {}
|
||||
protected:
|
||||
ObserverContext() {}
|
||||
};
|
||||
|
||||
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
|
||||
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
|
||||
typedef uint64_t RecordFunctionHandle;
|
||||
|
||||
struct TORCH_API RecordFunction {
|
||||
|
|
@ -164,6 +173,15 @@ struct TORCH_API RecordFunction {
|
|||
// public because of anonymous "friend" class
|
||||
CallbackHandles sorted_active_tls_handles_;
|
||||
CallbackHandles sorted_active_global_handles_;
|
||||
|
||||
// Stores various ObserverContext objects with event metadata for thread local
|
||||
// callbacks.
|
||||
ObserverContextList tls_ctx_;
|
||||
|
||||
// Stores various ObserverContext objects with event metadata for global
|
||||
// callbacks.
|
||||
ObserverContextList global_ctx_;
|
||||
|
||||
// Whether this RecordFunction runs any callbacks
|
||||
bool active = false;
|
||||
/// Whether any of the picked callbacks require inputs
|
||||
|
|
@ -198,6 +216,8 @@ struct TORCH_API RecordFunction {
|
|||
* RecordFunctionCallback represents a pair of callbacks to be used with
|
||||
* RecordFunction, members:
|
||||
* start, end - the callbacks to run when entering and exiting the scope;
|
||||
* optionally, the start callback may return an ObserverContext which will
|
||||
* be passed to the end callback, use appropriate constructor accordingly.
|
||||
* needs_inputs - whether the callbacks need the inputs passed from the observed
|
||||
* function/range; NOTE: passing the inputs incurs an additional overhead;
|
||||
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
|
||||
|
|
@ -211,12 +231,25 @@ struct TORCH_API RecordFunction {
|
|||
*/
|
||||
class TORCH_API RecordFunctionCallback {
|
||||
public:
|
||||
// This interface supports observers that require passing an ObserverContext
|
||||
// between start and end callbacks.
|
||||
explicit RecordFunctionCallback(
|
||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
|
||||
std::function<void(const RecordFunction&, ObserverContext*)> end =
|
||||
[](const RecordFunction&, ObserverContext*) {}):
|
||||
start_(std::move(start)),
|
||||
end_(std::move(end)) {
|
||||
scopes_.fill(true);
|
||||
}
|
||||
|
||||
// This interface is for observers that do not pass an ObserverContext object
|
||||
// between start and end callbacks.
|
||||
explicit RecordFunctionCallback(
|
||||
std::function<void(const RecordFunction&)> start,
|
||||
std::function<void(const RecordFunction&)> end =
|
||||
[](const RecordFunction&) {}):
|
||||
start_(std::move(start)),
|
||||
end_(std::move(end)) {
|
||||
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
|
||||
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
|
||||
scopes_.fill(true);
|
||||
}
|
||||
|
||||
|
|
@ -272,11 +305,11 @@ class TORCH_API RecordFunctionCallback {
|
|||
return scopes_[(size_t)sc];
|
||||
}
|
||||
|
||||
inline const std::function<void(const RecordFunction&)>& start() const {
|
||||
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
inline const std::function<void(const RecordFunction&)>& end() const {
|
||||
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
|
|
@ -284,8 +317,8 @@ class TORCH_API RecordFunctionCallback {
|
|||
bool shouldRun(RecordScope scope) const;
|
||||
|
||||
private:
|
||||
std::function<void(const RecordFunction&)> start_;
|
||||
std::function<void(const RecordFunction&)> end_;
|
||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
|
||||
std::function<void(const RecordFunction&, ObserverContext*)> end_;
|
||||
std::function<bool(const RecordFunctionCallback&)> should_run_;
|
||||
bool needs_inputs_ = false;
|
||||
bool needs_ids_ = false;
|
||||
|
|
|
|||
|
|
@ -304,6 +304,9 @@ __host__ __device__
|
|||
#endif // ANDROID / IOS
|
||||
|
||||
// Portably determine if a type T is trivially copyable or not.
|
||||
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
|
||||
// correctly. For example, T = std::unique_ptr may evaluate to true and be
|
||||
// treated as POD. This can cause unexpected behavior.
|
||||
#if defined(__GNUG__) && __GNUC__ < 5
|
||||
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -378,6 +378,9 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
|
|||
|
||||
/// This class consists of common code factored out of the SmallVector class to
|
||||
/// reduce code duplication based on the SmallVector 'N' template parameter.
|
||||
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
|
||||
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
|
||||
/// memory leaks.
|
||||
template <typename T>
|
||||
class SmallVectorImpl
|
||||
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {
|
||||
|
|
|
|||
|
|
@ -1036,6 +1036,69 @@ void testRecordFunction() {
|
|||
|
||||
clearCallbacks();
|
||||
|
||||
// START: thread local / global context check callbacks
|
||||
struct TestContext : public ObserverContext {
|
||||
int a{0};
|
||||
std::string b;
|
||||
};
|
||||
ids.clear();
|
||||
{ // START: global test
|
||||
const int test_val = 123;
|
||||
const std::string test_str = "test str";
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
||||
auto ctx = std::make_unique<TestContext>();
|
||||
ctx->a = test_val;
|
||||
ctx->b = test_str;
|
||||
ids.push_back(1);
|
||||
return ctx;
|
||||
},
|
||||
[test_val, test_str](
|
||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||
TORCH_CHECK(ctx_ptr != nullptr);
|
||||
TORCH_CHECK(ctx->a == test_val);
|
||||
TORCH_CHECK(ctx->b == test_str);
|
||||
}));
|
||||
|
||||
{ RECORD_USER_SCOPE("test"); }
|
||||
|
||||
TORCH_CHECK(ids.size() == 1);
|
||||
TORCH_CHECK(ids[0] == 1);
|
||||
ids.clear();
|
||||
} // END: global test
|
||||
{ // START: thread local test
|
||||
auto ctx_th = std::thread([&ids]() {
|
||||
const int test_val = 234;
|
||||
const std::string test_str = "test thread str";
|
||||
addThreadLocalCallback(RecordFunctionCallback(
|
||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
||||
auto ctx = std::make_unique<TestContext>();
|
||||
ctx->a = test_val;
|
||||
ctx->b = test_str;
|
||||
ids.push_back(2);
|
||||
return ctx;
|
||||
},
|
||||
[test_val, test_str](
|
||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||
TORCH_CHECK(ctx_ptr != nullptr);
|
||||
TORCH_CHECK(ctx->a == test_val);
|
||||
TORCH_CHECK(ctx->b == test_str);
|
||||
}));
|
||||
|
||||
// Will call both global and thread local callbacks.
|
||||
{ RECORD_USER_SCOPE("test_thread"); }
|
||||
});
|
||||
ctx_th.join();
|
||||
TORCH_CHECK(ids.size() == 2);
|
||||
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
|
||||
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
|
||||
ids.clear();
|
||||
} // END: thread local test
|
||||
|
||||
clearCallbacks();
|
||||
|
||||
// test should_run
|
||||
|
||||
bool ran = false;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user