mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Fixes https://github.com/pytorch/xla/issues/3490. Updates `torch::lazy::BackendDevice` with changes below: 1. Remove the no-op string constructor. 2. Update default ordinal to `-1`. 3. Add a `is_valid` function to check if `ordinal` is valid/non-default (`ordinal >= 0`). Pull Request resolved: https://github.com/pytorch/pytorch/pull/76264 Reviewed By: mrshenli Differential Revision: D35860266 Pulled By: alanwaketan fbshipit-source-id: 554ebe16a0683d37b00270c4f35163bf690bfe28 (cherry picked from commit b941d10e8545dfecfb34e4d5c24a29a1cc49bc4b)
82 lines
2.6 KiB
C++
82 lines
2.6 KiB
C++
#pragma once
|
|
|
|
#include <ostream>
|
|
#include <memory>
|
|
#include <string>
|
|
|
|
#include <ATen/Tensor.h>
|
|
#include <c10/macros/Export.h>
|
|
#include <c10/util/Deprecated.h>
|
|
#include <c10/util/Optional.h>
|
|
|
|
namespace c10 {
|
|
struct Device;
|
|
}
|
|
|
|
namespace torch {
|
|
namespace lazy {
|
|
|
|
// Backend should extend it and define their own supported hardware types.
|
|
struct TORCH_API BackendDeviceType {
|
|
int8_t type {(int8_t)at::kCPU};
|
|
// Note: previous default value was '0', which actually maps to at::kCPU, at least now it is explicit,
|
|
// we may want to make default/undefined semantics more clear though
|
|
BackendDeviceType() :type((int8_t)at::kCPU) {}
|
|
BackendDeviceType(int8_t type) :type(type) {}
|
|
|
|
virtual ~BackendDeviceType() = default;
|
|
virtual std::string toString() const { return "Unknown"; }
|
|
};
|
|
|
|
class TORCH_API BackendDevice {
|
|
public:
|
|
// The default constructor will set both the device type and ordinal
|
|
// to backend specific defaults.
|
|
BackendDevice();
|
|
BackendDevice(std::shared_ptr<BackendDeviceType>&& type, int64_t ordinal);
|
|
|
|
int8_t type() const;
|
|
int64_t ordinal() const { return ordinal_; }
|
|
|
|
bool operator==(const BackendDevice& other) const { return compare(other) == 0; }
|
|
bool operator!=(const BackendDevice& other) const { return compare(other) != 0; }
|
|
bool operator<(const BackendDevice& rhs) const { return compare(rhs) < 0; }
|
|
|
|
bool has_index() const { return ordinal_ >= 0; }
|
|
|
|
std::string toString() const;
|
|
|
|
private:
|
|
int compare(const BackendDevice& rhs) const;
|
|
|
|
// Use shared_ptr instead of unique_ptr so that BackendDevice can be copied.
|
|
std::shared_ptr<BackendDeviceType> type_;
|
|
int64_t ordinal_ {0};
|
|
};
|
|
|
|
TORCH_API std::ostream& operator<<(std::ostream& os, const BackendDevice& device);
|
|
|
|
// Helpers for converting a c10::Device to BackendDevice and vice versa.
|
|
TORCH_API BackendDevice atenDeviceToBackendDevice(const c10::Device& device);
|
|
TORCH_API c10::Device backendDeviceToAtenDevice(const BackendDevice& device);
|
|
|
|
// Tries to extract the backend device out of the lazy tensor. Returns nullopt if the
|
|
// input is not a lazy tensor.
|
|
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::TensorList tensors);
|
|
TORCH_API c10::optional<BackendDevice> GetBackendDevice(const at::Tensor& tensor);
|
|
|
|
// For variadic template.
|
|
TORCH_API c10::optional<BackendDevice> GetBackendDevice();
|
|
|
|
template<typename T, typename... Args>
|
|
c10::optional<BackendDevice> GetBackendDevice(const T& tensor, const Args&... forward_tensors) {
|
|
auto optional_device = GetBackendDevice(tensor);
|
|
if (optional_device) {
|
|
return optional_device;
|
|
}
|
|
return GetBackendDevice(forward_tensors...);
|
|
}
|
|
|
|
} // namespace lazy
|
|
} // namespace torch
|