pytorch/c10/core/impl/VirtualGuardImpl.h
Yu, Guangye 40c098f731 Introduce a device-agnostic runtime API design (#132204)
# Motivation
According to [[RFC]A device-agnostic Python runtime API design for stream-based accelerators](https://github.com/pytorch/pytorch/issues/128403), this PR intends to introduce a device-agnostic runtime API design.
I personally prefer the **Simple Version** APIs that no longer accept the device type as an input argument. It means we will leverage `getAccelerator` to fetch the current accelerator. And it is flexible to expand these APIs to handle multiple types of accelerator scenarios. The design does **NOT** break the previous design philosophies.
I also believe that namespace torch.accelerator is better. It lets users know that the APIs they are calling are running on an accelerator rather than CPU. This is important. Meanwhile, we can follow a simple API design principle:
1. Device-agnostic APIs should be placed under the torch.accelerator namespace and not accept a device_type optional parameter.
2. Device-specific APIs should be placed under device-specific submodules.
3. APIS required by both CPU and accelerators should be placed under the torch namespace and accept a device_type optional parameter.

Also, I list the pros and cons of **Simple Version** here:
Pros:
- `torch.accelerator.foo` will have the same input argument as `torch.xxx.foo`, bringing a better user experience;
- more concise, facilitate the developer to write a device-agnostic code.

Cons:
- no obvious drawbacks.

# Additional Context
I list the new APIs here:
```python
torch.accelerator.is_available() -> bool:
torch.accelerator.current_accelerator() -> torch.device:
torch.accelerator.device_count() -> int:
torch.accelerator.current_device_idx() -> int:
torch.accelerator.set_device_idx(device: Union[torch.device, str, int, None]) -> None:
torch.accelerator.current_stream(device: Union[torch.device, str, int, None]) -> torch.Stream:
torch.accelerator.set_stream(stream: torch.Stream) -> None:
torch.accelerator.synchronize(device: Union[torch.device, str, int, None]) -> None:
```
According to the discussion with Alban, we decide to change the API name `set_device` to `set_device_idx` and `current_device` to `current_device_idx` for more explicit. And will submit other PR to support device and stream context manager.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132204
Approved by: https://github.com/EikanWang, https://github.com/abhilash1910, https://github.com/gujinghui, https://github.com/albanD
2024-10-27 10:37:09 +00:00

108 lines
3.2 KiB
C++

#pragma once
#include <c10/core/impl/DeviceGuardImplInterface.h>
namespace c10::impl {
/**
* An implementation of DeviceGuardImplInterface which delegates
* to virtual dispatch on the DeviceGuardImpl registry.
*/
class VirtualGuardImpl final : public DeviceGuardImplInterface {
public:
VirtualGuardImpl(DeviceType device_type)
: impl_(getDeviceGuardImpl(device_type)) {}
// This constructor exists purely for testing
VirtualGuardImpl(const DeviceGuardImplInterface* impl) : impl_(impl) {}
// Copying and moving is OK!
VirtualGuardImpl(const VirtualGuardImpl&) = default;
VirtualGuardImpl& operator=(const VirtualGuardImpl&) = default;
VirtualGuardImpl(VirtualGuardImpl&&) noexcept = default;
VirtualGuardImpl& operator=(VirtualGuardImpl&&) noexcept = default;
DeviceType type() const override {
return impl_->type();
}
Device exchangeDevice(Device d) const override {
return impl_->exchangeDevice(d);
}
Device getDevice() const override {
return impl_->getDevice();
}
void setDevice(Device d) const override {
impl_->setDevice(d);
}
void uncheckedSetDevice(Device d) const noexcept override {
impl_->uncheckedSetDevice(d);
}
Stream getStream(Device d) const noexcept override {
return impl_->getStream(d);
}
Stream getNewStream(Device d, int priority = 0) const override {
return impl_->getNewStream(d, priority);
}
Stream getDefaultStream(Device d) const override {
return impl_->getDefaultStream(d);
}
Stream getStreamFromGlobalPool(Device d, bool isHighPriority = false)
const override {
return impl_->getStreamFromGlobalPool(d, isHighPriority);
}
Stream exchangeStream(Stream s) const noexcept override {
return impl_->exchangeStream(s);
}
DeviceIndex deviceCount() const noexcept override {
return impl_->deviceCount();
}
// Event functions
void record(
void** event,
const Stream& stream,
const DeviceIndex device_index,
const EventFlag flag) const override {
impl_->record(event, stream, device_index, flag);
}
void block(void* event, const Stream& stream) const override {
impl_->block(event, stream);
}
bool queryEvent(void* event) const override {
return impl_->queryEvent(event);
}
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override {
impl_->destroyEvent(event, device_index);
}
bool queryStream(const Stream& stream) const override {
return impl_->queryStream(stream);
}
void synchronizeStream(const Stream& stream) const override {
impl_->synchronizeStream(stream);
}
void recordDataPtrOnStream(const c10::DataPtr& data_ptr, const Stream& stream)
const override {
impl_->recordDataPtrOnStream(data_ptr, stream);
}
double elapsedTime(void* event1, void* event2, const DeviceIndex device_index)
const override {
return impl_->elapsedTime(event1, event2, device_index);
}
void synchronizeEvent(void* event) const override {
return impl_->synchronizeEvent(event);
}
void synchronizeDevice(const DeviceIndex device_index) const override {
return impl_->synchronizeDevice(device_index);
}
private:
const DeviceGuardImplInterface* impl_ = nullptr;
};
} // namespace c10::impl