mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pytorch][counters] Pybind for WaitCounter (#132167)
Summary: Basic pybind integration for WaitCounter providing a guard API. Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API). Test Plan: unit test Reviewed By: asiab4 Differential Revision: D60463979 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132167 Approved by: https://github.com/asiab4
This commit is contained in:
parent
39a3c98aa6
commit
2c7bd61afa
|
|
@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
|
||||||
|
|
||||||
class WaitGuard {
|
class WaitGuard {
|
||||||
public:
|
public:
|
||||||
|
WaitGuard(WaitGuard&& other) noexcept
|
||||||
|
: handle_{std::exchange(other.handle_, {})},
|
||||||
|
ctxs_{std::move(other.ctxs_)} {}
|
||||||
|
WaitGuard(const WaitGuard&) = delete;
|
||||||
|
WaitGuard& operator=(const WaitGuard&) = delete;
|
||||||
|
WaitGuard& operator=(WaitGuard&&) = delete;
|
||||||
|
|
||||||
~WaitGuard() {
|
~WaitGuard() {
|
||||||
stop();
|
stop();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@ from torch.monitor import (
|
||||||
unregister_event_handler,
|
unregister_event_handler,
|
||||||
Stat,
|
Stat,
|
||||||
TensorboardEventHandler,
|
TensorboardEventHandler,
|
||||||
|
WaitCounter,
|
||||||
)
|
)
|
||||||
|
|
||||||
class TestMonitor(TestCase):
|
class TestMonitor(TestCase):
|
||||||
|
|
@ -98,6 +99,13 @@ class TestMonitor(TestCase):
|
||||||
log_event(e)
|
log_event(e)
|
||||||
self.assertEqual(len(events), 2)
|
self.assertEqual(len(events), 2)
|
||||||
|
|
||||||
|
def test_wait_counter(self) -> None:
|
||||||
|
wait_counter = WaitCounter(
|
||||||
|
"test_wait_counter",
|
||||||
|
)
|
||||||
|
with wait_counter.guard() as wcg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("Really weird error")
|
@skipIfTorchDynamo("Really weird error")
|
||||||
class TestMonitorTensorboard(TestCase):
|
class TestMonitorTensorboard(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include <c10/util/WaitCounter.h>
|
||||||
|
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/csrc/utils/python_arg_parser.h>
|
#include <torch/csrc/utils/python_arg_parser.h>
|
||||||
#include <torch/csrc/utils/python_numbers.h>
|
#include <torch/csrc/utils/python_numbers.h>
|
||||||
|
|
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
|
||||||
after calling ``register_event_handler``. After this returns the event
|
after calling ``register_event_handler``. After this returns the event
|
||||||
handler will no longer receive events.
|
handler will no longer receive events.
|
||||||
)DOC");
|
)DOC");
|
||||||
|
|
||||||
|
struct WaitCounterTracker {
|
||||||
|
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
|
||||||
|
: handle{h} {}
|
||||||
|
c10::monitor::WaitCounterHandle handle;
|
||||||
|
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
|
||||||
|
};
|
||||||
|
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
|
||||||
|
m, "WaitCounterTracker")
|
||||||
|
.def(
|
||||||
|
"__enter__",
|
||||||
|
[](const std::shared_ptr<WaitCounterTracker>& self) {
|
||||||
|
self->guard.emplace(self->handle.start());
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"__exit__",
|
||||||
|
[](const std::shared_ptr<WaitCounterTracker>& self,
|
||||||
|
const pybind11::args&) { self->guard.reset(); });
|
||||||
|
|
||||||
|
py::class_<c10::monitor::WaitCounterHandle>(
|
||||||
|
m,
|
||||||
|
"WaitCounter",
|
||||||
|
R"DOC(
|
||||||
|
WaitCounter represents a named duration counter.
|
||||||
|
Multiple units of work can be tracked by the same WaitCounter. Depending
|
||||||
|
on the backend, the WaitCounter may track the number of units of work,
|
||||||
|
their duration etc.
|
||||||
|
)DOC")
|
||||||
|
.def(
|
||||||
|
py::init([](const std::string& key) {
|
||||||
|
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
|
||||||
|
}),
|
||||||
|
py::arg("key"))
|
||||||
|
.def(
|
||||||
|
"guard",
|
||||||
|
[](const c10::monitor::WaitCounterHandle* self) {
|
||||||
|
return std::make_shared<WaitCounterTracker>(*self);
|
||||||
|
},
|
||||||
|
R"DOC(
|
||||||
|
Creates a guard that manages a single unit of work.
|
||||||
|
)DOC");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace monitor
|
} // namespace monitor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user