pytorch/torch/csrc/lazy/backend/backend_device.h
Wonjoo Lee f9d07ae644 Update torch::lazy::BackendDevice to have a new default ordinal (#76264)
Summary:
Fixes https://github.com/pytorch/xla/issues/3490. Updates `torch::lazy::BackendDevice` with changes below:

1. Remove the no-op string constructor.
2. Update default ordinal to `-1`.
3. Add a `is_valid` function to check if `ordinal` is valid/non-default (`ordinal >= 0`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76264

Reviewed By: mrshenli

Differential Revision: D35860266

Pulled By: alanwaketan

fbshipit-source-id: 554ebe16a0683d37b00270c4f35163bf690bfe28
(cherry picked from commit b941d10e8545dfecfb34e4d5c24a29a1cc49bc4b)
2022-04-25 23:57:18 +00:00

82 lines
2.6 KiB
C++

#pragma once
#include <ostream>
#include <memory>
#include <string>
#include <ATen/Tensor.h>
#include <c10/macros/Export.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Optional.h>
namespace c10 {
struct Device;
}
namespace torch {
namespace lazy {
// Backend should extend it and define their own supported hardware types.
struct TORCH_API BackendDeviceType {
int8_t type {(int8_t)at::kCPU};
// Note: previous default value was '0', which actually maps to at::kCPU, at least now it is explicit,
// we may want to make default/undefined semantics more clear though
BackendDeviceType() :type((int8_t)at::kCPU) {}
BackendDeviceType(int8_t type) :type(type) {}
virtual ~BackendDeviceType() = default;
virtual std::string toString() const { return "Unknown"; }
};
class TORCH_API BackendDevice {
public:
// The default constructor will set both the device type and ordinal
// to backend specific defaults.
BackendDevice();
BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
int8_t type() const;
int64_t ordinal() const { return ordinal_; }
bool operator==(const BackendDevice& other) const { return compare(other) == 0; }
bool operator!=(const BackendDevice& other) const { return compare(other) != 0; }
bool operator<(const BackendDevice& rhs) const { return compare(rhs) < 0; }
bool has_index() const { return ordinal_ >= 0; }
std::string toString() const;
private:
int compare(const BackendDevice& rhs) const;
// Use shared_ptr instead of unique_ptr so that BackendDevice can be copied.
std::shared_ptr<BackendDeviceType> type_;
int64_t ordinal_ {0};
};
TORCH_API std::ostream& operator<<(std::ostream& os, const BackendDevice& device);
// Helpers for converting a c10::Device to BackendDevice and vice versa.
TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device);
TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
// Tries to extract the backend device out of the lazy tensor. Returns nullopt if the
// input is not a lazy tensor.
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::TensorList tensors);
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor);
// For variadic template.
TORCH_API c10::optional<BackendDevice> GetBackendDevice();
template<typename T, typename... Args>
c10::optional<BackendDevice> GetBackendDevice(const T& tensor, const Args&... forward_tensors) {
auto optional_device = GetBackendDevice(tensor);
if (optional_device) {
return optional_device;
}
return GetBackendDevice(forward_tensors...);
}
} // namespace lazy
} // namespace torch