Support gpu trace on XPU (#121795)

# Motivation
Support GPU trace on XPU backend. Add GPU trace to xpu runtime. It is beneficial to generalize the device caching allocator in the next step.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121795
Approved by: https://github.com/EikanWang, https://github.com/gujinghui, https://github.com/jgong5, https://github.com/albanD
ghstack dependencies: #121794
This commit is contained in:
Yu, Guangye 2024-03-30 08:36:54 +00:00 committed by PyTorch MergeBot
parent eb7adc3ae0
commit b8550f527f
8 changed files with 218 additions and 4 deletions

View File

@ -2563,6 +2563,7 @@ exclude_patterns = [
'torch/utils/viz/__init__.py', 'torch/utils/viz/__init__.py',
'torch/utils/viz/_cycles.py', 'torch/utils/viz/_cycles.py',
'torch/utils/weak.py', 'torch/utils/weak.py',
'torch/xpu/_gpu_trace.py',
] ]
init_command = [ init_command = [
'python3', 'python3',

View File

@ -22,7 +22,15 @@ struct TORCH_XPU_API XPUEvent {
XPUEvent(bool enable_timing = false) noexcept XPUEvent(bool enable_timing = false) noexcept
: enable_timing_{enable_timing} {} : enable_timing_{enable_timing} {}
~XPUEvent() = default; ~XPUEvent() {
if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_deletion(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
}
}
XPUEvent(const XPUEvent&) = delete; XPUEvent(const XPUEvent&) = delete;
XPUEvent& operator=(const XPUEvent&) = delete; XPUEvent& operator=(const XPUEvent&) = delete;
@ -77,6 +85,13 @@ struct TORCH_XPU_API XPUEvent {
void record(const XPUStream& stream) { void record(const XPUStream& stream) {
if (!isCreated()) { if (!isCreated()) {
device_index_ = stream.device_index(); device_index_ = stream.device_index();
event_ = std::make_unique<sycl::event>(
stream.queue().ext_oneapi_submit_barrier());
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_creation(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
} else { } else {
TORCH_CHECK( TORCH_CHECK(
device_index_ == stream.device_index(), device_index_ == stream.device_index(),
@ -86,9 +101,16 @@ struct TORCH_XPU_API XPUEvent {
stream.device_index(), stream.device_index(),
"."); ".");
event_.reset(); event_.reset();
event_ = std::make_unique<sycl::event>(
stream.queue().ext_oneapi_submit_barrier());
}
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
} }
event_ = std::make_unique<sycl::event>(
stream.queue().ext_oneapi_submit_barrier());
} }
void block(const XPUStream& stream) { void block(const XPUStream& stream) {
@ -96,6 +118,13 @@ struct TORCH_XPU_API XPUEvent {
std::vector<sycl::event> event_list{event()}; std::vector<sycl::event> event_list{event()};
// Make this stream wait until event_ is completed. // Make this stream wait until event_ is completed.
stream.queue().ext_oneapi_submit_barrier(event_list); stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
at::kXPU,
reinterpret_cast<uintptr_t>(event_.get()),
reinterpret_cast<uintptr_t>(&stream.queue()));
}
} }
} }
@ -117,6 +146,11 @@ struct TORCH_XPU_API XPUEvent {
void synchronize() const { void synchronize() const {
if (isCreated()) { if (isCreated()) {
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_synchronization(
at::kXPU, reinterpret_cast<uintptr_t>(event_.get()));
}
event().wait_and_throw(); event().wait_and_throw();
} }
} }

View File

@ -467,6 +467,11 @@ class XPUAllocator : public Allocator {
Block* block = device_allocators[device]->malloc(device, size, queue); Block* block = device_allocators[device]->malloc(device, size, queue);
add_allocated_block(block); add_allocated_block(block);
*devPtr = block->ptr; *devPtr = block->ptr;
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_allocation(
c10::kXPU, reinterpret_cast<uintptr_t>(*devPtr));
}
} }
void free(void* ptr) { void free(void* ptr) {
@ -476,6 +481,11 @@ class XPUAllocator : public Allocator {
Block* block = get_allocated_block(ptr, /* remove */ true); Block* block = get_allocated_block(ptr, /* remove */ true);
TORCH_CHECK(block, "invalid device pointer: ", ptr); TORCH_CHECK(block, "invalid device pointer: ", ptr);
device_allocators[block->device]->free(block); device_allocators[block->device]->free(block);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_memory_deallocation(
c10::kXPU, reinterpret_cast<uintptr_t>(block->ptr));
}
} }
void emptyCache() { void emptyCache() {

View File

@ -103,11 +103,17 @@ void initDeviceStreamState(DeviceIndex device) {
{sycl::property::queue::in_order(), queue::priority_high()}}; {sycl::property::queue::in_order(), queue::priority_high()}};
for (const auto p : c10::irange(max_compile_time_stream_priorities)) { for (const auto p : c10::irange(max_compile_time_stream_priorities)) {
for (const auto i : c10::irange(kStreamsPerPool)) { for (const auto i : c10::irange(kStreamsPerPool)) {
streams[device][p][i] = std::make_unique<sycl::queue>(sycl::queue( auto& stream = streams[device][p][i];
stream = std::make_unique<sycl::queue>(sycl::queue(
c10::xpu::get_device_context(), c10::xpu::get_device_context(),
c10::xpu::get_raw_device(device), c10::xpu::get_raw_device(device),
c10::xpu::asyncHandler, c10::xpu::asyncHandler,
properties[p])); properties[p]));
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_creation(
c10::kXPU, reinterpret_cast<uintptr_t>(stream.get()));
}
} }
priority_counters[device][p] = 0; priority_counters[device][p] = 0;
} }
@ -280,6 +286,10 @@ void syncStreamsOnDevice(DeviceIndex device) {
streams[device][p][i]->wait(); streams[device][p][i]->wait();
} }
} }
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_device_synchronization(c10::kXPU);
}
} }
} // namespace c10::xpu } // namespace c10::xpu

