mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytorch][counters] Pybind for WaitCounter (#132357)
Summary: Basic pybind integration for WaitCounter providing a guard API. Also fixes broken copy/move constructor in WaitGuard (it wasn't really used with the macro-based C++ API). Test Plan: unit test Differential Revision: D60557660 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132357 Approved by: https://github.com/jamesperng, https://github.com/asiab4
This commit is contained in:
parent
d224857b3a
commit
fca2dba7ca
|
|
@ -45,6 +45,13 @@ class C10_API WaitCounterHandle {
|
||||||
|
|
||||||
class WaitGuard {
|
class WaitGuard {
|
||||||
public:
|
public:
|
||||||
|
WaitGuard(WaitGuard&& other) noexcept
|
||||||
|
: handle_{std::exchange(other.handle_, {})},
|
||||||
|
ctxs_{std::move(other.ctxs_)} {}
|
||||||
|
WaitGuard(const WaitGuard&) = delete;
|
||||||
|
WaitGuard& operator=(const WaitGuard&) = delete;
|
||||||
|
WaitGuard& operator=(WaitGuard&&) = delete;
|
||||||
|
|
||||||
~WaitGuard() {
|
~WaitGuard() {
|
||||||
stop();
|
stop();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,21 @@
|
||||||
# Owner(s): ["oncall: r2p"]
|
# Owner(s): ["oncall: r2p"]
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
TestCase, run_tests, skipIfTorchDynamo,
|
|
||||||
)
|
|
||||||
|
|
||||||
from datetime import timedelta, datetime
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from torch.monitor import (
|
from torch.monitor import (
|
||||||
Aggregation,
|
Aggregation,
|
||||||
Event,
|
Event,
|
||||||
log_event,
|
log_event,
|
||||||
register_event_handler,
|
register_event_handler,
|
||||||
unregister_event_handler,
|
|
||||||
Stat,
|
Stat,
|
||||||
TensorboardEventHandler,
|
TensorboardEventHandler,
|
||||||
|
unregister_event_handler,
|
||||||
|
_WaitCounter,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
|
||||||
|
|
||||||
class TestMonitor(TestCase):
|
class TestMonitor(TestCase):
|
||||||
def test_interval_stat(self) -> None:
|
def test_interval_stat(self) -> None:
|
||||||
|
|
@ -98,6 +97,13 @@ class TestMonitor(TestCase):
|
||||||
log_event(e)
|
log_event(e)
|
||||||
self.assertEqual(len(events), 2)
|
self.assertEqual(len(events), 2)
|
||||||
|
|
||||||
|
def test_wait_counter(self) -> None:
|
||||||
|
wait_counter = _WaitCounter(
|
||||||
|
"test_wait_counter",
|
||||||
|
)
|
||||||
|
with wait_counter.guard() as wcg:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@skipIfTorchDynamo("Really weird error")
|
@skipIfTorchDynamo("Really weird error")
|
||||||
class TestMonitorTensorboard(TestCase):
|
class TestMonitorTensorboard(TestCase):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
|
#include <c10/util/WaitCounter.h>
|
||||||
|
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
#include <torch/csrc/utils/python_arg_parser.h>
|
#include <torch/csrc/utils/python_arg_parser.h>
|
||||||
#include <torch/csrc/utils/python_numbers.h>
|
#include <torch/csrc/utils/python_numbers.h>
|
||||||
|
|
@ -296,6 +298,47 @@ void initMonitorBindings(PyObject* module) {
|
||||||
after calling ``register_event_handler``. After this returns the event
|
after calling ``register_event_handler``. After this returns the event
|
||||||
handler will no longer receive events.
|
handler will no longer receive events.
|
||||||
)DOC");
|
)DOC");
|
||||||
|
|
||||||
|
struct WaitCounterTracker {
|
||||||
|
explicit WaitCounterTracker(const c10::monitor::WaitCounterHandle& h)
|
||||||
|
: handle{h} {}
|
||||||
|
c10::monitor::WaitCounterHandle handle;
|
||||||
|
std::optional<c10::monitor::WaitCounterHandle::WaitGuard> guard;
|
||||||
|
};
|
||||||
|
py::class_<WaitCounterTracker, std::shared_ptr<WaitCounterTracker>>(
|
||||||
|
m, "_WaitCounterTracker")
|
||||||
|
.def(
|
||||||
|
"__enter__",
|
||||||
|
[](const std::shared_ptr<WaitCounterTracker>& self) {
|
||||||
|
self->guard.emplace(self->handle.start());
|
||||||
|
})
|
||||||
|
.def(
|
||||||
|
"__exit__",
|
||||||
|
[](const std::shared_ptr<WaitCounterTracker>& self,
|
||||||
|
const pybind11::args&) { self->guard.reset(); });
|
||||||
|
|
||||||
|
py::class_<c10::monitor::WaitCounterHandle>(
|
||||||
|
m,
|
||||||
|
"_WaitCounter",
|
||||||
|
R"DOC(
|
||||||
|
WaitCounter represents a named duration counter.
|
||||||
|
Multiple units of work can be tracked by the same WaitCounter. Depending
|
||||||
|
on the backend, the WaitCounter may track the number of units of work,
|
||||||
|
their duration etc.
|
||||||
|
)DOC")
|
||||||
|
.def(
|
||||||
|
py::init([](const std::string& key) {
|
||||||
|
return std::make_unique<c10::monitor::WaitCounterHandle>(key);
|
||||||
|
}),
|
||||||
|
py::arg("key"))
|
||||||
|
.def(
|
||||||
|
"guard",
|
||||||
|
[](const c10::monitor::WaitCounterHandle* self) {
|
||||||
|
return std::make_shared<WaitCounterTracker>(*self);
|
||||||
|
},
|
||||||
|
R"DOC(
|
||||||
|
Creates a guard that manages a single unit of work.
|
||||||
|
)DOC");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace monitor
|
} // namespace monitor
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
from torch._C._monitor import * # noqa: F403
|
from torch._C._monitor import * # noqa: F403
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from torch._C._monitor import _WaitCounter # type: ignore[attr-defined]
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from torch.utils.tensorboard import SummaryWriter
|
from torch.utils.tensorboard import SummaryWriter
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user