mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[pytorch][counters] DynamicCounter (#132166)
Summary: Implement a callback-based dynamic counter with pluggable backends. The backend API and integration is similar to WaitCounter. Note that this counter should only be used with C++ callbacks, since making it safe to be used for GIL-requiring callbacks would be pretty challenging and may defeat the whole purpose of this counter (since the duration of the callback can no longer be guaranteed). Test Plan: unit test Differential Revision: D60464055 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132166 Approved by: https://github.com/asiab4
This commit is contained in:
parent
dc38646c58
commit
cb4c107d70
76
c10/util/DynamicCounter.cpp
Normal file
76
c10/util/DynamicCounter.cpp
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#include <c10/util/DynamicCounter.h>
|
||||
|
||||
#include <c10/util/Synchronized.h>
|
||||
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::monitor {
|
||||
|
||||
namespace {
|
||||
using DynamicCounterBackends =
|
||||
std::vector<std::shared_ptr<detail::DynamicCounterBackendIf>>;
|
||||
|
||||
Synchronized<DynamicCounterBackends>& dynamicCounterBackends() {
|
||||
static auto instance = new Synchronized<DynamicCounterBackends>();
|
||||
return *instance;
|
||||
}
|
||||
|
||||
Synchronized<std::unordered_set<std::string>>& registeredCounters() {
|
||||
static auto instance = new Synchronized<std::unordered_set<std::string>>();
|
||||
return *instance;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
namespace detail {
|
||||
void registerDynamicCounterBackend(
|
||||
std::unique_ptr<DynamicCounterBackendIf> backend) {
|
||||
dynamicCounterBackends().withLock(
|
||||
[&](auto& backends) { backends.push_back(std::move(backend)); });
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
struct DynamicCounter::Guard {
|
||||
Guard(std::string_view key, Callback&& getCounterCallback)
|
||||
: key_{key},
|
||||
getCounterCallback_(std::move(getCounterCallback)),
|
||||
backends_{dynamicCounterBackends().withLock(
|
||||
[](auto& backends) { return backends; })} {
|
||||
registeredCounters().withLock([&](auto& registeredCounters) {
|
||||
if (!registeredCounters.insert(std::string(key)).second) {
|
||||
throw std::logic_error(
|
||||
"Counter " + std::string(key) + " already registered");
|
||||
}
|
||||
});
|
||||
|
||||
for (const auto& backend : backends_) {
|
||||
// Avoid copying the user-provided callback to avoid unexpected behavior
|
||||
// changes when more than one backend is registered.
|
||||
backend->registerCounter(key, [&]() { return getCounterCallback_(); });
|
||||
}
|
||||
}
|
||||
|
||||
~Guard() {
|
||||
for (const auto& backend : backends_) {
|
||||
backend->unregisterCounter(key_);
|
||||
}
|
||||
|
||||
registeredCounters().withLock(
|
||||
[&](auto& registeredCounters) { registeredCounters.erase(key_); });
|
||||
}
|
||||
|
||||
private:
|
||||
std::string key_;
|
||||
Callback getCounterCallback_;
|
||||
DynamicCounterBackends backends_;
|
||||
};
|
||||
|
||||
DynamicCounter::DynamicCounter(
|
||||
std::string_view key,
|
||||
Callback getCounterCallback)
|
||||
: guard_{std::make_unique<Guard>(key, std::move(getCounterCallback))} {}
|
||||
|
||||
DynamicCounter::~DynamicCounter() = default;
|
||||
|
||||
} // namespace c10::monitor
|
||||
49
c10/util/DynamicCounter.h
Normal file
49
c10/util/DynamicCounter.h
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#pragma once
|
||||
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <string_view>
|
||||
|
||||
#include <c10/macros/Macros.h>
|
||||
|
||||
namespace c10::monitor {
|
||||
|
||||
class C10_API DynamicCounter {
|
||||
public:
|
||||
using Callback = std::function<int64_t()>;
|
||||
|
||||
// Creates a dynamic counter that can be queried at any point in time by
|
||||
// multiple backends. Only one counter with a given key can exist at any point
|
||||
// in time.
|
||||
//
|
||||
// The callback is invoked every time the counter is queried.
|
||||
// The callback must be thread-safe.
|
||||
// The callback must not throw.
|
||||
// The callback must not block.
|
||||
DynamicCounter(std::string_view key, Callback getCounterCallback);
|
||||
|
||||
// Unregisters the callback.
|
||||
// Waits for all ongoing callback invocations to finish.
|
||||
~DynamicCounter();
|
||||
|
||||
private:
|
||||
struct Guard;
|
||||
std::unique_ptr<Guard> guard_;
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
class DynamicCounterBackendIf {
|
||||
public:
|
||||
virtual ~DynamicCounterBackendIf() = default;
|
||||
|
||||
virtual void registerCounter(
|
||||
std::string_view key,
|
||||
DynamicCounter::Callback getCounterCallback) = 0;
|
||||
// MUST wait for all ongoing callback invocations to finish
|
||||
virtual void unregisterCounter(std::string_view key) = 0;
|
||||
};
|
||||
|
||||
void C10_API
|
||||
registerDynamicCounterBackend(std::unique_ptr<DynamicCounterBackendIf>);
|
||||
} // namespace detail
|
||||
} // namespace c10::monitor
|
||||
Loading…
Reference in New Issue
Block a user