[PyTorch] Use plain old function pointer for RecordFunctionCallback (reapply) (#49408)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49408

Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback.
ghstack-source-id: 118665808

Test Plan:
Wait for GitHub CI since we had C++14-specific issues with
this one in previous PR https://github.com/pytorch/pytorch/pull/48629

Reviewed By: malfet

Differential Revision: D25563207

fbshipit-source-id: 6a2831205917d465f8248ca37429ba2428d5626d
This commit is contained in:
Scott Wolchok 2020-12-15 19:14:11 -08:00 committed by Facebook GitHub Bot
parent e9d7d37ad0
commit 22c6dafd33
7 changed files with 193 additions and 141 deletions

View File

@ -277,10 +277,12 @@ class CallbackManager {
bool is_start) { bool is_start) {
try { try {
if (is_start) { if (is_start) {
ctx = rfcb.start()(rf); ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
} }
else { else {
rfcb.end()(rf, ctx.get()); if (rfcb.end()) {
rfcb.end()(rf, ctx.get());
}
} }
return true; return true;
} catch (const std::exception &e) { } catch (const std::exception &e) {

View File

@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
*/ */
class TORCH_API RecordFunctionCallback { class TORCH_API RecordFunctionCallback {
public: public:
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
// This interface supports observers that require passing an ObserverContext // This interface supports observers that require passing an ObserverContext
// between start and end callbacks. // between start and end callbacks.
explicit RecordFunctionCallback( explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start, StartCallback start,
std::function<void(const RecordFunction&, ObserverContext*)> end = EndCallback end = nullptr) :
[](const RecordFunction&, ObserverContext*) {}): start_(start),
start_(std::move(start)), end_(end) {
end_(std::move(end)) {
scopes_.fill(true); scopes_.fill(true);
} }
@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc]; return scopes_[(size_t)sc];
} }
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const { inline StartCallback start() const {
return start_; return start_;
} }
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const { inline EndCallback end() const {
return end_; return end_;
} }
private: private:
friend class CallbackManager; friend class CallbackManager;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_; StartCallback start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_; EndCallback end_;
bool(*should_run_)(const RecordFunctionCallback&) = nullptr; bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
double sampling_prob_ = 1.0; double sampling_prob_ = 1.0;
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {}; std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};

View File

@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;
void addTestCallback( void addTestCallback(
double sampling_prob = 1.0, double sampling_prob = 1.0,
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn = at::RecordFunctionCallback::StartCallback fn =
[](const at::RecordFunction&) { return nullptr; }) { [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
auto cb = at::RecordFunctionCallback( auto cb = at::RecordFunctionCallback(
std::move(fn), fn,
[](const at::RecordFunction&, at::ObserverContext*) {}) [](const at::RecordFunction&, at::ObserverContext*) {})
.needsInputs(false); .needsInputs(false);
if (sampling_prob < 1.0) { if (sampling_prob < 1.0) {
@ -106,10 +106,10 @@ int main(int argc, char** argv) {
at::clearCallbacks(); at::clearCallbacks();
std::cout << "Checking number of sampled observer invocations" << std::endl; std::cout << "Checking number of sampled observer invocations" << std::endl;
int cb_count = 0; static int cb_count = 0;
addTestCallback( addTestCallback(
kLowSamplingProb, kLowSamplingProb,
[&](const at::RecordFunction& fn) { [](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
++cb_count; ++cb_count;
return nullptr; return nullptr;
} }

View File

@ -721,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
TORCH_CHECK(found_mul); TORCH_CHECK(found_mul);
} }
static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++(*cnt);
} else {
bad_scope = true;
}
return nullptr;
}
template <RecordScope scope, size_t* cnt>
void pushScopedCallback() {
at::addGlobalCallback(
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
.scopes({scope}));
}
// These cannot be function-local because that would prohibit them
// from being used as template arguments prior to C++17.
static size_t fun_cnt;
static size_t ts_fun_cnt;
static size_t user_scope_cnt;
void checkScopeCallbacks() { void checkScopeCallbacks() {
bool found_function_scope = false; static bool found_function_scope;
bool found_method_scope = false; static bool found_method_scope;
bool found_user_scope = false; static bool found_user_scope;
found_function_scope = false;
found_method_scope = false;
found_user_scope = false;
at::addGlobalCallback(at::RecordFunctionCallback( at::addGlobalCallback(at::RecordFunctionCallback(
[&](const at::RecordFunction& fn) { [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
if (fn.scope() == at::RecordScope::FUNCTION && if (fn.scope() == at::RecordScope::FUNCTION &&
std::string(fn.name().str()) == "test_function") { std::string(fn.name().str()) == "test_function") {
found_function_scope = true; found_function_scope = true;
@ -742,27 +770,13 @@ void checkScopeCallbacks() {
return nullptr; return nullptr;
})); }));
bool bad_scope = false; bad_scope = false;
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) { fun_cnt = 0;
at::addGlobalCallback( pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
at::RecordFunctionCallback( ts_fun_cnt = 0;
[&bad_scope, &cnt, scope](const at::RecordFunction& fn) { pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
if (fn.scope() == scope) { user_scope_cnt = 0;
++cnt; pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
} else {
bad_scope = true;
}
return nullptr;
})
.scopes({scope}));
};
size_t fun_cnt = 0;
pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt);
size_t ts_fun_cnt = 0;
pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
size_t user_scope_cnt = 0;
pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt);
TORCH_CHECK(at::hasCallbacks()); TORCH_CHECK(at::hasCallbacks());
@ -788,33 +802,35 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
return should_run; return should_run;
} }
TEST(RecordFunctionTest, Basic) { static TracedTestInputs traced_inputs;
static std::unordered_set<std::string> ts_names;
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
const RecordFunction& fn) {
if (fn.scope() == RecordScope::FUNCTION) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& input : inputs) {
if (input.isTensor()) {
sizes.push_back(input.toTensor().sizes().vec());
} else if (input.isScalar()) {
sizes.push_back(std::vector<int64_t>());
}
}
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_names.insert(fn.name().str());
}
return nullptr;
}
TEST(RecordFunctionTest, TracedTestInputs) {
// disabling the inlining of method calls // disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false); GraphOptimizerEnabledGuard opt_guard(false);
// [(fn, [[sizes], [sizes], ...]), ...] // [(fn, [[sizes], [sizes], ...]), ...]
TracedTestInputs traced_inputs;
std::unordered_set<std::string> ts_names;
addGlobalCallback( addGlobalCallback(
RecordFunctionCallback( RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
[&](const RecordFunction& fn) {
if (fn.scope() == RecordScope::FUNCTION) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
for (const auto& input : inputs) {
if (input.isTensor()) {
sizes.push_back(input.toTensor().sizes().vec());
} else if (input.isScalar()) {
sizes.push_back(std::vector<int64_t>());
}
}
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
ts_names.insert(fn.name().str());
}
return nullptr;
})
.needsInputs(true));
TracedTestInputs eager_inputs, jit_inputs; TracedTestInputs eager_inputs, jit_inputs;
{ {
@ -841,28 +857,36 @@ TEST(RecordFunctionTest, Basic) {
checkTracedInputs(eager_inputs); checkTracedInputs(eager_inputs);
checkTracedInputs(jit_inputs); checkTracedInputs(jit_inputs);
at::clearCallbacks(); at::clearCallbacks();
}
static int sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++sampled_cb_ctr;
}
return nullptr;
}
static int non_sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++non_sampled_cb_ctr;
}
return nullptr;
}
TEST(RecordFunctionTest, SampledCallbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// test sampled callbacks // test sampled callbacks
int sampled_cb_ctr = 0; sampled_cb_ctr = 0;
auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) { auto setup_sampled_callback = [](double sampling_prob) {
return addGlobalCallback(RecordFunctionCallback( return addGlobalCallback(
[&sampled_cb_ctr](const RecordFunction& fn) { RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
if (std::string(fn.name().str()) == "test") {
++sampled_cb_ctr;
}
return nullptr;
})
.samplingProb(sampling_prob));
}; };
int non_sampled_cb_ctr = 0; addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
addGlobalCallback(RecordFunctionCallback(
[&non_sampled_cb_ctr](const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++non_sampled_cb_ctr;
}
return nullptr;
}));
auto handle = setup_sampled_callback(0.5); auto handle = setup_sampled_callback(0.5);
@ -897,13 +921,19 @@ TEST(RecordFunctionTest, Basic) {
// test the scope of the callbacks // test the scope of the callbacks
checkScopeCallbacks(); checkScopeCallbacks();
clearCallbacks(); clearCallbacks();
}
TEST(RecordFunctionTest, RecordFunctionGuard) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::vector<std::string> fn_names;
static std::mutex guard_mtx;
// check record function guard // check record function guard
std::vector<std::string> fn_names;
std::mutex mtx;
addGlobalCallback(RecordFunctionCallback( addGlobalCallback(RecordFunctionCallback(
[&fn_names, &mtx](const RecordFunction& fn) { [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::lock_guard<std::mutex> lock(mtx); std::lock_guard<std::mutex> lock(guard_mtx);
fn_names.push_back(fn.name().str()); fn_names.push_back(fn.name().str());
return nullptr; return nullptr;
})); }));
@ -925,20 +955,26 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(fn_names.size() == 1); TORCH_CHECK(fn_names.size() == 1);
TORCH_CHECK(fn_names[0] == "B"); TORCH_CHECK(fn_names[0] == "B");
clearCallbacks(); clearCallbacks();
}
// test add/remove static std::vector<size_t> ids;
std::vector<size_t> ids;
auto add_remove_test_add_cb = [&ids](size_t id) {
return addGlobalCallback(RecordFunctionCallback(
[&ids, id](const RecordFunction& fn) {
ids.push_back(id);
return nullptr ;
}));
};
auto h1 = add_remove_test_add_cb(1); template <size_t id>
auto h2 = add_remove_test_add_cb(2); auto add_remove_test_add_cb() {
auto h3 = add_remove_test_add_cb(3); return addGlobalCallback(RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(id);
return nullptr;
}));
}
TEST(RecordFunctionTest, Callbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
auto h1 = add_remove_test_add_cb<1>();
auto h2 = add_remove_test_add_cb<2>();
auto h3 = add_remove_test_add_cb<3>();
{ RECORD_USER_SCOPE("test"); } { RECORD_USER_SCOPE("test"); }
@ -969,8 +1005,7 @@ TEST(RecordFunctionTest, Basic) {
// thread local / global callbacks // thread local / global callbacks
ids.clear(); ids.clear();
addGlobalCallback(RecordFunctionCallback( add_remove_test_add_cb<1>();
[&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; }));
{ RECORD_USER_SCOPE("test"); } { RECORD_USER_SCOPE("test"); }
@ -978,9 +1013,12 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(ids[0] == 1); TORCH_CHECK(ids[0] == 1);
ids.clear(); ids.clear();
auto th = std::thread([&ids]() { auto th = std::thread([]() {
addThreadLocalCallback(RecordFunctionCallback( addThreadLocalCallback(RecordFunctionCallback(
[&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; })); [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(2);
return nullptr;
}));
{ RECORD_USER_SCOPE("test_thread"); } { RECORD_USER_SCOPE("test_thread"); }
}); });
@ -1005,22 +1043,20 @@ TEST(RecordFunctionTest, Basic) {
}; };
ids.clear(); ids.clear();
{ // START: global test { // START: global test
const int test_val = 123;
const std::string test_str = "test str";
addGlobalCallback(RecordFunctionCallback( addGlobalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) { [](const RecordFunction &
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>(); auto ctx = std::make_unique<TestContext>();
ctx->a = test_val; ctx->a = 123;
ctx->b = test_str; ctx->b = "test_str";
ids.push_back(1); ids.push_back(1);
return ctx; return ctx;
}, },
[test_val, test_str]( [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr); auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr); TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val); TORCH_CHECK(ctx->a == 123);
TORCH_CHECK(ctx->b == test_str); TORCH_CHECK(ctx->b == "test_str");
})); }));
{ RECORD_USER_SCOPE("test"); } { RECORD_USER_SCOPE("test"); }
@ -1030,23 +1066,23 @@ TEST(RecordFunctionTest, Basic) {
ids.clear(); ids.clear();
} // END: global test } // END: global test
{ // START: thread local test { // START: thread local test
auto ctx_th = std::thread([&ids]() { auto ctx_th = std::thread([]() {
const int test_val = 234; const int test_val = 234;
const std::string test_str = "test thread str"; const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback( addThreadLocalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) { [](const RecordFunction &
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>(); auto ctx = std::make_unique<TestContext>();
ctx->a = test_val; ctx->a = 234;
ctx->b = test_str; ctx->b = "test_thread_str";
ids.push_back(2); ids.push_back(2);
return ctx; return ctx;
}, },
[test_val, test_str]( [](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr); auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr); TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val); TORCH_CHECK(ctx->a == 234);
TORCH_CHECK(ctx->b == test_str); TORCH_CHECK(ctx->b == "test_thread_str");
})); }));
// Will call both global and thread local callbacks. // Will call both global and thread local callbacks.
@ -1060,14 +1096,21 @@ TEST(RecordFunctionTest, Basic) {
} // END: thread local test } // END: thread local test
clearCallbacks(); clearCallbacks();
}
// test should_run TEST(RecordFunctionTest, ShouldRun) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
bool ran = false;
should_run = false; should_run = false;
addGlobalCallback(RecordFunctionCallback( static bool ran = false;
[&ran](const RecordFunction& fn) { ran = true; return nullptr; }) addGlobalCallback(
.setShouldRun(shouldRunCallback)); RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ran = true;
return nullptr;
})
.setShouldRun(shouldRunCallback));
{ RECORD_USER_SCOPE("test"); } { RECORD_USER_SCOPE("test"); }
@ -1080,13 +1123,20 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(ran); TORCH_CHECK(ran);
clearCallbacks(); clearCallbacks();
}
TEST(RecordFunctionTest, Basic) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::string recorded_op;
static bool has_ids = false;
// test propagation of TLS callbacks // test propagation of TLS callbacks
std::thread t([]() { std::thread t([]() {
RecordFunctionGuard enable_rec_fn; RecordFunctionGuard enable_rec_fn;
std::string recorded_op;
auto handle = addThreadLocalCallback(RecordFunctionCallback( auto handle = addThreadLocalCallback(RecordFunctionCallback(
[&recorded_op](const RecordFunction& fn) { [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
recorded_op = fn.name().str(); recorded_op = fn.name().str();
return nullptr; return nullptr;
})); }));
@ -1096,17 +1146,16 @@ TEST(RecordFunctionTest, Basic) {
RECORD_USER_SCOPE("test_in_thread"); RECORD_USER_SCOPE("test_in_thread");
}); });
t_child.join(); t_child.join();
TORCH_CHECK(recorded_op == "test_in_thread"); EXPECT_EQ(recorded_op, "test_in_thread");
removeCallback(handle); removeCallback(handle);
}); });
t.join(); t.join();
clearCallbacks(); clearCallbacks();
// test set ids // test set ids
bool has_ids = false;
addGlobalCallback( addGlobalCallback(
RecordFunctionCallback( RecordFunctionCallback(
[&has_ids](const RecordFunction& fn) { [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0; has_ids = fn.handle() > 0;
return nullptr; return nullptr;
}) })
@ -1116,7 +1165,7 @@ TEST(RecordFunctionTest, Basic) {
clearCallbacks(); clearCallbacks();
has_ids = false; has_ids = false;
addGlobalCallback(RecordFunctionCallback( addGlobalCallback(RecordFunctionCallback(
[&has_ids](const RecordFunction& fn) { [](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0; has_ids = fn.handle() > 0;
return nullptr; return nullptr;
})); }));
@ -1126,10 +1175,10 @@ TEST(RecordFunctionTest, Basic) {
} }
TEST(RecordFunctionTest, OperatorNameOverload) { TEST(RecordFunctionTest, OperatorNameOverload) {
std::set<std::string> operator_names; static std::set<std::string> operator_names;
at::addGlobalCallback(at::RecordFunctionCallback( at::addGlobalCallback(at::RecordFunctionCallback(
[&operator_names](const at::RecordFunction& fn) { [](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
c10::optional<c10::OperatorName> op_name = c10::optional<c10::OperatorName> op_name =
fn.operator_name(); fn.operator_name();
if (op_name.has_value()) { if (op_name.has_value()) {
@ -1178,6 +1227,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
} }
TEST(ThreadLocalDebugInfoTest, Basic) { TEST(ThreadLocalDebugInfoTest, Basic) {
static std::atomic<bool> done{false};
TORCH_CHECK( TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>(); auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
@ -1190,10 +1241,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
// check that thread local debug info is propagated through fork calls // check that thread local debug info is propagated through fork calls
TORCH_CHECK( TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
std::atomic<bool> done{false};
{ {
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info); c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
at::launch([&done]() { at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true; done = true;
}); });
@ -1206,7 +1256,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr); c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
done = false; done = false;
auto handle = addGlobalCallback(RecordFunctionCallback( auto handle = addGlobalCallback(RecordFunctionCallback(
[&done](const RecordFunction&) { [](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true; done = true;
return nullptr; return nullptr;
@ -1236,7 +1286,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = false; done = false;
at::launch([&done]() { at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42); checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314); checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = true; done = true;

View File

@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
at::enableRecordFunction(enable); at::enableRecordFunction(enable);
}); });
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) { m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
auto cb = at::RecordFunctionCallback( auto cb = at::RecordFunctionCallback(nullptr)
[](const at::RecordFunction&) { return nullptr; },
[](const at::RecordFunction&, at::ObserverContext*) {})
.needsInputs(true) .needsInputs(true)
.samplingProb(sampling_prob); .samplingProb(sampling_prob);
if (is_global) { if (is_global) {

View File

@ -136,7 +136,7 @@ void pushProfilingCallbacks() {
auto state_ptr = getProfilerTLSState(); auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) { [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
auto state_ptr = getProfilerTLSState(); auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) { if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
return std::make_unique<KinetoObserverContext>(); return std::make_unique<KinetoObserverContext>();

View File

@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() {
auto state_ptr = getProfilerTLSState(); auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set"); TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback( auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) { [](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
auto state_ptr = getProfilerTLSState(); auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) { if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
return nullptr; return nullptr;