mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Use plain old function pointer for RecordFunctionCallback (reapply) (#49408)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49408 Nearly every non-test callsite doesn't need to capture any variables anyway, and this saves 48 bytes per callback. ghstack-source-id: 118665808 Test Plan: Wait for GitHub CI since we had C++14-specific issues with this one in previous PR https://github.com/pytorch/pytorch/pull/48629 Reviewed By: malfet Differential Revision: D25563207 fbshipit-source-id: 6a2831205917d465f8248ca37429ba2428d5626d
This commit is contained in:
parent
e9d7d37ad0
commit
22c6dafd33
|
|
@ -277,11 +277,13 @@ class CallbackManager {
|
|||
bool is_start) {
|
||||
try {
|
||||
if (is_start) {
|
||||
ctx = rfcb.start()(rf);
|
||||
ctx = rfcb.start() ? rfcb.start()(rf) : nullptr;
|
||||
}
|
||||
else {
|
||||
if (rfcb.end()) {
|
||||
rfcb.end()(rf, ctx.get());
|
||||
}
|
||||
}
|
||||
return true;
|
||||
} catch (const std::exception &e) {
|
||||
LOG(WARNING) << "Exception in RecordFunction callback: "
|
||||
|
|
|
|||
|
|
@ -305,14 +305,16 @@ struct TORCH_API RecordFunction {
|
|||
*/
|
||||
class TORCH_API RecordFunctionCallback {
|
||||
public:
|
||||
using StartCallback = std::unique_ptr<ObserverContext>(*)(const RecordFunction&);
|
||||
using EndCallback = void (*)(const RecordFunction&, ObserverContext*);
|
||||
|
||||
// This interface supports observers that require passing an ObserverContext
|
||||
// between start and end callbacks.
|
||||
explicit RecordFunctionCallback(
|
||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
|
||||
std::function<void(const RecordFunction&, ObserverContext*)> end =
|
||||
[](const RecordFunction&, ObserverContext*) {}):
|
||||
start_(std::move(start)),
|
||||
end_(std::move(end)) {
|
||||
StartCallback start,
|
||||
EndCallback end = nullptr) :
|
||||
start_(start),
|
||||
end_(end) {
|
||||
scopes_.fill(true);
|
||||
}
|
||||
|
||||
|
|
@ -368,18 +370,18 @@ class TORCH_API RecordFunctionCallback {
|
|||
return scopes_[(size_t)sc];
|
||||
}
|
||||
|
||||
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
|
||||
inline StartCallback start() const {
|
||||
return start_;
|
||||
}
|
||||
|
||||
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
|
||||
inline EndCallback end() const {
|
||||
return end_;
|
||||
}
|
||||
|
||||
private:
|
||||
friend class CallbackManager;
|
||||
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
|
||||
std::function<void(const RecordFunction&, ObserverContext*)> end_;
|
||||
StartCallback start_;
|
||||
EndCallback end_;
|
||||
bool(*should_run_)(const RecordFunctionCallback&) = nullptr;
|
||||
double sampling_prob_ = 1.0;
|
||||
std::array<bool, static_cast<size_t>(RecordScope::NUM_SCOPES)> scopes_ = {};
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ const float kLowSamplingProb = 0.0001;
|
|||
|
||||
void addTestCallback(
|
||||
double sampling_prob = 1.0,
|
||||
std::function<std::unique_ptr<at::ObserverContext>(const at::RecordFunction&)> fn =
|
||||
[](const at::RecordFunction&) { return nullptr; }) {
|
||||
at::RecordFunctionCallback::StartCallback fn =
|
||||
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> { return nullptr; }) {
|
||||
auto cb = at::RecordFunctionCallback(
|
||||
std::move(fn),
|
||||
fn,
|
||||
[](const at::RecordFunction&, at::ObserverContext*) {})
|
||||
.needsInputs(false);
|
||||
if (sampling_prob < 1.0) {
|
||||
|
|
@ -106,10 +106,10 @@ int main(int argc, char** argv) {
|
|||
at::clearCallbacks();
|
||||
|
||||
std::cout << "Checking number of sampled observer invocations" << std::endl;
|
||||
int cb_count = 0;
|
||||
static int cb_count = 0;
|
||||
addTestCallback(
|
||||
kLowSamplingProb,
|
||||
[&](const at::RecordFunction& fn) {
|
||||
[](const at::RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
||||
++cb_count;
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -721,12 +721,40 @@ void checkTracedInputs(const TracedTestInputs& inputs) {
|
|||
TORCH_CHECK(found_mul);
|
||||
}
|
||||
|
||||
static bool bad_scope = false;
|
||||
template <RecordScope scope, size_t* cnt>
|
||||
std::unique_ptr<at::ObserverContext> checkScopeCallback(
|
||||
const at::RecordFunction& fn) {
|
||||
if (fn.scope() == scope) {
|
||||
++(*cnt);
|
||||
} else {
|
||||
bad_scope = true;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
template <RecordScope scope, size_t* cnt>
|
||||
void pushScopedCallback() {
|
||||
at::addGlobalCallback(
|
||||
at::RecordFunctionCallback(checkScopeCallback<scope, cnt>)
|
||||
.scopes({scope}));
|
||||
}
|
||||
|
||||
// These cannot be function-local because that would prohibit them
|
||||
// from being used as template arguments prior to C++17.
|
||||
static size_t fun_cnt;
|
||||
static size_t ts_fun_cnt;
|
||||
static size_t user_scope_cnt;
|
||||
|
||||
void checkScopeCallbacks() {
|
||||
bool found_function_scope = false;
|
||||
bool found_method_scope = false;
|
||||
bool found_user_scope = false;
|
||||
static bool found_function_scope;
|
||||
static bool found_method_scope;
|
||||
static bool found_user_scope;
|
||||
found_function_scope = false;
|
||||
found_method_scope = false;
|
||||
found_user_scope = false;
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||
[&](const at::RecordFunction& fn) {
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
if (fn.scope() == at::RecordScope::FUNCTION &&
|
||||
std::string(fn.name().str()) == "test_function") {
|
||||
found_function_scope = true;
|
||||
|
|
@ -742,27 +770,13 @@ void checkScopeCallbacks() {
|
|||
return nullptr;
|
||||
}));
|
||||
|
||||
bool bad_scope = false;
|
||||
auto pushScopedCallback = [&](at::RecordScope scope, size_t& cnt) {
|
||||
at::addGlobalCallback(
|
||||
at::RecordFunctionCallback(
|
||||
[&bad_scope, &cnt, scope](const at::RecordFunction& fn) {
|
||||
if (fn.scope() == scope) {
|
||||
++cnt;
|
||||
} else {
|
||||
bad_scope = true;
|
||||
}
|
||||
return nullptr;
|
||||
})
|
||||
.scopes({scope}));
|
||||
};
|
||||
|
||||
size_t fun_cnt = 0;
|
||||
pushScopedCallback(at::RecordScope::FUNCTION, fun_cnt);
|
||||
size_t ts_fun_cnt = 0;
|
||||
pushScopedCallback(at::RecordScope::TORCHSCRIPT_FUNCTION, ts_fun_cnt);
|
||||
size_t user_scope_cnt = 0;
|
||||
pushScopedCallback(at::RecordScope::USER_SCOPE, user_scope_cnt);
|
||||
bad_scope = false;
|
||||
fun_cnt = 0;
|
||||
pushScopedCallback<at::RecordScope::FUNCTION, &fun_cnt>();
|
||||
ts_fun_cnt = 0;
|
||||
pushScopedCallback<at::RecordScope::TORCHSCRIPT_FUNCTION, &ts_fun_cnt>();
|
||||
user_scope_cnt = 0;
|
||||
pushScopedCallback<at::RecordScope::USER_SCOPE, &user_scope_cnt>();
|
||||
|
||||
TORCH_CHECK(at::hasCallbacks());
|
||||
|
||||
|
|
@ -788,16 +802,11 @@ static bool shouldRunCallback(const RecordFunctionCallback&) {
|
|||
return should_run;
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, Basic) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
static TracedTestInputs traced_inputs;
|
||||
static std::unordered_set<std::string> ts_names;
|
||||
|
||||
// [(fn, [[sizes], [sizes], ...]), ...]
|
||||
TracedTestInputs traced_inputs;
|
||||
std::unordered_set<std::string> ts_names;
|
||||
addGlobalCallback(
|
||||
RecordFunctionCallback(
|
||||
[&](const RecordFunction& fn) {
|
||||
std::unique_ptr<at::ObserverContext> tracedInputsCallback(
|
||||
const RecordFunction& fn) {
|
||||
if (fn.scope() == RecordScope::FUNCTION) {
|
||||
auto inputs = fn.inputs();
|
||||
std::vector<std::vector<int64_t>> sizes;
|
||||
|
|
@ -813,8 +822,15 @@ TEST(RecordFunctionTest, Basic) {
|
|||
ts_names.insert(fn.name().str());
|
||||
}
|
||||
return nullptr;
|
||||
})
|
||||
.needsInputs(true));
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, TracedTestInputs) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
// [(fn, [[sizes], [sizes], ...]), ...]
|
||||
addGlobalCallback(
|
||||
RecordFunctionCallback(tracedInputsCallback).needsInputs(true));
|
||||
|
||||
TracedTestInputs eager_inputs, jit_inputs;
|
||||
{
|
||||
|
|
@ -841,28 +857,36 @@ TEST(RecordFunctionTest, Basic) {
|
|||
checkTracedInputs(eager_inputs);
|
||||
checkTracedInputs(jit_inputs);
|
||||
at::clearCallbacks();
|
||||
}
|
||||
|
||||
// test sampled callbacks
|
||||
int sampled_cb_ctr = 0;
|
||||
auto setup_sampled_callback = [&sampled_cb_ctr](double sampling_prob) {
|
||||
return addGlobalCallback(RecordFunctionCallback(
|
||||
[&sampled_cb_ctr](const RecordFunction& fn) {
|
||||
static int sampled_cb_ctr = 0;
|
||||
std::unique_ptr<ObserverContext> sampledCallback(const RecordFunction& fn) {
|
||||
if (std::string(fn.name().str()) == "test") {
|
||||
++sampled_cb_ctr;
|
||||
}
|
||||
return nullptr;
|
||||
})
|
||||
.samplingProb(sampling_prob));
|
||||
};
|
||||
}
|
||||
|
||||
int non_sampled_cb_ctr = 0;
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[&non_sampled_cb_ctr](const RecordFunction& fn) {
|
||||
static int non_sampled_cb_ctr = 0;
|
||||
std::unique_ptr<ObserverContext> nonSampledCallback(const RecordFunction& fn) {
|
||||
if (std::string(fn.name().str()) == "test") {
|
||||
++non_sampled_cb_ctr;
|
||||
}
|
||||
return nullptr;
|
||||
}));
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, SampledCallbacks) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
// test sampled callbacks
|
||||
sampled_cb_ctr = 0;
|
||||
auto setup_sampled_callback = [](double sampling_prob) {
|
||||
return addGlobalCallback(
|
||||
RecordFunctionCallback(sampledCallback).samplingProb(sampling_prob));
|
||||
};
|
||||
|
||||
addGlobalCallback(RecordFunctionCallback(nonSampledCallback));
|
||||
|
||||
auto handle = setup_sampled_callback(0.5);
|
||||
|
||||
|
|
@ -897,13 +921,19 @@ TEST(RecordFunctionTest, Basic) {
|
|||
// test the scope of the callbacks
|
||||
checkScopeCallbacks();
|
||||
clearCallbacks();
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, RecordFunctionGuard) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
static std::vector<std::string> fn_names;
|
||||
static std::mutex guard_mtx;
|
||||
|
||||
// check record function guard
|
||||
std::vector<std::string> fn_names;
|
||||
std::mutex mtx;
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[&fn_names, &mtx](const RecordFunction& fn) {
|
||||
std::lock_guard<std::mutex> lock(mtx);
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
std::lock_guard<std::mutex> lock(guard_mtx);
|
||||
fn_names.push_back(fn.name().str());
|
||||
return nullptr;
|
||||
}));
|
||||
|
|
@ -925,20 +955,26 @@ TEST(RecordFunctionTest, Basic) {
|
|||
TORCH_CHECK(fn_names.size() == 1);
|
||||
TORCH_CHECK(fn_names[0] == "B");
|
||||
clearCallbacks();
|
||||
}
|
||||
|
||||
// test add/remove
|
||||
std::vector<size_t> ids;
|
||||
auto add_remove_test_add_cb = [&ids](size_t id) {
|
||||
static std::vector<size_t> ids;
|
||||
|
||||
template <size_t id>
|
||||
auto add_remove_test_add_cb() {
|
||||
return addGlobalCallback(RecordFunctionCallback(
|
||||
[&ids, id](const RecordFunction& fn) {
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
ids.push_back(id);
|
||||
return nullptr;
|
||||
}));
|
||||
};
|
||||
}
|
||||
|
||||
auto h1 = add_remove_test_add_cb(1);
|
||||
auto h2 = add_remove_test_add_cb(2);
|
||||
auto h3 = add_remove_test_add_cb(3);
|
||||
TEST(RecordFunctionTest, Callbacks) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
auto h1 = add_remove_test_add_cb<1>();
|
||||
auto h2 = add_remove_test_add_cb<2>();
|
||||
auto h3 = add_remove_test_add_cb<3>();
|
||||
|
||||
{ RECORD_USER_SCOPE("test"); }
|
||||
|
||||
|
|
@ -969,8 +1005,7 @@ TEST(RecordFunctionTest, Basic) {
|
|||
// thread local / global callbacks
|
||||
|
||||
ids.clear();
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[&ids](const RecordFunction& fn) { ids.push_back(1); return nullptr; }));
|
||||
add_remove_test_add_cb<1>();
|
||||
|
||||
{ RECORD_USER_SCOPE("test"); }
|
||||
|
||||
|
|
@ -978,9 +1013,12 @@ TEST(RecordFunctionTest, Basic) {
|
|||
TORCH_CHECK(ids[0] == 1);
|
||||
ids.clear();
|
||||
|
||||
auto th = std::thread([&ids]() {
|
||||
auto th = std::thread([]() {
|
||||
addThreadLocalCallback(RecordFunctionCallback(
|
||||
[&ids](const RecordFunction& fn) { ids.push_back(2); return nullptr; }));
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
ids.push_back(2);
|
||||
return nullptr;
|
||||
}));
|
||||
|
||||
{ RECORD_USER_SCOPE("test_thread"); }
|
||||
});
|
||||
|
|
@ -1005,22 +1043,20 @@ TEST(RecordFunctionTest, Basic) {
|
|||
};
|
||||
ids.clear();
|
||||
{ // START: global test
|
||||
const int test_val = 123;
|
||||
const std::string test_str = "test str";
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
||||
[](const RecordFunction &
|
||||
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
||||
auto ctx = std::make_unique<TestContext>();
|
||||
ctx->a = test_val;
|
||||
ctx->b = test_str;
|
||||
ctx->a = 123;
|
||||
ctx->b = "test_str";
|
||||
ids.push_back(1);
|
||||
return ctx;
|
||||
},
|
||||
[test_val, test_str](
|
||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||
TORCH_CHECK(ctx_ptr != nullptr);
|
||||
TORCH_CHECK(ctx->a == test_val);
|
||||
TORCH_CHECK(ctx->b == test_str);
|
||||
TORCH_CHECK(ctx->a == 123);
|
||||
TORCH_CHECK(ctx->b == "test_str");
|
||||
}));
|
||||
|
||||
{ RECORD_USER_SCOPE("test"); }
|
||||
|
|
@ -1030,23 +1066,23 @@ TEST(RecordFunctionTest, Basic) {
|
|||
ids.clear();
|
||||
} // END: global test
|
||||
{ // START: thread local test
|
||||
auto ctx_th = std::thread([&ids]() {
|
||||
auto ctx_th = std::thread([]() {
|
||||
const int test_val = 234;
|
||||
const std::string test_str = "test thread str";
|
||||
addThreadLocalCallback(RecordFunctionCallback(
|
||||
[test_val, test_str, &ids](const RecordFunction& /* unused */) {
|
||||
[](const RecordFunction &
|
||||
/* unused */) -> std::unique_ptr<at::ObserverContext> {
|
||||
auto ctx = std::make_unique<TestContext>();
|
||||
ctx->a = test_val;
|
||||
ctx->b = test_str;
|
||||
ctx->a = 234;
|
||||
ctx->b = "test_thread_str";
|
||||
ids.push_back(2);
|
||||
return ctx;
|
||||
},
|
||||
[test_val, test_str](
|
||||
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
[](const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
|
||||
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
|
||||
TORCH_CHECK(ctx_ptr != nullptr);
|
||||
TORCH_CHECK(ctx->a == test_val);
|
||||
TORCH_CHECK(ctx->b == test_str);
|
||||
TORCH_CHECK(ctx->a == 234);
|
||||
TORCH_CHECK(ctx->b == "test_thread_str");
|
||||
}));
|
||||
|
||||
// Will call both global and thread local callbacks.
|
||||
|
|
@ -1060,13 +1096,20 @@ TEST(RecordFunctionTest, Basic) {
|
|||
} // END: thread local test
|
||||
|
||||
clearCallbacks();
|
||||
}
|
||||
|
||||
// test should_run
|
||||
TEST(RecordFunctionTest, ShouldRun) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
bool ran = false;
|
||||
should_run = false;
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[&ran](const RecordFunction& fn) { ran = true; return nullptr; })
|
||||
static bool ran = false;
|
||||
addGlobalCallback(
|
||||
RecordFunctionCallback(
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
ran = true;
|
||||
return nullptr;
|
||||
})
|
||||
.setShouldRun(shouldRunCallback));
|
||||
|
||||
{ RECORD_USER_SCOPE("test"); }
|
||||
|
|
@ -1080,13 +1123,20 @@ TEST(RecordFunctionTest, Basic) {
|
|||
TORCH_CHECK(ran);
|
||||
|
||||
clearCallbacks();
|
||||
}
|
||||
|
||||
TEST(RecordFunctionTest, Basic) {
|
||||
// disabling the inlining of method calls
|
||||
GraphOptimizerEnabledGuard opt_guard(false);
|
||||
|
||||
static std::string recorded_op;
|
||||
static bool has_ids = false;
|
||||
|
||||
// test propagation of TLS callbacks
|
||||
std::thread t([]() {
|
||||
RecordFunctionGuard enable_rec_fn;
|
||||
std::string recorded_op;
|
||||
auto handle = addThreadLocalCallback(RecordFunctionCallback(
|
||||
[&recorded_op](const RecordFunction& fn) {
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
recorded_op = fn.name().str();
|
||||
return nullptr;
|
||||
}));
|
||||
|
|
@ -1096,17 +1146,16 @@ TEST(RecordFunctionTest, Basic) {
|
|||
RECORD_USER_SCOPE("test_in_thread");
|
||||
});
|
||||
t_child.join();
|
||||
TORCH_CHECK(recorded_op == "test_in_thread");
|
||||
EXPECT_EQ(recorded_op, "test_in_thread");
|
||||
removeCallback(handle);
|
||||
});
|
||||
t.join();
|
||||
clearCallbacks();
|
||||
|
||||
// test set ids
|
||||
bool has_ids = false;
|
||||
addGlobalCallback(
|
||||
RecordFunctionCallback(
|
||||
[&has_ids](const RecordFunction& fn) {
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
has_ids = fn.handle() > 0;
|
||||
return nullptr;
|
||||
})
|
||||
|
|
@ -1116,7 +1165,7 @@ TEST(RecordFunctionTest, Basic) {
|
|||
clearCallbacks();
|
||||
has_ids = false;
|
||||
addGlobalCallback(RecordFunctionCallback(
|
||||
[&has_ids](const RecordFunction& fn) {
|
||||
[](const RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
has_ids = fn.handle() > 0;
|
||||
return nullptr;
|
||||
}));
|
||||
|
|
@ -1126,10 +1175,10 @@ TEST(RecordFunctionTest, Basic) {
|
|||
}
|
||||
|
||||
TEST(RecordFunctionTest, OperatorNameOverload) {
|
||||
std::set<std::string> operator_names;
|
||||
|
||||
static std::set<std::string> operator_names;
|
||||
at::addGlobalCallback(at::RecordFunctionCallback(
|
||||
[&operator_names](const at::RecordFunction& fn) {
|
||||
[](const at::RecordFunction& fn)
|
||||
-> std::unique_ptr<at::ObserverContext> {
|
||||
c10::optional<c10::OperatorName> op_name =
|
||||
fn.operator_name();
|
||||
if (op_name.has_value()) {
|
||||
|
|
@ -1178,6 +1227,8 @@ void checkDebugInfo(c10::DebugInfoKind kind, int model_id) {
|
|||
}
|
||||
|
||||
TEST(ThreadLocalDebugInfoTest, Basic) {
|
||||
static std::atomic<bool> done{false};
|
||||
|
||||
TORCH_CHECK(
|
||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||
auto debug_info = std::make_shared<TestThreadLocalDebugInfo>();
|
||||
|
|
@ -1190,10 +1241,9 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
|||
// check that thread local debug info is propagated through fork calls
|
||||
TORCH_CHECK(
|
||||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||
std::atomic<bool> done{false};
|
||||
{
|
||||
c10::DebugInfoGuard guard(c10::DebugInfoKind::TEST_INFO, debug_info);
|
||||
at::launch([&done]() {
|
||||
at::launch([]() {
|
||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||
done = true;
|
||||
});
|
||||
|
|
@ -1206,7 +1256,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
|||
c10::ThreadLocalDebugInfo::get(c10::DebugInfoKind::TEST_INFO) == nullptr);
|
||||
done = false;
|
||||
auto handle = addGlobalCallback(RecordFunctionCallback(
|
||||
[&done](const RecordFunction&) {
|
||||
[](const RecordFunction&) -> std::unique_ptr<at::ObserverContext> {
|
||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||
done = true;
|
||||
return nullptr;
|
||||
|
|
@ -1236,7 +1286,7 @@ TEST(ThreadLocalDebugInfoTest, Basic) {
|
|||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
||||
done = false;
|
||||
at::launch([&done]() {
|
||||
at::launch([]() {
|
||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO, 42);
|
||||
checkDebugInfo(c10::DebugInfoKind::TEST_INFO_2, 314);
|
||||
done = true;
|
||||
|
|
|
|||
|
|
@ -172,9 +172,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
at::enableRecordFunction(enable);
|
||||
});
|
||||
m.def("_set_empty_test_observer", [](bool is_global, double sampling_prob) {
|
||||
auto cb = at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction&) { return nullptr; },
|
||||
[](const at::RecordFunction&, at::ObserverContext*) {})
|
||||
auto cb = at::RecordFunctionCallback(nullptr)
|
||||
.needsInputs(true)
|
||||
.samplingProb(sampling_prob);
|
||||
if (is_global) {
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ void pushProfilingCallbacks() {
|
|||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction& fn) {
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
if (!state_ptr || state_ptr->config().state != ProfilerState::KINETO) {
|
||||
return std::make_unique<KinetoObserverContext>();
|
||||
|
|
|
|||
|
|
@ -414,7 +414,7 @@ void pushProfilingCallbacksLegacy() {
|
|||
auto state_ptr = getProfilerTLSState();
|
||||
TORCH_INTERNAL_ASSERT(state_ptr, "Expected profiler state set");
|
||||
auto handle = at::addThreadLocalCallback(at::RecordFunctionCallback(
|
||||
[](const at::RecordFunction& fn) {
|
||||
[](const at::RecordFunction& fn) -> std::unique_ptr<at::ObserverContext> {
|
||||
auto state_ptr = getProfilerTLSState();
|
||||
if (!state_ptr || state_ptr->config().state == ProfilerState::Disabled) {
|
||||
return nullptr;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user