mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Use plain old function pointer for RecordFunctionCallback (reapply) (#49408)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49408 Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback. ghstack-source-id: 118665808 Test Plan: Wait for GitHub CI since we had C++14-specific issues with this one in previous PR https://github.com/pytorch/pytorch/pull/48629 Reviewed By: malfet Differential Revision: D25563207 fbshipit-source-id: 6a2831205917d465f8248ca37429ba2428d5626d
This commit is contained in:
parent
e9d7d37ad0
commit
22c6dafd33
|
|
@ -277,10 +277,12 @@ class CallbackManager {
|
||||||
bool is_start) {
|
bool is_start) {
|
||||||
try {
|
try {
|
||||||
if (is_start) {
|
if (is_start) {
|
||||||
ctx = rfcb.start()(rf);
|
ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
rfcb.end()(rf, ctx.get());
|
if (rfcb.end()) {
|
||||||
|
rfcb.end()(rf, ctx.get());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} catch (const std::exception &e) {
|
} catch (const std::exception &e) {
|
||||||
|
|
|
||||||
|
|
@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
|
||||||
*/
|
*/
|
||||||
class TORCH_API RecordFunctionCallback {
|
class TORCH_API RecordFunctionCallback {
|
||||||
public:
|
public:
|
||||||
|
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
|
||||||
|
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
|
||||||
|
|
||||||
// This interface supports observers that require passing an ObserverContext
|
// This interface supports observers that require passing an ObserverContext
|
||||||
// between start and end callbacks.
|
// between start and end callbacks.
|
||||||
explicit RecordFunctionCallback(
|
explicit RecordFunctionCallback(
|
||||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
|
StartCallback start,
|
||||||
std::function<void(const RecordFunction&, ObserverContext*)> end =
|
EndCallback end = nullptr) :
|
||||||
[](const RecordFunction&, ObserverContext*) {}):
|
start_(start),
|
||||||
start_(std::move(start)),
|
end_(end) {
|
||||||
end_(std::move(end)) {
|
|
||||||
scopes_.fill(true);
|
scopes_.fill(true);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
|
||||||
return scopes_[(size_t)sc];
|
return scopes_[(size_t)sc];
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
|
inline StartCallback start() const {
|
||||||
return start_;
|
return start_;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
|
inline EndCallback end() const {
|
||||||
return end_;
|
return end_;
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
friend class CallbackManager;
|
friend class CallbackManager;
|
||||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
|
StartCallback start_;
|
||||||
std::function<void(const RecordFunction&, ObserverContext*)> end_;
|
EndCallback end_;
|
||||||
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
|
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
|
||||||
double sampling_prob_ = 1.0;
|
double sampling_prob_ = 1.0;
|
||||||
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
|
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
|
||||||
|
|
|
||||||
|
|
@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;
|
||||||
|
|
||||||
void addTestCallback(
|
void addTestCallback(
|
||||||
double sampling_prob = 1.0,
|
double sampling_prob = 1.0,
|
||||||
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
|
at::RecordFunctionCallback::StartCallback fn =
|
||||||
[](const at::RecordFunction&) { return nullptr; }) {
|
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
|
||||||
auto cb = at::RecordFunctionCallback(
|
auto cb = at::RecordFunctionCallback(
|
||||||
std::move(fn),
|
fn,
|
||||||
[](const at::RecordFunction&, at::ObserverContext*) {})
|
[](const at::RecordFunction&, at::ObserverContext*) {})
|
||||||
.needsInputs(false);
|
.needsInputs(false);
|
||||||
if (sampling_prob < 1.0) {
|
if (sampling_prob < 1.0) {
|
||||||
|
|
@ -106,10 +106,10 @@ int main(int argc, char** argv) {
|
||||||
at::clearCallbacks();
|
at::clearCallbacks();
|
||||||
|
|
||||||
std::cout << "Checking number of sampled observer invocations" << std::endl;
|
std::cout << "Checking number of sampled observer invocations" << std::endl;
|
||||||
int cb_count = 0;
|
static int cb_count = 0;
|
||||||
addTestCallback(
|
addTestCallback(
|
||||||
kLowSamplingProb,
|
kLowSamplingProb,
|
||||||
[&](const at::RecordFunction& fn) {
|
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
||||||
++cb_count;
|
++cb_count;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -721,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
|
||||||
TORCH_CHECK(found_mul);
|
TORCH_CHECK(found_mul);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool bad_scope = false;
|
||||||
|
template <RecordScope scope, size_t* cnt>
|
||||||
|
std::unique_ptr<at::ObserverContext> checkScopeCallback(
|
||||||
|
const at::RecordFunction& fn) {
|
||||||
|
if (fn.scope() == scope) {
|
||||||
|
++(*cnt);
|
||||||
|
} else {
|
||||||
|
bad_scope = true;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <RecordScope scope, size_t* cnt>
|
||||||
|
void pushScopedCallback() {
|
||||||
|
at::addGlobalCallback(
|
||||||
|
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
|
||||||
|
.scopes({scope}));
|
||||||
|
}
|
||||||
|
|
||||||
|
// These cannot be function-local because that would prohibit them
|
||||||
|
// from being used as template arguments prior to C++17.
|
||||||
|
static size_t fun_cnt;
|
||||||
|
static size_t ts_fun_cnt;
|
||||||
|
static size_t user_scope_cnt;
|
||||||
|
|
||||||
void checkScopeCallbacks() {
|
void checkScopeCallbacks() {
|
||||||
bool found_function_scope = false;
|
static bool found_function_scope;
|
||||||
bool found_method_scope = false;
|
static bool found_method_scope;
|
||||||
bool found_user_scope = false;
|
static bool found_user_scope;
|
||||||
|
found_function_scope = false;
|
||||||
|
found_method_scope = false;
|
||||||
|
found_user_scope = false;
|
||||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||||
[&](const at::RecordFunction& fn) {
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
if (fn.scope() == at::RecordScope::FUNCTION &&
|
if (fn.scope() == at::RecordScope::FUNCTION &&
|
||||||
std::string(fn.name().str()) == "test_function") {
|
std::string(fn.name().str()) == "test_function") {
|
||||||
found_function_scope = true;
|
found_function_scope = true;
|
||||||
|
|
@ -742,27 +770,13 @@ void checkScopeCallbacks() {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}));
|
}));
|
||||||
|
|
||||||
bool bad_scope = false;
|
bad_scope = false;
|
||||||
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) {
|
fun_cnt = 0;
|
||||||
at::addGlobalCallback(
|
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
|
||||||
at::RecordFunctionCallback(
|
ts_fun_cnt = 0;
|
||||||
[&bad_scope, &cnt, scope](const at::RecordFunction& fn) {
|
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
|
||||||
if (fn.scope() == scope) {
|
user_scope_cnt = 0;
|
||||||
++cnt;
|
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
|
||||||
} else {
|
|
||||||
bad_scope = true;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
})
|
|
||||||
.scopes({scope}));
|
|
||||||
};
|
|
||||||
|
|
||||||
size_t fun_cnt = 0;
|
|
||||||
pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt);
|
|
||||||
size_t ts_fun_cnt = 0;
|
|
||||||
pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
|
|
||||||
size_t user_scope_cnt = 0;
|
|
||||||
pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt);
|
|
||||||
|
|
||||||
TORCH_CHECK(at::hasCallbacks());
|
TORCH_CHECK(at::hasCallbacks());
|
||||||
|
|
||||||
|
|
@ -788,33 +802,35 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
|
||||||
return should_run;
|
return should_run;
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RecordFunctionTest, Basic) {
|
static TracedTestInputs traced_inputs;
|
||||||
|
static std::unordered_set<std::string> ts_names;
|
||||||
|
|
||||||
|
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
|
||||||
|
const RecordFunction& fn) {
|
||||||
|
if (fn.scope() == RecordScope::FUNCTION) {
|
||||||
|
auto inputs = fn.inputs();
|
||||||
|
std::vector<std::vector<int64_t>> sizes;
|
||||||
|
for (const auto& input : inputs) {
|
||||||
|
if (input.isTensor()) {
|
||||||
|
sizes.push_back(input.toTensor().sizes().vec());
|
||||||
|
} else if (input.isScalar()) {
|
||||||
|
sizes.push_back(std::vector<int64_t>());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
|
||||||
|
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
||||||
|
ts_names.insert(fn.name().str());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecordFunctionTest, TracedTestInputs) {
|
||||||
// disabling the inlining of method calls
|
// disabling the inlining of method calls
|
||||||
GraphOptimizerEnabledGuard opt_guard(false);
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
// [(fn, [[sizes], [sizes], ...]), ...]
|
// [(fn, [[sizes], [sizes], ...]), ...]
|
||||||
TracedTestInputs traced_inputs;
|
|
||||||
std::unordered_set<std::string> ts_names;
|
|
||||||
addGlobalCallback(
|
addGlobalCallback(
|
||||||
RecordFunctionCallback(
|
RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
|
||||||
[&](const RecordFunction& fn) {
|
|
||||||
if (fn.scope() == RecordScope::FUNCTION) {
|
|
||||||
auto inputs = fn.inputs();
|
|
||||||
std::vector<std::vector<int64_t>> sizes;
|
|
||||||
for (const auto& input : inputs) {
|
|
||||||
if (input.isTensor()) {
|
|
||||||
sizes.push_back(input.toTensor().sizes().vec());
|
|
||||||
} else if (input.isScalar()) {
|
|
||||||
sizes.push_back(std::vector<int64_t>());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
traced_inputs.push_back(std::make_tuple(fn.name().str(), sizes));
|
|
||||||
} else if (fn.scope() == RecordScope::TORCHSCRIPT_FUNCTION) {
|
|
||||||
ts_names.insert(fn.name().str());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
})
|
|
||||||
.needsInputs(true));
|
|
||||||
|
|
||||||
TracedTestInputs eager_inputs, jit_inputs;
|
TracedTestInputs eager_inputs, jit_inputs;
|
||||||
{
|
{
|
||||||
|
|
@ -841,28 +857,36 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
checkTracedInputs(eager_inputs);
|
checkTracedInputs(eager_inputs);
|
||||||
checkTracedInputs(jit_inputs);
|
checkTracedInputs(jit_inputs);
|
||||||
at::clearCallbacks();
|
at::clearCallbacks();
|
||||||
|
}
|
||||||
|
|
||||||
|
static int sampled_cb_ctr = 0;
|
||||||
|
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
|
||||||
|
if (std::string(fn.name().str()) == "test") {
|
||||||
|
++sampled_cb_ctr;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int non_sampled_cb_ctr = 0;
|
||||||
|
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
|
||||||
|
if (std::string(fn.name().str()) == "test") {
|
||||||
|
++non_sampled_cb_ctr;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecordFunctionTest, SampledCallbacks) {
|
||||||
|
// disabling the inlining of method calls
|
||||||
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
// test sampled callbacks
|
// test sampled callbacks
|
||||||
int sampled_cb_ctr = 0;
|
sampled_cb_ctr = 0;
|
||||||
auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) {
|
auto setup_sampled_callback = [](double sampling_prob) {
|
||||||
return addGlobalCallback(RecordFunctionCallback(
|
return addGlobalCallback(
|
||||||
[&sampled_cb_ctr](const RecordFunction& fn) {
|
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
|
||||||
if (std::string(fn.name().str()) == "test") {
|
|
||||||
++sampled_cb_ctr;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
})
|
|
||||||
.samplingProb(sampling_prob));
|
|
||||||
};
|
};
|
||||||
|
|
||||||
int non_sampled_cb_ctr = 0;
|
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
|
||||||
[&non_sampled_cb_ctr](const RecordFunction& fn) {
|
|
||||||
if (std::string(fn.name().str()) == "test") {
|
|
||||||
++non_sampled_cb_ctr;
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}));
|
|
||||||
|
|
||||||
auto handle = setup_sampled_callback(0.5);
|
auto handle = setup_sampled_callback(0.5);
|
||||||
|
|
||||||
|
|
@ -897,13 +921,19 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
// test the scope of the callbacks
|
// test the scope of the callbacks
|
||||||
checkScopeCallbacks();
|
checkScopeCallbacks();
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecordFunctionTest, RecordFunctionGuard) {
|
||||||
|
// disabling the inlining of method calls
|
||||||
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
|
static std::vector<std::string> fn_names;
|
||||||
|
static std::mutex guard_mtx;
|
||||||
|
|
||||||
// check record function guard
|
// check record function guard
|
||||||
std::vector<std::string> fn_names;
|
|
||||||
std::mutex mtx;
|
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
addGlobalCallback(RecordFunctionCallback(
|
||||||
[&fn_names, &mtx](const RecordFunction& fn) {
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
std::lock_guard<std::mutex> lock(mtx);
|
std::lock_guard<std::mutex> lock(guard_mtx);
|
||||||
fn_names.push_back(fn.name().str());
|
fn_names.push_back(fn.name().str());
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}));
|
}));
|
||||||
|
|
@ -925,20 +955,26 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
TORCH_CHECK(fn_names.size() == 1);
|
TORCH_CHECK(fn_names.size() == 1);
|
||||||
TORCH_CHECK(fn_names[0] == "B");
|
TORCH_CHECK(fn_names[0] == "B");
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
|
}
|
||||||
|
|
||||||
// test add/remove
|
static std::vector<size_t> ids;
|
||||||
std::vector<size_t> ids;
|
|
||||||
auto add_remove_test_add_cb = [&ids](size_t id) {
|
|
||||||
return addGlobalCallback(RecordFunctionCallback(
|
|
||||||
[&ids, id](const RecordFunction& fn) {
|
|
||||||
ids.push_back(id);
|
|
||||||
return nullptr ;
|
|
||||||
}));
|
|
||||||
};
|
|
||||||
|
|
||||||
auto h1 = add_remove_test_add_cb(1);
|
template <size_t id>
|
||||||
auto h2 = add_remove_test_add_cb(2);
|
auto add_remove_test_add_cb() {
|
||||||
auto h3 = add_remove_test_add_cb(3);
|
return addGlobalCallback(RecordFunctionCallback(
|
||||||
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
|
ids.push_back(id);
|
||||||
|
return nullptr;
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecordFunctionTest, Callbacks) {
|
||||||
|
// disabling the inlining of method calls
|
||||||
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
|
auto h1 = add_remove_test_add_cb<1>();
|
||||||
|
auto h2 = add_remove_test_add_cb<2>();
|
||||||
|
auto h3 = add_remove_test_add_cb<3>();
|
||||||
|
|
||||||
{ RECORD_USER_SCOPE("test"); }
|
{ RECORD_USER_SCOPE("test"); }
|
||||||
|
|
||||||
|
|
@ -969,8 +1005,7 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
// thread local / global callbacks
|
// thread local / global callbacks
|
||||||
|
|
||||||
ids.clear();
|
ids.clear();
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
add_remove_test_add_cb<1>();
|
||||||
[&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; }));
|
|
||||||
|
|
||||||
{ RECORD_USER_SCOPE("test"); }
|
{ RECORD_USER_SCOPE("test"); }
|
||||||
|
|
||||||
|
|
@ -978,9 +1013,12 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
TORCH_CHECK(ids[0] == 1);
|
TORCH_CHECK(ids[0] == 1);
|
||||||
ids.clear();
|
ids.clear();
|
||||||
|
|
||||||
auto th = std::thread([&ids]() {
|
auto th = std::thread([]() {
|
||||||
addThreadLocalCallback(RecordFunctionCallback(
|
addThreadLocalCallback(RecordFunctionCallback(
|
||||||
[&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; }));
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
|
ids.push_back(2);
|
||||||
|
return nullptr;
|
||||||
|
}));
|
||||||
|
|
||||||
{ RECORD_USER_SCOPE("test_thread"); }
|
{ RECORD_USER_SCOPE("test_thread"); }
|
||||||
});
|
});
|
||||||
|
|
@ -1005,22 +1043,20 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
};
|
};
|
||||||
ids.clear();
|
ids.clear();
|
||||||
{ // START: global test
|
{ // START: global test
|
||||||
const int test_val = 123;
|
|
||||||
const std::string test_str = "test str";
|
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
addGlobalCallback(RecordFunctionCallback(
|
||||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
[](const RecordFunction &
|
||||||
|
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
||||||
auto ctx = std::make_unique<TestContext>();
|
auto ctx = std::make_unique<TestContext>();
|
||||||
ctx->a = test_val;
|
ctx->a = 123;
|
||||||
ctx->b = test_str;
|
ctx->b = "test_str";
|
||||||
ids.push_back(1);
|
ids.push_back(1);
|
||||||
return ctx;
|
return ctx;
|
||||||
},
|
},
|
||||||
[test_val, test_str](
|
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
|
||||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||||
TORCH_CHECK(ctx_ptr != nullptr);
|
TORCH_CHECK(ctx_ptr != nullptr);
|
||||||
TORCH_CHECK(ctx->a == test_val);
|
TORCH_CHECK(ctx->a == 123);
|
||||||
TORCH_CHECK(ctx->b == test_str);
|
TORCH_CHECK(ctx->b == "test_str");
|
||||||
}));
|
}));
|
||||||
|
|
||||||
{ RECORD_USER_SCOPE("test"); }
|
{ RECORD_USER_SCOPE("test"); }
|
||||||
|
|
@ -1030,23 +1066,23 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
ids.clear();
|
ids.clear();
|
||||||
} // END: global test
|
} // END: global test
|
||||||
{ // START: thread local test
|
{ // START: thread local test
|
||||||
auto ctx_th = std::thread([&ids]() {
|
auto ctx_th = std::thread([]() {
|
||||||
const int test_val = 234;
|
const int test_val = 234;
|
||||||
const std::string test_str = "test thread str";
|
const std::string test_str = "test thread str";
|
||||||
addThreadLocalCallback(RecordFunctionCallback(
|
addThreadLocalCallback(RecordFunctionCallback(
|
||||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
[](const RecordFunction &
|
||||||
|
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
||||||
auto ctx = std::make_unique<TestContext>();
|
auto ctx = std::make_unique<TestContext>();
|
||||||
ctx->a = test_val;
|
ctx->a = 234;
|
||||||
ctx->b = test_str;
|
ctx->b = "test_thread_str";
|
||||||
ids.push_back(2);
|
ids.push_back(2);
|
||||||
return ctx;
|
return ctx;
|
||||||
},
|
},
|
||||||
[test_val, test_str](
|
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
|
||||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||||
TORCH_CHECK(ctx_ptr != nullptr);
|
TORCH_CHECK(ctx_ptr != nullptr);
|
||||||
TORCH_CHECK(ctx->a == test_val);
|
TORCH_CHECK(ctx->a == 234);
|
||||||
TORCH_CHECK(ctx->b == test_str);
|
TORCH_CHECK(ctx->b == "test_thread_str");
|
||||||
}));
|
}));
|
||||||
|
|
||||||
// Will call both global and thread local callbacks.
|
// Will call both global and thread local callbacks.
|
||||||
|
|
@ -1060,14 +1096,21 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
} // END: thread local test
|
} // END: thread local test
|
||||||
|
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
|
}
|
||||||
|
|
||||||
// test should_run
|
TEST(RecordFunctionTest, ShouldRun) {
|
||||||
|
// disabling the inlining of method calls
|
||||||
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
bool ran = false;
|
|
||||||
should_run = false;
|
should_run = false;
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
static bool ran = false;
|
||||||
[&ran](const RecordFunction& fn) { ran = true; return nullptr; })
|
addGlobalCallback(
|
||||||
.setShouldRun(shouldRunCallback));
|
RecordFunctionCallback(
|
||||||
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
|
ran = true;
|
||||||
|
return nullptr;
|
||||||
|
})
|
||||||
|
.setShouldRun(shouldRunCallback));
|
||||||
|
|
||||||
{ RECORD_USER_SCOPE("test"); }
|
{ RECORD_USER_SCOPE("test"); }
|
||||||
|
|
||||||
|
|
@ -1080,13 +1123,20 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
TORCH_CHECK(ran);
|
TORCH_CHECK(ran);
|
||||||
|
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(RecordFunctionTest, Basic) {
|
||||||
|
// disabling the inlining of method calls
|
||||||
|
GraphOptimizerEnabledGuard opt_guard(false);
|
||||||
|
|
||||||
|
static std::string recorded_op;
|
||||||
|
static bool has_ids = false;
|
||||||
|
|
||||||
// test propagation of TLS callbacks
|
// test propagation of TLS callbacks
|
||||||
std::thread t([]() {
|
std::thread t([]() {
|
||||||
RecordFunctionGuard enable_rec_fn;
|
RecordFunctionGuard enable_rec_fn;
|
||||||
std::string recorded_op;
|
|
||||||
auto handle = addThreadLocalCallback(RecordFunctionCallback(
|
auto handle = addThreadLocalCallback(RecordFunctionCallback(
|
||||||
[&recorded_op](const RecordFunction& fn) {
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
recorded_op = fn.name().str();
|
recorded_op = fn.name().str();
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}));
|
}));
|
||||||
|
|
@ -1096,17 +1146,16 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
RECORD_USER_SCOPE("test_in_thread");
|
RECORD_USER_SCOPE("test_in_thread");
|
||||||
});
|
});
|
||||||
t_child.join();
|
t_child.join();
|
||||||
TORCH_CHECK(recorded_op == "test_in_thread");
|
EXPECT_EQ(recorded_op, "test_in_thread");
|
||||||
removeCallback(handle);
|
removeCallback(handle);
|
||||||
});
|
});
|
||||||
t.join();
|
t.join();
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
|
|
||||||
// test set ids
|
// test set ids
|
||||||
bool has_ids = false;
|
|
||||||
addGlobalCallback(
|
addGlobalCallback(
|
||||||
RecordFunctionCallback(
|
RecordFunctionCallback(
|
||||||
[&has_ids](const RecordFunction& fn) {
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
has_ids = fn.handle() > 0;
|
has_ids = fn.handle() > 0;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
})
|
})
|
||||||
|
|
@ -1116,7 +1165,7 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
clearCallbacks();
|
clearCallbacks();
|
||||||
has_ids = false;
|
has_ids = false;
|
||||||
addGlobalCallback(RecordFunctionCallback(
|
addGlobalCallback(RecordFunctionCallback(
|
||||||
[&has_ids](const RecordFunction& fn) {
|
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
has_ids = fn.handle() > 0;
|
has_ids = fn.handle() > 0;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}));
|
}));
|
||||||
|
|
@ -1126,10 +1175,10 @@ TEST(RecordFunctionTest, Basic) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(RecordFunctionTest, OperatorNameOverload) {
|
TEST(RecordFunctionTest, OperatorNameOverload) {
|
||||||
std::set<std::string> operator_names;
|
static std::set<std::string> operator_names;
|
||||||
|
|
||||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||||
[&operator_names](const at::RecordFunction& fn) {
|
[](const at::RecordFunction& fn)
|
||||||
|
-> std::unique_ptr<at::ObserverContext> {
|
||||||
c10::optional<c10::OperatorName> op_name =
|
c10::optional<c10::OperatorName> op_name =
|
||||||
fn.operator_name();
|
fn.operator_name();
|
||||||
if (op_name.has_value()) {
|
if (op_name.has_value()) {
|
||||||
|
|
@ -1178,6 +1227,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(ThreadLocalDebugInfoTest, Basic) {
|
TEST(ThreadLocalDebugInfoTest, Basic) {
|
||||||
|
static std::atomic<bool> done{false};
|
||||||
|
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||||
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
|
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
|
||||||
|
|
@ -1190,10 +1241,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
||||||
// check that thread local debug info is propagated through fork calls
|
// check that thread local debug info is propagated through fork calls
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||||
std::atomic<bool> done{false};
|
|
||||||
{
|
{
|
||||||
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
||||||
at::launch([&done]() {
|
at::launch([]() {
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||||
done = true;
|
done = true;
|
||||||
});
|
});
|
||||||
|
|
@ -1206,7 +1256,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
||||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||||
done = false;
|
done = false;
|
||||||
auto handle = addGlobalCallback(RecordFunctionCallback(
|
auto handle = addGlobalCallback(RecordFunctionCallback(
|
||||||
[&done](const RecordFunction&) {
|
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||||
done = true;
|
done = true;
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
@ -1236,7 +1286,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
||||||
done = false;
|
done = false;
|
||||||
at::launch([&done]() {
|
at::launch([]() {
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
||||||
done = true;
|
done = true;
|
||||||
|
|
|
||||||
|
|
@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
||||||
at::enableRecordFunction(enable);
|
at::enableRecordFunction(enable);
|
||||||
});
|
});
|
||||||
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
|
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
|
||||||
auto cb = at::RecordFunctionCallback(
|
auto cb = at::RecordFunctionCallback(nullptr)
|
||||||
[](const at::RecordFunction&) { return nullptr; },
|
|
||||||
[](const at::RecordFunction&, at::ObserverContext*) {})
|
|
||||||
.needsInputs(true)
|
.needsInputs(true)
|
||||||
.samplingProb(sampling_prob);
|
.samplingProb(sampling_prob);
|
||||||
if (is_global) {
|
if (is_global) {
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,7 @@ void pushProfilingCallbacks() {
|
||||||
auto state_ptr = getProfilerTLSState();
|
auto state_ptr = getProfilerTLSState();
|
||||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||||
[](const at::RecordFunction& fn) {
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
auto state_ptr = getProfilerTLSState();
|
auto state_ptr = getProfilerTLSState();
|
||||||
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
|
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
|
||||||
return std::make_unique<KinetoObserverContext>();
|
return std::make_unique<KinetoObserverContext>();
|
||||||
|
|
|
||||||
|
|
@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() {
|
||||||
auto state_ptr = getProfilerTLSState();
|
auto state_ptr = getProfilerTLSState();
|
||||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||||
[](const at::RecordFunction& fn) {
|
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||||
auto state_ptr = getProfilerTLSState();
|
auto state_ptr = getProfilerTLSState();
|
||||||
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
|
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user