[PyTorch] Use plain old function pointer for RecordFunctionCallback (reapply) (#49408)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49408

Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback.
ghstack-source-id: 118665808

Test Plan:
Wait for GitHub CI since we had C++14-specific issues with
this one in previous PR https://github.com/pytorch/pytorch/pull/48629

Reviewed By: malfet

Differential Revision: D25563207

fbshipit-source-id: 6a2831205917d465f8248ca37429ba2428d5626d
This commit is contained in:
Scott Wolchok 2020-12-15 19:14:11 -08:00 committed by Facebook GitHub Bot
parent e9d7d37ad0
commit 22c6dafd33
7 changed files with 193 additions and 141 deletions

View File

@ -277,11 +277,13 @@ class CallbackManager {
bool is_start) {
try {
if (is_start) {
ctx = rfcb.start()(rf);
ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
}
else {
if (rfcb.end()) {
rfcb.end()(rf, ctx.get());
}
}
return true;
} catch (const std::exception &e) {
LOG(WARNING) << "Exception in RecordFunction callback: "

View File

@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
*/
class TORCH_API RecordFunctionCallback {
public:
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
// This interface supports observers that require passing an ObserverContext
// between start and end callbacks.
explicit RecordFunctionCallback(
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
std::function<void(const RecordFunction&, ObserverContext*)> end =
[](const RecordFunction&, ObserverContext*) {}):
start_(std::move(start)),
end_(std::move(end)) {
StartCallback start,
EndCallback end = nullptr) :
start_(start),
end_(end) {
scopes_.fill(true);
}
@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
return scopes_[(size_t)sc];
}
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
inline StartCallback start() const {
return start_;
}
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
inline EndCallback end() const {
return end_;
}
private:
friend class CallbackManager;
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
std::function<void(const RecordFunction&, ObserverContext*)> end_;
StartCallback start_;
EndCallback end_;
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
double sampling_prob_ = 1.0;
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};

View File

@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;
void addTestCallback(
double sampling_prob = 1.0,
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
[](const at::RecordFunction&) { return nullptr; }) {
at::RecordFunctionCallback::StartCallback fn =
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
auto cb = at::RecordFunctionCallback(
std::move(fn),
fn,
[](const at::RecordFunction&, at::ObserverContext*) {})
.needsInputs(false);
if (sampling_prob < 1.0) {
@ -106,10 +106,10 @@ int main(int argc, char** argv) {
at::clearCallbacks();
std::cout << "Checking number of sampled observer invocations" << std::endl;
int cb_count = 0;
static int cb_count = 0;
addTestCallback(
kLowSamplingProb,
[&](const at::RecordFunction& fn) {
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
++cb_count;
return nullptr;
}

View File

@ -721,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
TORCH_CHECK(found_mul);
}
static bool bad_scope = false;
template <RecordScope scope, size_t* cnt>
std::unique_ptr<at::ObserverContext> checkScopeCallback(
const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++(*cnt);
} else {
bad_scope = true;
}
return nullptr;
}
template <RecordScope scope, size_t* cnt>
void pushScopedCallback() {
at::addGlobalCallback(
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
.scopes({scope}));
}
// These cannot be function-local because that would prohibit them
// from being used as template arguments prior to C++17.
static size_t fun_cnt;
static size_t ts_fun_cnt;
static size_t user_scope_cnt;
void checkScopeCallbacks() {
bool found_function_scope = false;
bool found_method_scope = false;
bool found_user_scope = false;
static bool found_function_scope;
static bool found_method_scope;
static bool found_user_scope;
found_function_scope = false;
found_method_scope = false;
found_user_scope = false;
at::addGlobalCallback(at::RecordFunctionCallback(
[&](const at::RecordFunction& fn) {
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
if (fn.scope() == at::RecordScope::FUNCTION &&
std::string(fn.name().str()) == "test_function") {
found_function_scope = true;
@ -742,27 +770,13 @@ void checkScopeCallbacks() {
return nullptr;
}));
bool bad_scope = false;
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) {
at::addGlobalCallback(
at::RecordFunctionCallback(
[&bad_scope, &cnt, scope](const at::RecordFunction& fn) {
if (fn.scope() == scope) {
++cnt;
} else {
bad_scope = true;
}
return nullptr;
})
.scopes({scope}));
};
size_t fun_cnt = 0;
pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt);
size_t ts_fun_cnt = 0;
pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
size_t user_scope_cnt = 0;
pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt);
bad_scope = false;
fun_cnt = 0;
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
ts_fun_cnt = 0;
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
user_scope_cnt = 0;
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
TORCH_CHECK(at::hasCallbacks());
@ -788,16 +802,11 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
return should_run;
}
TEST(RecordFunctionTest, Basic) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static TracedTestInputs traced_inputs;
static std::unordered_set<std::string> ts_names;
// [(fn, [[sizes], [sizes], ...]), ...]
TracedTestInputs traced_inputs;
std::unordered_set<std::string> ts_names;
addGlobalCallback(
RecordFunctionCallback(
[&](const RecordFunction& fn) {
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
const RecordFunction& fn) {
if (fn.scope() == RecordScope::FUNCTION) {
auto inputs = fn.inputs();
std::vector<std::vector<int64_t>> sizes;
@ -813,8 +822,15 @@ TEST(RecordFunctionTest, Basic) {
ts_names.insert(fn.name().str());
}
return nullptr;
})
.needsInputs(true));
}
TEST(RecordFunctionTest, TracedTestInputs) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// [(fn, [[sizes], [sizes], ...]), ...]
addGlobalCallback(
RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
TracedTestInputs eager_inputs, jit_inputs;
{
@ -841,28 +857,36 @@ TEST(RecordFunctionTest, Basic) {
checkTracedInputs(eager_inputs);
checkTracedInputs(jit_inputs);
at::clearCallbacks();
}
// test sampled callbacks
int sampled_cb_ctr = 0;
auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) {
return addGlobalCallback(RecordFunctionCallback(
[&sampled_cb_ctr](const RecordFunction& fn) {
static int sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++sampled_cb_ctr;
}
return nullptr;
})
.samplingProb(sampling_prob));
};
}
int non_sampled_cb_ctr = 0;
addGlobalCallback(RecordFunctionCallback(
[&non_sampled_cb_ctr](const RecordFunction& fn) {
static int non_sampled_cb_ctr = 0;
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
if (std::string(fn.name().str()) == "test") {
++non_sampled_cb_ctr;
}
return nullptr;
}));
}
TEST(RecordFunctionTest, SampledCallbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
// test sampled callbacks
sampled_cb_ctr = 0;
auto setup_sampled_callback = [](double sampling_prob) {
return addGlobalCallback(
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
};
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
auto handle = setup_sampled_callback(0.5);
@ -897,13 +921,19 @@ TEST(RecordFunctionTest, Basic) {
// test the scope of the callbacks
checkScopeCallbacks();
clearCallbacks();
}
TEST(RecordFunctionTest, RecordFunctionGuard) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::vector<std::string> fn_names;
static std::mutex guard_mtx;
// check record function guard
std::vector<std::string> fn_names;
std::mutex mtx;
addGlobalCallback(RecordFunctionCallback(
[&fn_names, &mtx](const RecordFunction& fn) {
std::lock_guard<std::mutex> lock(mtx);
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
std::lock_guard<std::mutex> lock(guard_mtx);
fn_names.push_back(fn.name().str());
return nullptr;
}));
@ -925,20 +955,26 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(fn_names.size() == 1);
TORCH_CHECK(fn_names[0] == "B");
clearCallbacks();
}
// test add/remove
std::vector<size_t> ids;
auto add_remove_test_add_cb = [&ids](size_t id) {
static std::vector<size_t> ids;
template <size_t id>
auto add_remove_test_add_cb() {
return addGlobalCallback(RecordFunctionCallback(
[&ids, id](const RecordFunction& fn) {
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(id);
return nullptr;
}));
};
}
auto h1 = add_remove_test_add_cb(1);
auto h2 = add_remove_test_add_cb(2);
auto h3 = add_remove_test_add_cb(3);
TEST(RecordFunctionTest, Callbacks) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
auto h1 = add_remove_test_add_cb<1>();
auto h2 = add_remove_test_add_cb<2>();
auto h3 = add_remove_test_add_cb<3>();
{ RECORD_USER_SCOPE("test"); }
@ -969,8 +1005,7 @@ TEST(RecordFunctionTest, Basic) {
// thread local / global callbacks
ids.clear();
addGlobalCallback(RecordFunctionCallback(
[&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; }));
add_remove_test_add_cb<1>();
{ RECORD_USER_SCOPE("test"); }
@ -978,9 +1013,12 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(ids[0] == 1);
ids.clear();
auto th = std::thread([&ids]() {
auto th = std::thread([]() {
addThreadLocalCallback(RecordFunctionCallback(
[&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; }));
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ids.push_back(2);
return nullptr;
}));
{ RECORD_USER_SCOPE("test_thread"); }
});
@ -1005,22 +1043,20 @@ TEST(RecordFunctionTest, Basic) {
};
ids.clear();
{ // START: global test
const int test_val = 123;
const std::string test_str = "test str";
addGlobalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
[](const RecordFunction &
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ctx->a = 123;
ctx->b = "test_str";
ids.push_back(1);
return ctx;
},
[test_val, test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
TORCH_CHECK(ctx->a == 123);
TORCH_CHECK(ctx->b == "test_str");
}));
{ RECORD_USER_SCOPE("test"); }
@ -1030,23 +1066,23 @@ TEST(RecordFunctionTest, Basic) {
ids.clear();
} // END: global test
{ // START: thread local test
auto ctx_th = std::thread([&ids]() {
auto ctx_th = std::thread([]() {
const int test_val = 234;
const std::string test_str = "test thread str";
addThreadLocalCallback(RecordFunctionCallback(
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
[](const RecordFunction &
/* unused */) -> std::unique_ptr<at::ObserverContext> {
auto ctx = std::make_unique<TestContext>();
ctx->a = test_val;
ctx->b = test_str;
ctx->a = 234;
ctx->b = "test_thread_str";
ids.push_back(2);
return ctx;
},
[test_val, test_str](
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
TORCH_CHECK(ctx_ptr != nullptr);
TORCH_CHECK(ctx->a == test_val);
TORCH_CHECK(ctx->b == test_str);
TORCH_CHECK(ctx->a == 234);
TORCH_CHECK(ctx->b == "test_thread_str");
}));
// Will call both global and thread local callbacks.
@ -1060,13 +1096,20 @@ TEST(RecordFunctionTest, Basic) {
} // END: thread local test
clearCallbacks();
}
// test should_run
TEST(RecordFunctionTest, ShouldRun) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
bool ran = false;
should_run = false;
addGlobalCallback(RecordFunctionCallback(
[&ran](const RecordFunction& fn) { ran = true; return nullptr; })
static bool ran = false;
addGlobalCallback(
RecordFunctionCallback(
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
ran = true;
return nullptr;
})
.setShouldRun(shouldRunCallback));
{ RECORD_USER_SCOPE("test"); }
@ -1080,13 +1123,20 @@ TEST(RecordFunctionTest, Basic) {
TORCH_CHECK(ran);
clearCallbacks();
}
TEST(RecordFunctionTest, Basic) {
// disabling the inlining of method calls
GraphOptimizerEnabledGuard opt_guard(false);
static std::string recorded_op;
static bool has_ids = false;
// test propagation of TLS callbacks
std::thread t([]() {
RecordFunctionGuard enable_rec_fn;
std::string recorded_op;
auto handle = addThreadLocalCallback(RecordFunctionCallback(
[&recorded_op](const RecordFunction& fn) {
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
recorded_op = fn.name().str();
return nullptr;
}));
@ -1096,17 +1146,16 @@ TEST(RecordFunctionTest, Basic) {
RECORD_USER_SCOPE("test_in_thread");
});
t_child.join();
TORCH_CHECK(recorded_op == "test_in_thread");
EXPECT_EQ(recorded_op, "test_in_thread");
removeCallback(handle);
});
t.join();
clearCallbacks();
// test set ids
bool has_ids = false;
addGlobalCallback(
RecordFunctionCallback(
[&has_ids](const RecordFunction& fn) {
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
})
@ -1116,7 +1165,7 @@ TEST(RecordFunctionTest, Basic) {
clearCallbacks();
has_ids = false;
addGlobalCallback(RecordFunctionCallback(
[&has_ids](const RecordFunction& fn) {
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
has_ids = fn.handle() > 0;
return nullptr;
}));
@ -1126,10 +1175,10 @@ TEST(RecordFunctionTest, Basic) {
}
TEST(RecordFunctionTest, OperatorNameOverload) {
std::set<std::string> operator_names;
static std::set<std::string> operator_names;
at::addGlobalCallback(at::RecordFunctionCallback(
[&operator_names](const at::RecordFunction& fn) {
[](const at::RecordFunction& fn)
-> std::unique_ptr<at::ObserverContext> {
c10::optional<c10::OperatorName> op_name =
fn.operator_name();
if (op_name.has_value()) {
@ -1178,6 +1227,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
}
TEST(ThreadLocalDebugInfoTest, Basic) {
static std::atomic<bool> done{false};
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
@ -1190,10 +1241,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
// check that thread local debug info is propagated through fork calls
TORCH_CHECK(
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
std::atomic<bool> done{false};
{
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
at::launch([&done]() {
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
});
@ -1206,7 +1256,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
done = false;
auto handle = addGlobalCallback(RecordFunctionCallback(
[&done](const RecordFunction&) {
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
done = true;
return nullptr;
@ -1236,7 +1286,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = false;
at::launch([&done]() {
at::launch([]() {
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
done = true;

View File

@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
at::enableRecordFunction(enable);
});
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
auto cb = at::RecordFunctionCallback(
[](const at::RecordFunction&) { return nullptr; },
[](const at::RecordFunction&, at::ObserverContext*) {})
auto cb = at::RecordFunctionCallback(nullptr)
.needsInputs(true)
.samplingProb(sampling_prob);
if (is_global) {

View File

@ -136,7 +136,7 @@ void pushProfilingCallbacks() {
auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) {
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
return std::make_unique<KinetoObserverContext>();

View File

@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() {
auto state_ptr = getProfilerTLSState();
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
[](const at::RecordFunction& fn) {
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
auto state_ptr = getProfilerTLSState();
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
return nullptr;