mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MTIA] Add _mtia_getCurrentRawStream to MTIA module (#149436)
Summary: The FlexAttention path generates code that uses this function. Although streams are not used yet in Triton-MTIA, adding this now allows us to not branch out just for MTIA and generate different code. Test Plan: CI Reviewed By: chaos5958 Differential Revision: D70072057 Pull Request resolved: https://github.com/pytorch/pytorch/pull/149436 Approved by: https://github.com/chaos5958
This commit is contained in:
parent
ef93cdfb8a
commit
42bd4a09a3
|
|
@ -82,6 +82,11 @@ struct TORCH_API MTIAHooksInterface : AcceleratorHooksInterface {
|
|||
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||
}
|
||||
|
||||
virtual int64_t getCurrentRawStream(DeviceIndex device) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
virtual c10::Stream getDefaultStream(DeviceIndex device) const {
|
||||
FAIL_MTIAHOOKS_FUNC(__func__);
|
||||
return c10::Stream::unpack3(-1, 0, c10::DeviceType::MTIA);
|
||||
|
|
|
|||
|
|
@ -54,6 +54,11 @@ void initModule(PyObject* module) {
|
|||
return at::detail::getMTIAHooks().getCurrentStream(device_index);
|
||||
});
|
||||
|
||||
m.def("_mtia_getCurrentRawStream", [](c10::DeviceIndex device_index) {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
return at::detail::getMTIAHooks().getCurrentRawStream(device_index);
|
||||
});
|
||||
|
||||
m.def("_mtia_deviceSynchronize", []() {
|
||||
torch::utils::device_lazy_init(at::kMTIA);
|
||||
at::detail::getMTIAHooks().deviceSynchronize(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user