pytorch/torch/csrc/mtia/Module.cpp
Andy (An) Wang d3be06cbdc [MTIAGraph][Pytorch][2/n] Add binding for Python to C++, and hook for Pytorch to Fbcode (#165963)
Summary:
This diff is the binding and hook layer for MTIA Graph, including
1. binding between Python and C++
2. hook between Pytorch and mtia fbcode
<img width="1780" height="754" alt="image" src="https://github.com/user-attachments/assets/31e24e5b-8324-42d8-8d3b-59536bc18340" />

[Doc](https://docs.google.com/document/d/1Q3xdZAIqhBvuy2HxGDfJyXVmxYXUEeYSZSwsp7bcJF8/edit?tab=t.osb46a42t6wb#heading=h.ayp9tkk08x00)

Test Plan: Will be tested in the python implementation which will use the binding and hook

Differential Revision: D84457757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165963
Approved by: https://github.com/malfet, https://github.com/albanD
2025-10-31 02:52:51 +00:00

182 lines
5.9 KiB
C++

#include <ATen/ATen.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <torch/csrc/Generator.h>
#include <torch/csrc/Stream.h>
#include <torch/csrc/mtia/Module.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/device_lazy_init.h>
#include <torch/csrc/utils/pybind.h>
namespace torch::mtia {
struct _MTIAGraph {
// MTIA use accelerator hooks to connect pytorch and outside.
// We need to provide the MTIAGraph class at Python layer, but the hooks only
// support hooking functions, not classes. Thus we store all MTIAGraph C++
// instances in a map, and use a handle to choose the right instance.
int64_t handle_;
_MTIAGraph(bool keep_graph = false)
: handle_(at::detail::getMTIAHooks().mtiagraphCreate(keep_graph)) {}
~_MTIAGraph() = default;
void capture_begin(at::MempoolId_t pool) {
at::detail::getMTIAHooks().mtiagraphCaptureBegin(handle_, pool);
}
void capture_end() {
at::detail::getMTIAHooks().mtiagraphCaptureEnd(handle_);
}
void instantiate() {
at::detail::getMTIAHooks().mtiagraphInstantiate(handle_);
}
void replay() {
at::detail::getMTIAHooks().mtiagraphReplay(handle_);
}
void reset() {
at::detail::getMTIAHooks().mtiagraphReset(handle_);
}
at::MempoolId_t pool() {
return at::detail::getMTIAHooks().mtiagraphPool(handle_);
}
};
void initModule(PyObject* module) {
auto m = py::handle(module).cast<py::module>();
m.def("_mtia_init", []() {
TORCH_INTERNAL_ASSERT(!torch::utils::is_device_in_bad_fork(at::kMTIA));
torch::utils::register_fork_handler_for_device_init(at::kMTIA);
at::globalContext().lazyInitDevice(c10::DeviceType::MTIA);
});
m.def("_mtia_isBuilt", []() {
// Check if the MTIAHooks class has been registered with the registry.
return at::detail::isMTIAHooksBuilt();
});
m.def("_mtia_isInBadFork", []() {
return torch::utils::is_device_in_bad_fork(at::kMTIA);
});
m.def("_mtia_getCurrentStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentStream(device_index);
});
m.def("_mtia_getCurrentRawStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getCurrentRawStream(device_index);
});
m.def("_mtia_deviceSynchronize", []() {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().deviceSynchronize(
at::detail::getMTIAHooks().getCurrentDevice());
});
m.def("_mtia_exchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().exchangeDevice(device_index);
});
m.def("_mtia_maybeExchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().maybeExchangeDevice(device_index);
});
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getDefaultStream(device_index);
});
m.def(
"_mtia_setStream",
[](int64_t stream_id,
c10::DeviceIndex device_index,
int64_t device_type) {
torch::utils::device_lazy_init(at::kMTIA);
at::detail::getMTIAHooks().setCurrentStream(c10::Stream::unpack3(
stream_id,
device_index,
static_cast<c10::DeviceType>(device_type)));
});
m.def("_mtia_setCurrentStream", [](const c10::Stream& stream) {
torch::utils::device_lazy_init(at::kMTIA);
auto device = at::detail::getMTIAHooks().getCurrentDevice();
if (device != stream.device_index()) {
at::detail::getMTIAHooks().setCurrentDevice(stream.device_index());
}
at::detail::getMTIAHooks().setCurrentStream(stream);
});
m.def("_mtia_memoryStats", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memoryStats(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceCapability", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceCapability(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_getDeviceProperties", [](c10::DeviceIndex device_index) {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().getDeviceProperties(device_index);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_emptyCache", []() { at::detail::getMTIAHooks().emptyCache(); });
m.def(
"_mtia_recordMemoryHistory",
[](const std::optional<std::string>& enabled,
const std::string& stacks,
size_t max_entries) {
at::detail::getMTIAHooks().recordMemoryHistory(
enabled, stacks, max_entries);
});
m.def("_mtia_memorySnapshot", []() {
PyObject* raw_pyobject =
at::detail::getMTIAHooks().memorySnapshot(std::nullopt);
return py::reinterpret_steal<py::object>(raw_pyobject);
});
m.def("_mtia_attachOutOfMemoryObserver", [](const py::function& observer) {
at::detail::getMTIAHooks().attachOutOfMemoryObserver(observer.ptr());
return;
});
m.def("_mtia_getDeviceCount", []() {
return at::detail::getMTIAHooks().deviceCount();
});
m.def("_mtia_resetPeakMemoryStats", [](c10::DeviceIndex device_index) {
at::detail::getMTIAHooks().resetPeakMemoryStats(device_index);
});
py::class_<_MTIAGraph>(m, "_MTIAGraph")
.def(py::init<bool>(), py::arg("keep_graph") = false)
.def("capture_begin", &_MTIAGraph::capture_begin)
.def("capture_end", &_MTIAGraph::capture_end)
.def("instantiate", &_MTIAGraph::instantiate)
.def("replay", &_MTIAGraph::replay)
.def("reset", &_MTIAGraph::reset)
.def("pool", &_MTIAGraph::pool);
}
} // namespace torch::mtia