View File

@ -1,6 +1,7 @@
#pragma once #pragma once
#include <c10/core/Stream.h> #include <c10/core/Stream.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/xpu/XPUFunctions.h> #include <c10/xpu/XPUFunctions.h>
namespace c10::xpu { namespace c10::xpu {
@ -96,6 +97,11 @@ class C10_XPU_API XPUStream {
/// stream. /// stream.
void synchronize() const { void synchronize() const {
queue().wait_and_throw(); queue().wait_and_throw();
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_stream_synchronization(
c10::kXPU, reinterpret_cast<uintptr_t>(&queue()));
}
} }
/// Return the priority that this stream is associated with. Lower numbers /// Return the priority that this stream is associated with. Lower numbers

View File

@ -2,6 +2,7 @@
#include <c10/core/DeviceGuard.h> #include <c10/core/DeviceGuard.h>
#include <c10/core/impl/DeviceGuardImplInterface.h> #include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/impl/GPUTrace.h>
#include <c10/xpu/XPUCachingAllocator.h> #include <c10/xpu/XPUCachingAllocator.h>
#include <c10/xpu/XPUFunctions.h> #include <c10/xpu/XPUFunctions.h>
#include <c10/xpu/XPUStream.h> #include <c10/xpu/XPUStream.h>
@ -84,6 +85,13 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
auto* xpu_event = reinterpret_cast<sycl::event*>(*event); auto* xpu_event = reinterpret_cast<sycl::event*>(*event);
const XPUStream xpu_stream{stream}; const XPUStream xpu_stream{stream};
*xpu_event = xpu_stream.queue().ext_oneapi_submit_barrier(); *xpu_event = xpu_stream.queue().ext_oneapi_submit_barrier();
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_record(
c10::kXPU,
reinterpret_cast<uintptr_t>(xpu_event),
reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
}
} }
void block(void* event, const Stream& stream) const override { void block(void* event, const Stream& stream) const override {
@ -93,6 +101,13 @@ struct XPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
std::vector<sycl::event> event_list{*xpu_event}; std::vector<sycl::event> event_list{*xpu_event};
const XPUStream xpu_stream(stream); const XPUStream xpu_stream(stream);
xpu_stream.queue().ext_oneapi_submit_barrier(event_list); xpu_stream.queue().ext_oneapi_submit_barrier(event_list);
const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace();
if (C10_UNLIKELY(interp)) {
(*interp)->trace_gpu_event_wait(
c10::kXPU,
reinterpret_cast<uintptr_t>(xpu_event),
reinterpret_cast<uintptr_t>(&xpu_stream.queue()));
}
} }
bool queryEvent(void* event) const override { bool queryEvent(void* event) const override {

View File

@ -4,6 +4,7 @@ import sys
import unittest import unittest
import torch import torch
import torch.xpu._gpu_trace as gpu_trace
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,
onlyXPU, onlyXPU,
@ -239,5 +240,67 @@ if __name__ == "__main__":
instantiate_device_type_tests(TestXpu, globals(), only_for="xpu") instantiate_device_type_tests(TestXpu, globals(), only_for="xpu")
class TestXpuTrace(TestCase):
def setUp(self):
torch._C._activate_gpu_trace()
self.mock = unittest.mock.MagicMock()
def test_event_creation_callback(self):
gpu_trace.register_callback_for_event_creation(self.mock)
event = torch.xpu.Event()
event.record()
self.mock.assert_called_once_with(event._as_parameter_.value)
def test_event_deletion_callback(self):
gpu_trace.register_callback_for_event_deletion(self.mock)
event = torch.xpu.Event()
event.record()
event_id = event._as_parameter_.value
del event
self.mock.assert_called_once_with(event_id)
def test_event_record_callback(self):
gpu_trace.register_callback_for_event_record(self.mock)
event = torch.xpu.Event()
event.record()
self.mock.assert_called_once_with(
event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
)
def test_event_wait_callback(self):
gpu_trace.register_callback_for_event_wait(self.mock)
event = torch.xpu.Event()
event.record()
event.wait()
self.mock.assert_called_once_with(
event._as_parameter_.value, torch.xpu.current_stream().sycl_queue
)
def test_device_synchronization_callback(self):
gpu_trace.register_callback_for_device_synchronization(self.mock)
torch.xpu.synchronize()
self.mock.assert_called()
def test_stream_synchronization_callback(self):
gpu_trace.register_callback_for_stream_synchronization(self.mock)
stream = torch.xpu.Stream()
stream.synchronize()
self.mock.assert_called_once_with(stream.sycl_queue)
def test_event_synchronization_callback(self):
gpu_trace.register_callback_for_event_synchronization(self.mock)
event = torch.xpu.Event()
event.record()
event.synchronize()
self.mock.assert_called_once_with(event._as_parameter_.value)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

75
torch/xpu/_gpu_trace.py Normal file
View File

@ -0,0 +1,75 @@
from typing import Callable
from torch._utils import CallbackRegistry
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU event creation"
)
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU event deletion"
)
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"XPU event record"
)
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"XPU event wait"
)
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU memory allocation"
)
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU memory deallocation"
)
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU stream creation"
)
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"XPU device synchronization"
)
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU stream synchronization"
)
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"XPU event synchronization"
)
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
EventCreationCallbacks.add_callback(cb)
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
EventDeletionCallbacks.add_callback(cb)
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
EventRecordCallbacks.add_callback(cb)
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
EventWaitCallbacks.add_callback(cb)
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
MemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
MemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
StreamCreationCallbacks.add_callback(cb)
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
DeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
StreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
EventSynchronizationCallbacks.add_callback(cb)