[MTIA] Add _mtia_exchangeDevice to MTIA module (#149322)

Summary: The FlexAttention path uses `_exchange_device`, so it will be needed eventually for MTIA as well.

Test Plan: `buck2 test fbcode//mtia/host_runtime/torch_mtia/tests:test_torch_mtia_api -- test_exchange_device`

Reviewed By: chaos5958

Differential Revision: D70072059

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149322
Approved by: https://github.com/chaos5958
This commit is contained in:
Pat Vignola 2025-03-17 19:31:10 +00:00 committed by PyTorch MergeBot
parent 8d7c430e84
commit 769f19bf95
2 changed files with 17 additions and 0 deletions

View File

@ -60,6 +60,13 @@ void initModule(PyObject* module) {
at::detail::getMTIAHooks().getCurrentDevice()); at::detail::getMTIAHooks().getCurrentDevice());
}); });
m.def("_mtia_exchangeDevice", [](c10::DeviceIndex device_index) {
if (device_index < 0) {
return static_cast<c10::DeviceIndex>(-1);
}
return at::detail::getMTIAHooks().exchangeDevice(device_index);
});
m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) { m.def("_mtia_getDefaultStream", [](c10::DeviceIndex device_index) {
torch::utils::device_lazy_init(at::kMTIA); torch::utils::device_lazy_init(at::kMTIA);
return at::detail::getMTIAHooks().getDefaultStream(device_index); return at::detail::getMTIAHooks().getDefaultStream(device_index);

View File

@ -30,6 +30,16 @@ _initialization_lock = threading.Lock()
_lazy_seed_tracker = _LazySeedTracker() _lazy_seed_tracker = _LazySeedTracker()
if hasattr(torch._C, "_mtia_exchangeDevice"):
_exchange_device = torch._C._mtia_exchangeDevice
else:
def _exchange_device(device: int) -> int:
if device < 0:
return -1
raise RuntimeError("PyTorch was compiled without MTIA support")
def init(): def init():
_lazy_init() _lazy_init()