[kineto] Optimize getStepCallbacks for common case of no active callbacks

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77804

IIUC, the result of this function will be empty and unused if there are no sampled callbacks, which is the common case. We can accelerate this case by wrapping the result in an optional to save initializing an empty SmallVector.

Differential Revision: [D36497279](https://our.internmc.facebook.com/intern/diff/D36497279/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36497279/)!

Approved by: https://github.com/robieta
This commit is contained in:
Scott Wolchok 2022-05-19 15:26:22 -07:00 committed by PyTorch MergeBot
parent 02c4d877b4
commit c083489f46
7 changed files with 59 additions and 25 deletions

View File

@ -545,9 +545,9 @@ C10_ALWAYS_INLINE_UNLESS_MOBILE Return Dispatcher::call(const TypedOperatorHandl
.template getDispatchKeySetUnboxed<Args...>(args...); .template getDispatchKeySetUnboxed<Args...>(args...);
const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet); const KernelFunction& kernel = op.operatorDef_->op.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(!step_callbacks.empty() && op.operatorDef_->op.isObserved())) { if (C10_UNLIKELY(step_callbacks.has_value() && op.operatorDef_->op.isObserved())) {
return callWithDispatchKeySlowPath<Return, Args...>(op, step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...); return callWithDispatchKeySlowPath<Return, Args...>(op, *step_callbacks, dispatchKeySet, kernel, std::forward<Args>(args)...);
} }
#endif // PYTORCH_DISABLE_PER_OP_PROFILING #endif // PYTORCH_DISABLE_PER_OP_PROFILING
return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...); return kernel.template call<Return, Args...>(op, dispatchKeySet, std::forward<Args>(args)...);
@ -568,9 +568,9 @@ inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const
auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack); auto dispatchKeySet = entry.dispatchKeyExtractor().getDispatchKeySetBoxed(stack);
const auto& kernel = entry.lookup(dispatchKeySet); const auto& kernel = entry.lookup(dispatchKeySet);
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = at::getStepCallbacks(at::RecordScope::FUNCTION); auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::FUNCTION);
if (C10_UNLIKELY(!step_callbacks.empty() && entry.isObserved())) { if (C10_UNLIKELY(step_callbacks.has_value() && entry.isObserved())) {
at::RecordFunction guard(std::move(step_callbacks)); at::RecordFunction guard(std::move(*step_callbacks));
auto dispatchKey = dispatchKeySet.highestPriorityTypeId(); auto dispatchKey = dispatchKeySet.highestPriorityTypeId();
auto& schema = op.schema(); auto& schema = op.schema();
auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema); auto schema_ref = std::reference_wrapper<const FunctionSchema>(schema);

View File

@ -130,6 +130,7 @@ class CacheEntry {
// The caller is expected to check `GlobalCallbackManager::get().version()' // The caller is expected to check `GlobalCallbackManager::get().version()'
// and call CacheEntry::update() if necessary. // and call CacheEntry::update() if necessary.
StepCallbacks getActiveCallbacks(); StepCallbacks getActiveCallbacks();
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty();
// Full rebuild. (E.g. during registration) // Full rebuild. (E.g. during registration)
void update(const std::vector<RecordFunctionCallback>& callbacks); void update(const std::vector<RecordFunctionCallback>& callbacks);
@ -142,6 +143,8 @@ class CacheEntry {
int tries_left_{-1}; int tries_left_{-1};
}; };
C10_ALWAYS_INLINE void getActiveCallbacksImpl();
void rebuildActiveCallbacks(); void rebuildActiveCallbacks();
int sampleTries(double p) const; int sampleTries(double p) const;
@ -169,6 +172,7 @@ class LocalCallbackManager {
public: public:
const RecordFunctionTLS& getTLS() const; const RecordFunctionTLS& getTLS() const;
StepCallbacks getActiveCallbacks(const RecordScope scope); StepCallbacks getActiveCallbacks(const RecordScope scope);
c10::optional<StepCallbacks> getActiveCallbacksUnlessEmpty(const RecordScope scope);
void setTLS(const RecordFunctionTLS& tls); void setTLS(const RecordFunctionTLS& tls);
void seed(uint32_t seed); void seed(uint32_t seed);
@ -178,6 +182,8 @@ class LocalCallbackManager {
void clearCallbacks(); void clearCallbacks();
private: private:
void rebuildActiveCallbacksIfNeeded();
void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot); void rebuild_all(const GlobalCallbackManager::snapshot_t& global_snapshot);
void rebuild_callback_scopes( void rebuild_callback_scopes(
@ -271,7 +277,7 @@ void CacheEntry::update(const std::vector<RecordFunctionCallback>& callbacks) {
rebuildActiveCallbacks(); rebuildActiveCallbacks();
} }
StepCallbacks CacheEntry::getActiveCallbacks() { void CacheEntry::getActiveCallbacksImpl() {
// We rebuild the active set when `sampling_countdown_` reaches zero, so if it // We rebuild the active set when `sampling_countdown_` reaches zero, so if it
// reaches zero at the start of this function something has gone wrong. // reaches zero at the start of this function something has gone wrong.
TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_); TORCH_INTERNAL_ASSERT(sampling_countdown_ > 0, sampling_countdown_);
@ -295,7 +301,18 @@ StepCallbacks CacheEntry::getActiveCallbacks() {
} }
} }
} }
}
StepCallbacks CacheEntry::getActiveCallbacks() {
getActiveCallbacksImpl();
return active_callbacks_;
}
c10::optional<StepCallbacks> CacheEntry::getActiveCallbacksUnlessEmpty() {
getActiveCallbacksImpl();
if (C10_LIKELY(active_callbacks_.empty())) {
return c10::nullopt;
}
return active_callbacks_; return active_callbacks_;
} }
@ -365,15 +382,25 @@ const RecordFunctionTLS& LocalCallbackManager::getTLS() const {
return registered_callbacks_; return registered_callbacks_;
} }
StepCallbacks LocalCallbackManager::getActiveCallbacks( void LocalCallbackManager::rebuildActiveCallbacksIfNeeded() {
const RecordScope scope) {
const auto global_version = GlobalCallbackManager::get().version(); const auto global_version = GlobalCallbackManager::get().version();
if (C10_UNLIKELY(global_version != global_version_)) { if (C10_UNLIKELY(global_version != global_version_)) {
rebuild_all(GlobalCallbackManager::get().getSnapshot()); rebuild_all(GlobalCallbackManager::get().getSnapshot());
} }
}
StepCallbacks LocalCallbackManager::getActiveCallbacks(
const RecordScope scope) {
rebuildActiveCallbacksIfNeeded();
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks(); return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacks();
} }
c10::optional<StepCallbacks> LocalCallbackManager::getActiveCallbacksUnlessEmpty(
const RecordScope scope) {
rebuildActiveCallbacksIfNeeded();
return active_callbacks_[static_cast<size_t>(scope)].getActiveCallbacksUnlessEmpty();
}
void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) { void LocalCallbackManager::setTLS(const RecordFunctionTLS& tls) {
registered_callbacks_ = tls; registered_callbacks_ = tls;
rebuild_all(GlobalCallbackManager::get().getSnapshot()); rebuild_all(GlobalCallbackManager::get().getSnapshot());
@ -572,6 +599,10 @@ StepCallbacks getStepCallbacks(RecordScope scope) {
return LocalCallbackManager::get().getActiveCallbacks(scope); return LocalCallbackManager::get().getActiveCallbacks(scope);
} }
c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope) {
return LocalCallbackManager::get().getActiveCallbacksUnlessEmpty(scope);
}
const RecordFunctionTLS& get_record_function_tls_() { const RecordFunctionTLS& get_record_function_tls_() {
return LocalCallbackManager::get().getTLS(); return LocalCallbackManager::get().getTLS();
} }

