pytorch/torch/csrc/lazy/backend/backend_device.h
cyy 799521fae5 Fixes 96676 (#96714)
Fixes #96676

PR #95942 introduced some changes in function implementations to replace const parameters by const referenced ones. However, GetBackendDevice was missed and  remains the old signature. This quick fix solves the type mismatch.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96714
Approved by: https://github.com/antoniojkim, https://github.com/Skylion007
2023-03-14 19:00:59 +00:00

101 lines
2.8 KiB
C++

#pragma once
#include <memory>
#include <ostream>
#include <string>
#include <ATen/Tensor.h>
#include <c10/macros/Export.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Optional.h>
namespace c10 {
struct Device;
}
namespace torch {
namespace lazy {
// Backend should extend it and define their own supported hardware types.
struct TORCH_API BackendDeviceType {
int8_t type{(int8_t)at::kCPU};
// Note: previous default value was '0', which actually maps to at::kCPU, at
// least now it is explicit, we may want to make default/undefined semantics
// more clear though
BackendDeviceType() : type((int8_t)at::kCPU) {}
BackendDeviceType(int8_t type) : type(type) {}
virtual ~BackendDeviceType() = default;
virtual std::string toString() const {
return "Unknown";
}
};
class TORCH_API BackendDevice {
public:
// The default constructor will set both the device type and ordinal
// to backend specific defaults.
BackendDevice();
BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
int8_t type() const;
int64_t ordinal() const {
return ordinal_;
}
bool operator==(const BackendDevice& other) const {
return compare(other) == 0;
}
bool operator!=(const BackendDevice& other) const {
return compare(other) != 0;
}
bool operator<(const BackendDevice& rhs) const {
return compare(rhs) < 0;
}
std::string toString() const;
private:
int compare(const BackendDevice& rhs) const;
// Use shared_ptr instead of unique_ptr so that BackendDevice can be copied.
std::shared_ptr<BackendDeviceType> type_;
int64_t ordinal_;
};
TORCH_API std::ostream& operator<<(
std::ostream& os,
const BackendDevice& device);
// Helpers for converting a c10::Device to BackendDevice and vice versa.
TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device);
TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
// Tries to extract the backend device out of the lazy tensor. Returns nullopt
// if the input is not a lazy tensor.
TORCH_API c10::optional<BackendDevice> GetBackendDevice(
const at::ITensorListRef tensors);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(
const at::TensorList tensors);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(
const at::Tensor& tensor);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(
const c10::optional<c10::Device>& device);
// For variadic template.
TORCH_API c10::optional<BackendDevice> GetBackendDevice();
template <typename T, typename... Args>
c10::optional<BackendDevice> GetBackendDevice(
const T& tensor,
const Args&... forward_tensors) {
auto optional_device = GetBackendDevice(tensor);
if (optional_device) {
return optional_device;
}
return GetBackendDevice(forward_tensors...);
}
} // namespace lazy
} // namespace torch