mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add XLA / TPU device type, backend type and type id (#16763)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/16763 Replicate the easy bits in https://github.com/pytorch/pytorch/pull/15153 with TPU / XLA instead of MSNPU. Also don't initialize the storage for XLA tensors for now. Pull Request resolved: https://github.com/pytorch/pytorch/pull/16585 Reviewed By: ezyang Differential Revision: D13912118 Pulled By: gchanan fbshipit-source-id: 4889177e2478768fb281ed075b71146d1d850bd9
This commit is contained in:
parent
6efa40e07b
commit
9811a4220d
|
|
@ -75,6 +75,7 @@ enum class TypeID {
|
|||
SparseCUDALong,
|
||||
SparseCUDAShort,
|
||||
MSNPU,
|
||||
XLA,
|
||||
CPUComplexFloat,
|
||||
CPUComplexDouble,
|
||||
CUDAComplexFloat,
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ generators = {
|
|||
|
||||
backends = ['CPU', 'CUDA']
|
||||
densities = ['Dense', 'Sparse']
|
||||
extension_backends = ['MSNPU']
|
||||
extension_backends = ['MSNPU', 'XLA']
|
||||
|
||||
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
|
||||
scalar_types = [
|
||||
|
|
|
|||
|
|
@ -21,7 +21,8 @@ list(APPEND ATen_CPU_TEST_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/xla_tensor_test.cpp)
|
||||
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
|
|
|
|||
34
aten/src/ATen/test/xla_tensor_test.cpp
Normal file
34
aten/src/ATen/test/xla_tensor_test.cpp
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
void XLAFree(void *ptr) {
|
||||
free(ptr);
|
||||
}
|
||||
|
||||
void* XLAMalloc(ptrdiff_t size) {
|
||||
return malloc(size);
|
||||
}
|
||||
|
||||
struct XLAAllocator final : public at::Allocator {
|
||||
at::DataPtr allocate(size_t size) const override {
|
||||
auto* ptr = XLAMalloc(size);
|
||||
return {ptr, ptr, &XLAFree, at::DeviceType::XLA};
|
||||
}
|
||||
at::DeleterFnPtr raw_deleter() const override {
|
||||
return &XLAFree;
|
||||
}
|
||||
};
|
||||
|
||||
TEST(XlaTensorTest, TestNoStorage) {
|
||||
XLAAllocator allocator;
|
||||
auto storage = Storage(caffe2::TypeMeta::Make<float>(), 0, &allocator, true);
|
||||
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
||||
std::move(storage),
|
||||
XLATensorId(),
|
||||
/*is_variable=*/false);
|
||||
at::Tensor t(std::move(tensor_impl));
|
||||
ASSERT_TRUE(t.device() == DeviceType::XLA);
|
||||
}
|
||||
|
|
@ -18,6 +18,7 @@ VALGRIND=${VALGRIND:=ON}
|
|||
./tensor_interop_test
|
||||
./undefined_tensor_test
|
||||
./extension_backend_test
|
||||
./xla_tensor_test
|
||||
if [[ -x ./cudnn_test ]]; then
|
||||
./cudnn_test
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ namespace c10 {
|
|||
* would make sense in your use case. If it doesn't make sense, maybe
|
||||
* you want DeviceType.
|
||||
*/
|
||||
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions };
|
||||
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, XLA, Undefined, NumOptions };
|
||||
|
||||
static inline Backend toSparse(Backend b) {
|
||||
switch (b) {
|
||||
|
|
@ -51,6 +51,8 @@ static inline Backend toDense(Backend b) {
|
|||
return Backend::HIP;
|
||||
case Backend::MSNPU:
|
||||
return Backend::MSNPU;
|
||||
case Backend::XLA:
|
||||
return Backend::XLA;
|
||||
case Backend::SparseCPU:
|
||||
return Backend::CPU;
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -71,6 +73,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
|
|||
return Backend::HIP;
|
||||
} else if (t == MSNPUTensorId()) {
|
||||
return Backend::MSNPU;
|
||||
} else if (t == XLATensorId()) {
|
||||
return Backend::XLA;
|
||||
} else if (t == SparseCPUTensorId()) {
|
||||
return Backend::SparseCPU;
|
||||
} else if (t == SparseCUDATensorId()) {
|
||||
|
|
@ -94,6 +98,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) {
|
|||
return HIPTensorId();
|
||||
case Backend::MSNPU:
|
||||
return MSNPUTensorId();
|
||||
case Backend::XLA:
|
||||
return XLATensorId();
|
||||
case Backend::SparseCPU:
|
||||
return SparseCPUTensorId();
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -117,6 +123,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
|||
return DeviceType::HIP;
|
||||
case Backend::MSNPU:
|
||||
return DeviceType::MSNPU;
|
||||
case Backend::XLA:
|
||||
return DeviceType::XLA;
|
||||
case Backend::SparseCPU:
|
||||
return DeviceType::CPU;
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -140,6 +148,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) {
|
|||
return Backend::HIP;
|
||||
case DeviceType::MSNPU:
|
||||
return Backend::MSNPU;
|
||||
case DeviceType::XLA:
|
||||
return Backend::XLA;
|
||||
default:
|
||||
AT_ERROR("Unknown device type ", d);
|
||||
}
|
||||
|
|
@ -160,6 +170,7 @@ static inline Backend backendToCPU(Backend b) {
|
|||
case Backend::SparseHIP:
|
||||
return Backend::SparseCPU;
|
||||
case Backend::MSNPU:
|
||||
case Backend::XLA:
|
||||
return Backend::CPU;
|
||||
case Backend::Undefined:
|
||||
return Backend::Undefined;
|
||||
|
|
@ -174,6 +185,7 @@ static inline Backend backendToCUDA(Backend b) {
|
|||
case Backend::CUDA:
|
||||
case Backend::HIP:
|
||||
case Backend::MSNPU:
|
||||
case Backend::XLA:
|
||||
return Backend::CUDA;
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -192,6 +204,7 @@ static inline Backend backendToHIP(Backend b) {
|
|||
case Backend::CUDA:
|
||||
case Backend::HIP:
|
||||
case Backend::MSNPU:
|
||||
case Backend::XLA:
|
||||
return Backend::HIP;
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -208,6 +221,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
|
|||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
|
||||
static inline const char* toString(Backend b) {
|
||||
switch (b) {
|
||||
|
|
@ -219,6 +233,8 @@ static inline const char* toString(Backend b) {
|
|||
return "HIP";
|
||||
case Backend::MSNPU:
|
||||
return "MSNPU";
|
||||
case Backend::XLA:
|
||||
return "XLA";
|
||||
case Backend::SparseCPU:
|
||||
return "SparseCPU";
|
||||
case Backend::SparseCUDA:
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
|||
return lower_case ? "fpga" : "FPGA";
|
||||
case DeviceType::MSNPU:
|
||||
return lower_case ? "msnpu" : "MSNPU";
|
||||
case DeviceType::XLA:
|
||||
return lower_case ? "xla" : "XLA";
|
||||
default:
|
||||
AT_ERROR(
|
||||
"Unknown device: ",
|
||||
|
|
@ -56,6 +58,7 @@ bool isValidDeviceType(DeviceType d) {
|
|||
case DeviceType::HIP:
|
||||
case DeviceType::FPGA:
|
||||
case DeviceType::MSNPU:
|
||||
case DeviceType::XLA:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -22,11 +22,12 @@ enum class DeviceType : int16_t {
|
|||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MSNPU = 8, // MSNPU
|
||||
XLA = 9, // XLA / TPU
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 9,
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 10,
|
||||
ONLY_FOR_TEST = 20901, // This device type is only for test.
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
int64_t get_device() const {
|
||||
// NB: This method is not virtual and tries to avoid dispatches in the common case for perf.
|
||||
const auto tid = type_id();
|
||||
if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
|
||||
if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() || tid == XLATensorId()) {
|
||||
// TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch.
|
||||
return storage().device().index();
|
||||
}
|
||||
|
|
@ -369,7 +369,8 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
// TODO: This is a little convoluted so it would be good to investigate
|
||||
// caching device on TensorImpl (#12934) to speed up device() calls in all cases.
|
||||
const auto tid = type_id();
|
||||
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
|
||||
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId() ||
|
||||
tid == XLATensorId()) {
|
||||
// NB: storage(), not storage_, b/c of Variable.
|
||||
const auto& mystorage = storage();
|
||||
if (mystorage) {
|
||||
|
|
|
|||
|
|
@ -511,6 +511,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) {
|
|||
return HIPTensorId();
|
||||
case DeviceType::MSNPU:
|
||||
return MSNPUTensorId();
|
||||
case DeviceType::XLA:
|
||||
return XLATensorId();
|
||||
default:
|
||||
AT_ERROR("Unsupported device type for dense layout: ", options.device().type());
|
||||
}
|
||||
|
|
@ -549,6 +551,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) {
|
|||
return DeviceType::HIP;
|
||||
} else if (tid == MSNPUTensorId()) {
|
||||
return DeviceType::MSNPU;
|
||||
} else if (tid == XLATensorId()) {
|
||||
return DeviceType::XLA;
|
||||
} else if (tid == SparseCPUTensorId()) {
|
||||
return DeviceType::CPU;
|
||||
} else if (tid == SparseCUDATensorId()) {
|
||||
|
|
|
|||
|
|
@ -69,5 +69,6 @@ C10_DEFINE_TENSOR_TYPE(IDEEPTensorId);
|
|||
C10_DEFINE_TENSOR_TYPE(HIPTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(MSNPUTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(XLATensorId);
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -108,6 +108,7 @@ C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only
|
|||
C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported
|
||||
C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only
|
||||
C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only
|
||||
C10_DECLARE_TENSOR_TYPE(XLATensorId); // PyTorch only
|
||||
|
||||
} // namespace c10
|
||||
|
||||
|
|
|
|||
|
|
@ -179,8 +179,9 @@ enum DeviceTypeProto {
|
|||
PROTO_HIP = 6; // AMD HIP
|
||||
PROTO_FPGA = 7; // FPGA
|
||||
PROTO_MSNPU = 8; // MSNPU
|
||||
PROTO_XLA = 9; // XLA / TPU
|
||||
// Change the following number if you add more devices in the code.
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9;
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 10;
|
||||
PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test.
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user