View File

@ -478,6 +478,8 @@ struct TORCH_API RecordFunction {
TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); TORCH_API StepCallbacks getStepCallbacks(RecordScope scope);
TORCH_API c10::optional<StepCallbacks> getStepCallbacksUnlessEmpty(RecordScope scope);
namespace detail { namespace detail {
template <typename Inputs, typename F, typename... Args> template <typename Inputs, typename F, typename... Args>
void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) { void record_function_with_scope(RecordFunction& guard, F fn, const Inputs& inputs, Args&&... args) {

View File

@ -1,3 +1,4 @@
#include <torch/torch.h> #include <torch/torch.h>
#include <ATen/record_function.h> #include <ATen/record_function.h>
@ -49,9 +50,9 @@ float runPureRecordFunctionBench(int iter) {
typedef std::chrono::microseconds us; typedef std::chrono::microseconds us;
std::chrono::time_point<clock> start_time = clock::now(); std::chrono::time_point<clock> start_time = clock::now();
for (auto idx = 0; idx < iter; ++idx) { for (auto idx = 0; idx < iter; ++idx) {
auto step_callbacks = at::getStepCallbacks(at::RecordScope::USER_SCOPE); auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::USER_SCOPE);
if (!step_callbacks.empty()) { if (step_callbacks.has_value()) {
at::RecordFunction guard(std::move(step_callbacks)); at::RecordFunction guard(std::move(*step_callbacks));
guard.before("Test", -1); guard.before("Test", -1);
} }
} }

View File

@ -151,9 +151,9 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
// probably operate with names. // probably operate with names.
at::NoNamesGuard no_names_guard; at::NoNamesGuard no_names_guard;
auto step_callbacks = at::getStepCallbacks(at::RecordScope::BACKWARD_FUNCTION); auto step_callbacks = at::getStepCallbacksUnlessEmpty(at::RecordScope::BACKWARD_FUNCTION);
if (!step_callbacks.empty()) { if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(step_callbacks)); at::RecordFunction guard(std::move(*step_callbacks));
// Using sequence number and thread id to correlate with // Using sequence number and thread id to correlate with
// the forward pass function // the forward pass function
guard.setForwardThreadId(thread_id_); guard.setForwardThreadId(thread_id_);

View File

@ -845,11 +845,11 @@ struct InterpreterStateImpl : c10::intrusive_ptr_target {
static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { static void checkAndStartRecordFunction(Frame& frame, Stack& stack) {
if (!frame.record_function) { if (!frame.record_function) {
auto step_callbacks = auto step_callbacks = at::getStepCallbacksUnlessEmpty(
at::getStepCallbacks(at::RecordScope::TORCHSCRIPT_FUNCTION); at::RecordScope::TORCHSCRIPT_FUNCTION);
if (!step_callbacks.empty()) { if (C10_UNLIKELY(step_callbacks.has_value())) {
auto rec_fn = auto rec_fn =
std::make_unique<at::RecordFunction>(std::move(step_callbacks)); std::make_unique<at::RecordFunction>(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(rec_fn->isActive());
if (rec_fn->needsInputs()) { if (rec_fn->needsInputs()) {
rec_fn->before( rec_fn->before(

View File

@ -1201,9 +1201,9 @@ c10::IValue BlockRunner::run_impl_record_functions(
IValueList&& args, IValueList&& args,
const KeywordArgs& kwargs) { const KeywordArgs& kwargs) {
auto step_callbacks = auto step_callbacks =
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_MODEL); at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_MODEL);
if (!step_callbacks.empty()) { if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(step_callbacks)); at::RecordFunction guard(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
guard.needsInputs() guard.needsInputs()
? guard.before( ? guard.before(
@ -1845,9 +1845,9 @@ std::vector<IValue> ProcessedNode::inputs_ivalue_vec() const {
void ProcessedNode::run() { void ProcessedNode::run() {
#ifndef PYTORCH_DISABLE_PER_OP_PROFILING #ifndef PYTORCH_DISABLE_PER_OP_PROFILING
auto step_callbacks = auto step_callbacks =
at::getStepCallbacks(at::RecordScope::STATIC_RUNTIME_OP); at::getStepCallbacksUnlessEmpty(at::RecordScope::STATIC_RUNTIME_OP);
if (!step_callbacks.empty()) { if (C10_UNLIKELY(step_callbacks.has_value())) {
at::RecordFunction guard(std::move(step_callbacks)); at::RecordFunction guard(std::move(*step_callbacks));
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive()); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(guard.isActive());
if (guard.needsInputs()) { if (guard.needsInputs()) {
const auto inputs = inputs_ivalue_vec(); const auto inputs = inputs_ivalue_vec();