mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode (#165963)
Summary: This diff is the binding and hook layer for MTIA Graph, including 1. binding between Python and C++ 2. hook between Pytorch and mtia fbcode <img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340" /> [Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00) Test Plan: Will be tested in the python implementation which will use the binding and hook Differential Revision: D84457757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963 Approved by: https://github.com/malfet, https://github.com/albanD
This commit is contained in:
parent
1129605415
commit
d3be06cbdc
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <c10/core/CachingDeviceAllocator.h>
|
||||||
#include <c10/core/Device.h>
|
#include <c10/core/Device.h>
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
|
|
||||||
|
|
@ -151,6 +152,36 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual bool isAvailable() const override;
|
virtual bool isAvailable() const override;
|
||||||
|
|
||||||
|
/* MTIAGraph related APIs */
|
||||||
|
virtual int64_t mtiagraphCreate(bool keep_graph = false) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void mtiagraphCaptureBegin(int64_t handle, MempoolId_t pool) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void mtiagraphCaptureEnd(int64_t handle) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void mtiagraphInstantiate(int64_t handle) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void mtiagraphReplay(int64_t handle) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual void mtiagraphReset(int64_t handle) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual MempoolId_t mtiagraphPool(int64_t handle) const {
|
||||||
|
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API MTIAHooksArgs {};
|
struct TORCH_API MTIAHooksArgs {};
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,42 @@
|
||||||
|
|
||||||
namespace torch::mtia {
|
namespace torch::mtia {
|
||||||
|
|
||||||
|
struct _MTIAGraph {
|
||||||
|
// MTIA use accelerator hooks to connect pytorch and outside.
|
||||||
|
// We need to provide the MTIAGraph class at Python layer, but the hooks only
|
||||||
|
// support hooking functions, not classes. Thus we store all MTIAGraph C++
|
||||||
|
// instances in a map, and use a handle to choose the right instance.
|
||||||
|
int64_t handle_;
|
||||||
|
|
||||||
|
_MTIAGraph(bool keep_graph = false)
|
||||||
|
: handle_(at::detail::getMTIAHooks().mtiagraphCreate(keep_graph)) {}
|
||||||
|
~_MTIAGraph() = default;
|
||||||
|
|
||||||
|
void capture_begin(at::MempoolId_t pool) {
|
||||||
|
at::detail::getMTIAHooks().mtiagraphCaptureBegin(handle_, pool);
|
||||||
|
}
|
||||||
|
|
||||||
|
void capture_end() {
|
||||||
|
at::detail::getMTIAHooks().mtiagraphCaptureEnd(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void instantiate() {
|
||||||
|
at::detail::getMTIAHooks().mtiagraphInstantiate(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void replay() {
|
||||||
|
at::detail::getMTIAHooks().mtiagraphReplay(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
void reset() {
|
||||||
|
at::detail::getMTIAHooks().mtiagraphReset(handle_);
|
||||||
|
}
|
||||||
|
|
||||||
|
at::MempoolId_t pool() {
|
||||||
|
return at::detail::getMTIAHooks().mtiagraphPool(handle_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void initModule(PyObject* module) {
|
void initModule(PyObject* module) {
|
||||||
auto m = py::handle(module).cast<py::module>();
|
auto m = py::handle(module).cast<py::module>();
|
||||||
|
|
||||||
|
|
@ -131,6 +167,15 @@ void initModule(PyObject* module) {
|
||||||
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
|
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
|
||||||
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
|
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
py::class_<_MTIAGraph>(m, "_MTIAGraph")
|
||||||
|
.def(py::init<bool>(), py::arg("keep_graph") = false)
|
||||||
|
.def("capture_begin", &_MTIAGraph::capture_begin)
|
||||||
|
.def("capture_end", &_MTIAGraph::capture_end)
|
||||||
|
.def("instantiate", &_MTIAGraph::instantiate)
|
||||||
|
.def("replay", &_MTIAGraph::replay)
|
||||||
|
.def("reset", &_MTIAGraph::reset)
|
||||||
|
.def("pool", &_MTIAGraph::pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace torch::mtia
|
} // namespace torch::mtia
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user