Add support for the ONNX Runtime Eager Mode backend (#58248)

Summary:
This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort.

We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends).

The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248

Reviewed By: astaff

Differential Revision: D30344992

Pulled By: albanD

fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2
This commit is contained in:
Aaron Bockover 2021-08-20 11:11:47 -07:00 committed by Facebook GitHub Bot
parent b95ce1591d
commit c78ab28441
38 changed files with 236 additions and 120 deletions

View File

@ -9,6 +9,7 @@
#include <ATen/core/LegacyTypeDispatch.h> #include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/CUDAHooksInterface.h> #include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h> #include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/ORTHooksInterface.h>
#include <c10/util/Exception.h> #include <c10/util/Exception.h>
#include <c10/core/impl/DeviceGuardImplInterface.h> #include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/QEngine.h> #include <c10/core/QEngine.h>
@ -79,6 +80,9 @@ class TORCH_API Context {
static bool hasMLC() { static bool hasMLC() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC); return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
} }
static bool hasORT() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
}
// defined in header so that getNonVariableType has ability to inline // defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently // call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() { THCState* lazyInitCUDA() {
@ -292,6 +296,10 @@ static inline bool hasMLC() {
return globalContext().hasMLC(); return globalContext().hasMLC();
} }
static inline bool hasORT() {
return globalContext().hasORT();
}
// Despite its name, this function returns the number of *CUDA* GPUs. // Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() { static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS

View File

@ -184,6 +184,10 @@ std::string show_config() {
ss << detail::getCUDAHooks().showConfig(); ss << detail::getCUDAHooks().showConfig();
} }
if (hasORT()) {
ss << detail::getORTHooks().showConfig();
}
ss << " - Build settings: "; ss << " - Build settings: ";
for (const auto& pair : caffe2::GetBuildOptions()) { for (const auto& pair : caffe2::GetBuildOptions()) {
if (!pair.second.empty()) { if (!pair.second.empty()) {

View File

@ -405,6 +405,7 @@ _(aten, is_complex) \
_(aten, is_contiguous) \ _(aten, is_contiguous) \
_(aten, is_cuda) \ _(aten, is_cuda) \
_(aten, is_mlc) \ _(aten, is_mlc) \
_(aten, is_ort) \
_(aten, is_distributed) \ _(aten, is_distributed) \
_(aten, is_floating_point) \ _(aten, is_floating_point) \
_(aten, is_inference) \ _(aten, is_inference) \

View File

@ -13,13 +13,13 @@ Theres four main use cases
* Youre writing a new operator that isnt supposed to be part of the public PyTorch API. * Youre writing a new operator that isnt supposed to be part of the public PyTorch API.
* Youre writing a new operator but dont want to change the core pytorch code base, say youre developing a shared library with operators. * Youre writing a new operator but dont want to change the core pytorch code base, say youre developing a shared library with operators.
* Youre writing a C++ extension for PyTorch or youre using inline c++ in your .py model files. * Youre writing a C++ extension for PyTorch or youre using inline c++ in your .py model files.
* Youre writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`. * Youre writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.
For these use cases, the custom operator API is the better solution. For these use cases, the custom operator API is the better solution.
### What is the price for using the custom operator API instead of `native_functions.yaml`? ### What is the price for using the custom operator API instead of `native_functions.yaml`?
If youre just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then youre fine and dont pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats. If youre just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then youre fine and dont pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
* It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator. * It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
* The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`. * The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.

View File

@ -0,0 +1,31 @@
#include <ATen/detail/ORTHooksInterface.h>
#include <c10/util/Exception.h>
#include <cstddef>
#include <memory>
#include <mutex>
namespace at {
namespace detail {
// See getCUDAHooks for some more commentary
const ORTHooksInterface& getORTHooks() {
static std::unique_ptr<ORTHooksInterface> ort_hooks;
static std::once_flag once;
std::call_once(once, [] {
ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {});
if (!ort_hooks) {
ort_hooks =
// NOLINTNEXTLINE(modernize-make-unique)
std::unique_ptr<ORTHooksInterface>(new ORTHooksInterface());
}
});
return *ort_hooks;
}
} // namespace detail
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs)
} // namespace at

View File

@ -0,0 +1,36 @@
#pragma once
#include <c10/util/Exception.h>
#include <c10/util/Registry.h>
constexpr const char* ORT_HELP =
" You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
"The 'torch_ort' module is provided by the ONNX Runtime itself "
"(https://onnxruntime.ai).";
// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {
struct TORCH_API ORTHooksInterface {
// This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor
virtual ~ORTHooksInterface() {}
virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
}
};
// NB: dummy argument to suppress "ISO C++11 requires at least one argument
// for the "..." in a variadic macro"
struct TORCH_API ORTHooksArgs {};
C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
#define REGISTER_ORT_HOOKS(clsname) \
C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
namespace detail {
TORCH_API const ORTHooksInterface& getORTHooks();
} // namespace detail
} // namespace at

View File

@ -492,6 +492,12 @@ class TORCH_API Tensor {
return impl_->is_mlc(); return impl_->is_mlc();
} }
/// Returns if a `Tensor` is ort tensor.
bool is_ort() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_ort();
}
/// Returns if a `Tensor` is vulkan tensor. /// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const { bool is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead. // NB: this is not a native function to avoid dispatching overhead.

View File

@ -6,6 +6,11 @@
#include <torch/csrc/jit/runtime/operator.h> #include <torch/csrc/jit/runtime/operator.h>
// NB. These tests use the ORT dispatch key to test backend dispatching
// machinery, but these tests are not specific to ORT at all. The ORT
// backend is fully out-of-tree, so it's safe to use this key for
// in-tree tests.
using namespace at; using namespace at;
static int test_int; static int test_int;
@ -17,16 +22,16 @@ Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::op
Storage( Storage(
Storage::use_byte_size_t(), Storage::use_byte_size_t(),
0, 0,
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
nullptr, nullptr,
false), false),
DispatchKey::MSNPU, DispatchKey::ORT,
caffe2::TypeMeta::Make<float>()); caffe2::TypeMeta::Make<float>());
return Tensor(std::move(tensor_impl)); return Tensor(std::move(tensor_impl));
} }
Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) { Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
auto out = empty({5, 5}, at::kMSNPU); // Don't return self as-is auto out = empty({5, 5}, at::kORT); // Don't return self as-is
test_int = 2; test_int = 2;
return out; return out;
} }
@ -42,28 +47,28 @@ Tensor empty_strided_override(
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt); return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
} }
TORCH_LIBRARY_IMPL(aten, MSNPU, m) { TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("aten::empty.memory_format", empty_override); m.impl("aten::empty.memory_format", empty_override);
m.impl("aten::empty_strided", empty_strided_override); m.impl("aten::empty_strided", empty_strided_override);
m.impl("aten::add.Tensor", add_override); m.impl("aten::add.Tensor", add_override);
} }
TEST(BackendExtensionTest, TestRegisterOp) { TEST(BackendExtensionTest, TestRegisterOp) {
Tensor a = empty({5, 5}, at::kMSNPU); Tensor a = empty({5, 5}, at::kORT);
ASSERT_EQ(a.device().type(), at::kMSNPU); ASSERT_EQ(a.device().type(), at::kORT);
ASSERT_EQ(a.device().index(), 1); ASSERT_EQ(a.device().index(), 1);
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>()); ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
ASSERT_EQ(test_int, 1); ASSERT_EQ(test_int, 1);
Tensor b = empty_like(a, at::kMSNPU); Tensor b = empty_like(a, at::kORT);
ASSERT_EQ(b.device().type(), at::kMSNPU); ASSERT_EQ(b.device().type(), at::kORT);
ASSERT_EQ(b.device().index(), 1); ASSERT_EQ(b.device().index(), 1);
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>()); ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
add(a, b); add(a, b);
ASSERT_EQ(test_int, 2); ASSERT_EQ(test_int, 2);
// Ensure that non-MSNPU operator still works // Ensure that non-ORT operator still works
Tensor d = empty({5, 5}, at::kCPU); Tensor d = empty({5, 5}, at::kCPU);
ASSERT_EQ(d.device().type(), at::kCPU); ASSERT_EQ(d.device().type(), at::kCPU);
} }

View File

@ -40,7 +40,7 @@ enum class Backend {
SparseHIP, SparseHIP,
SparseVE, SparseVE,
SparseXPU, SparseXPU,
MSNPU, ORT,
XLA, XLA,
Vulkan, Vulkan,
Metal, Metal,
@ -66,8 +66,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::VE; return Backend::VE;
} else if (t == DispatchKey::FPGA) { } else if (t == DispatchKey::FPGA) {
return Backend::FPGA; return Backend::FPGA;
} else if (t == DispatchKey::MSNPU) { } else if (t == DispatchKey::ORT) {
return Backend::MSNPU; return Backend::ORT;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) { } else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA; return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) { } else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
@ -123,8 +123,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::VE; return DispatchKey::VE;
case Backend::FPGA: case Backend::FPGA:
return DispatchKey::FPGA; return DispatchKey::FPGA;
case Backend::MSNPU: case Backend::ORT:
return DispatchKey::MSNPU; return DispatchKey::ORT;
case Backend::XLA: case Backend::XLA:
return DispatchKey::XLA; return DispatchKey::XLA;
case Backend::Lazy: case Backend::Lazy:
@ -178,8 +178,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::VE; return DeviceType::VE;
case Backend::FPGA: case Backend::FPGA:
return DeviceType::FPGA; return DeviceType::FPGA;
case Backend::MSNPU: case Backend::ORT:
return DeviceType::MSNPU; return DeviceType::ORT;
case Backend::XLA: case Backend::XLA:
return DeviceType::XLA; return DeviceType::XLA;
case Backend::Lazy: case Backend::Lazy:
@ -235,8 +235,8 @@ static inline const char* toString(Backend b) {
return "FPGA"; return "FPGA";
case Backend::XPU: case Backend::XPU:
return "XPU"; return "XPU";
case Backend::MSNPU: case Backend::ORT:
return "MSNPU"; return "ORT";
case Backend::XLA: case Backend::XLA:
return "XLA"; return "XLA";
case Backend::Lazy: case Backend::Lazy:

View File

@ -28,7 +28,7 @@ DeviceType parse_type(const std::string& device_string) {
{"hip", DeviceType::HIP}, {"hip", DeviceType::HIP},
{"ve", DeviceType::VE}, {"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA}, {"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU}, {"ort", DeviceType::ORT},
{"xla", DeviceType::XLA}, {"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy}, {"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan}, {"vulkan", DeviceType::Vulkan},
@ -47,7 +47,7 @@ DeviceType parse_type(const std::string& device_string) {
} }
TORCH_CHECK( TORCH_CHECK(
false, false,
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ", "Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
device_string); device_string);
} }
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR }; enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };

View File

@ -25,8 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "ve" : "VE"; return lower_case ? "ve" : "VE";
case DeviceType::FPGA: case DeviceType::FPGA:
return lower_case ? "fpga" : "FPGA"; return lower_case ? "fpga" : "FPGA";
case DeviceType::MSNPU: case DeviceType::ORT:
return lower_case ? "msnpu" : "MSNPU"; return lower_case ? "ort" : "ORT";
case DeviceType::XLA: case DeviceType::XLA:
return lower_case ? "xla" : "XLA"; return lower_case ? "xla" : "XLA";
case DeviceType::Lazy: case DeviceType::Lazy:
@ -75,7 +75,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::HIP: case DeviceType::HIP:
case DeviceType::VE: case DeviceType::VE:
case DeviceType::FPGA: case DeviceType::FPGA:
case DeviceType::MSNPU: case DeviceType::ORT:
case DeviceType::XLA: case DeviceType::XLA:
case DeviceType::Lazy: case DeviceType::Lazy:
case DeviceType::MLC: case DeviceType::MLC:

View File

@ -21,7 +21,7 @@ enum class DeviceType : int8_t {
IDEEP = 5, // IDEEP. IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP HIP = 6, // AMD HIP
FPGA = 7, // FPGA FPGA = 7, // FPGA
MSNPU = 8, // MSNPU ORT = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan Vulkan = 10, // Vulkan
Metal = 11, // Metal Metal = 11, // Metal
@ -42,7 +42,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA; constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP; constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA; constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMSNPU = DeviceType::MSNPU; constexpr DeviceType kORT = DeviceType::ORT;
constexpr DeviceType kXLA = DeviceType::XLA; constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMLC = DeviceType::MLC; constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kMeta = DeviceType::Meta; constexpr DeviceType kMeta = DeviceType::Meta;

View File

@ -19,8 +19,8 @@ const char* toString(DispatchKey t) {
return "FPGA"; return "FPGA";
case DispatchKey::XPU: case DispatchKey::XPU:
return "XPU"; return "XPU";
case DispatchKey::MSNPU: case DispatchKey::ORT:
return "MSNPU"; return "ORT";
case DispatchKey::XLA: case DispatchKey::XLA:
return "XLA"; return "XLA";
case DispatchKey::Lazy: case DispatchKey::Lazy:

View File

@ -59,8 +59,15 @@ enum class DispatchKey : uint8_t {
// CUDA] // CUDA]
FPGA, // Xilinx support lives out of tree at FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels // https://gitlab.com/pytorch-complex/vitis_kernels
MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp // ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
// https://github.com/microsoft/onnxruntime, and is also used to test general
// backend/extension machinery in the core. cf:
// - test/cpp_extensions/ort_extension.cpp
// - test/test_torch.py
// - aten/src/ATen/test/extension_backend_test.cpp
ORT,
XLA, // lives out of tree at https://github.com/pytorch/xla XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan, Vulkan,
@ -114,7 +121,7 @@ enum class DispatchKey : uint8_t {
// Here are reserved backends for user-defined backends, see Note [Private use // Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey] // DispatchKey]
// To see some example about how to use this, check out MSNPU // To see some example about how to use this, check out ORT
PrivateUse1, PrivateUse1,
PrivateUse2, PrivateUse2,
PrivateUse3, PrivateUse3,

View File

@ -19,6 +19,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKey::PrivateUse3, DispatchKey::PrivateUse3,
DispatchKey::MLC, DispatchKey::MLC,
DispatchKey::HPU, DispatchKey::HPU,
DispatchKey::ORT,
DispatchKey::Meta, DispatchKey::Meta,
}); });

View File

@ -248,7 +248,7 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
{DispatchKey::HIP, {DispatchKey::HIP,
DispatchKey::VE, DispatchKey::VE,
DispatchKey::FPGA, DispatchKey::FPGA,
DispatchKey::MSNPU, DispatchKey::ORT,
DispatchKey::Vulkan, DispatchKey::Vulkan,
DispatchKey::Metal, DispatchKey::Metal,
DispatchKey::QuantizedCPU, DispatchKey::QuantizedCPU,

View File

@ -873,6 +873,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return key_set_.has(DispatchKey::MLC); return key_set_.has(DispatchKey::MLC);
} }
bool is_ort() const {
return key_set_.has(DispatchKey::ORT);
}
// TODO: remove this once we don't automatically enabled Autograd dispatch // TODO: remove this once we don't automatically enabled Autograd dispatch
// keys // keys
// in TensorImpl constructor. // in TensorImpl constructor.

View File

@ -663,8 +663,8 @@ inline DispatchKey computeDispatchKey(
return DispatchKey::VE; return DispatchKey::VE;
case DeviceType::FPGA: case DeviceType::FPGA:
return DispatchKey::FPGA; return DispatchKey::FPGA;
case DeviceType::MSNPU: case DeviceType::ORT:
return DispatchKey::MSNPU; return DispatchKey::ORT;
case DeviceType::XLA: case DeviceType::XLA:
return DispatchKey::XLA; return DispatchKey::XLA;
case DeviceType::Lazy: case DeviceType::Lazy:
@ -790,10 +790,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
case DispatchKey::HPU: case DispatchKey::HPU:
case DispatchKey::AutogradHPU: case DispatchKey::AutogradHPU:
return DeviceType::HPU; return DeviceType::HPU;
case DispatchKey::ORT:
// stuff that isn't real return DeviceType::ORT;
case DispatchKey::MSNPU:
return DeviceType::MSNPU;
default: default:
TORCH_CHECK( TORCH_CHECK(
false, false,

View File

@ -219,7 +219,7 @@ enum DeviceTypeProto {
PROTO_IDEEP = 5; // IDEEP. PROTO_IDEEP = 5; // IDEEP.
PROTO_HIP = 6; // AMD HIP PROTO_HIP = 6; // AMD HIP
PROTO_FPGA = 7; // FPGA PROTO_FPGA = 7; // FPGA
PROTO_MSNPU = 8; // MSNPU PROTO_ORT = 8; // ONNX Runtime
PROTO_XLA = 9; // XLA / TPU PROTO_XLA = 9; // XLA / TPU
PROTO_MLC = 10; // ML Compute PROTO_MLC = 10; // ML Compute
// Change the following number if you add more devices in the code. // Change the following number if you add more devices in the code.

View File

@ -23,7 +23,7 @@ class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapp
PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_IDEEP = DeviceTypeProto.V(5)
PROTO_HIP = DeviceTypeProto.V(6) PROTO_HIP = DeviceTypeProto.V(6)
PROTO_FPGA = DeviceTypeProto.V(7) PROTO_FPGA = DeviceTypeProto.V(7)
PROTO_MSNPU = DeviceTypeProto.V(8) PROTO_ORT = DeviceTypeProto.V(8)
PROTO_XLA = DeviceTypeProto.V(9) PROTO_XLA = DeviceTypeProto.V(9)
PROTO_MLC = DeviceTypeProto.V(10) PROTO_MLC = DeviceTypeProto.V(10)
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
@ -37,7 +37,7 @@ PROTO_OPENCL = DeviceTypeProto.V(4)
PROTO_IDEEP = DeviceTypeProto.V(5) PROTO_IDEEP = DeviceTypeProto.V(5)
PROTO_HIP = DeviceTypeProto.V(6) PROTO_HIP = DeviceTypeProto.V(6)
PROTO_FPGA = DeviceTypeProto.V(7) PROTO_FPGA = DeviceTypeProto.V(7)
PROTO_MSNPU = DeviceTypeProto.V(8) PROTO_ORT = DeviceTypeProto.V(8)
PROTO_XLA = DeviceTypeProto.V(9) PROTO_XLA = DeviceTypeProto.V(9)
PROTO_MLC = DeviceTypeProto.V(10) PROTO_MLC = DeviceTypeProto.V(10)
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11) PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)

View File

@ -10,10 +10,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
Storage( Storage(
Storage::use_byte_size_t(), Storage::use_byte_size_t(),
0, 0,
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), at::DataPtr(nullptr, Device(DeviceType::ORT, 0)),
nullptr, nullptr,
false), false),
DispatchKey::MSNPU, DispatchKey::ORT,
dtype); dtype);
// This is a hack to workaround the shape checks in _convolution. // This is a hack to workaround the shape checks in _convolution.
tensor_impl->set_sizes_contiguous(size); tensor_impl->set_sizes_contiguous(size);
@ -52,7 +52,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
get_tensor(input.dtype(), {})); get_tensor(input.dtype(), {}));
} }
TORCH_LIBRARY_IMPL(aten, MSNPU, m) { TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("empty.memory_format", empty_override); m.impl("empty.memory_format", empty_override);
m.impl("add.out", add_out_override); m.impl("add.out", add_out_override);
m.impl("convolution_overrideable", fake_convolution); m.impl("convolution_overrideable", fake_convolution);
@ -61,34 +61,34 @@ TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
// TODO: Extend this to exercise multi-device setting. In that case, // TODO: Extend this to exercise multi-device setting. In that case,
// we need to add a thread local variable to track the current device. // we need to add a thread local variable to track the current device.
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface { struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface {
static constexpr DeviceType static_type = DeviceType::MSNPU; static constexpr DeviceType static_type = DeviceType::ORT;
MSNPUGuardImpl() {} ORTGuardImpl() {}
MSNPUGuardImpl(DeviceType t) { ORTGuardImpl(DeviceType t) {
AT_ASSERT(t == DeviceType::MSNPU); AT_ASSERT(t == DeviceType::ORT);
} }
DeviceType type() const override { DeviceType type() const override {
return DeviceType::MSNPU; return DeviceType::ORT;
} }
Device exchangeDevice(Device d) const override { Device exchangeDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU); AT_ASSERT(d.type() == DeviceType::ORT);
AT_ASSERT(d.index() == 0); AT_ASSERT(d.index() == 0);
return d; return d;
} }
Device getDevice() const override { Device getDevice() const override {
return Device(DeviceType::MSNPU, 0); return Device(DeviceType::ORT, 0);
} }
void setDevice(Device d) const override { void setDevice(Device d) const override {
AT_ASSERT(d.type() == DeviceType::MSNPU); AT_ASSERT(d.type() == DeviceType::ORT);
AT_ASSERT(d.index() == 0); AT_ASSERT(d.index() == 0);
} }
void uncheckedSetDevice(Device d) const noexcept override { void uncheckedSetDevice(Device d) const noexcept override {
} }
Stream getStream(Device d) const noexcept override { Stream getStream(Device d) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0)); return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
} }
Stream exchangeStream(Stream s) const noexcept override { Stream exchangeStream(Stream s) const noexcept override {
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0)); return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
} }
DeviceIndex deviceCount() const noexcept override { DeviceIndex deviceCount() const noexcept override {
return 1; return 1;
@ -99,23 +99,23 @@ struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
const Stream& stream, const Stream& stream,
const DeviceIndex device_index, const DeviceIndex device_index,
const EventFlag flag) const override { const EventFlag flag) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events."); TORCH_CHECK(false, "ORT backend doesn't support events.");
} }
void block( void block(
void* event, void* event,
const Stream& stream) const override { const Stream& stream) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events."); TORCH_CHECK(false, "ORT backend doesn't support events.");
} }
bool queryEvent(void* event) const override { bool queryEvent(void* event) const override {
TORCH_CHECK(false, "MSNPU backend doesn't support events."); TORCH_CHECK(false, "ORT backend doesn't support events.");
} }
void destroyEvent( void destroyEvent(
void* event, void* event,
const DeviceIndex device_index) const noexcept override { } const DeviceIndex device_index) const noexcept override { }
}; };
constexpr DeviceType MSNPUGuardImpl::static_type; constexpr DeviceType ORTGuardImpl::static_type;
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl); C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl);
int get_test_int() { int get_test_int() {
return test_int; return test_int;

View File

@ -21,7 +21,7 @@ ext_modules = [
'torch_test_cpp_extension.cpp', ['extension.cpp'], 'torch_test_cpp_extension.cpp', ['extension.cpp'],
extra_compile_args=CXX_FLAGS), extra_compile_args=CXX_FLAGS),
CppExtension( CppExtension(
'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'], 'torch_test_cpp_extension.ort', ['ort_extension.cpp'],
extra_compile_args=CXX_FLAGS), extra_compile_args=CXX_FLAGS),
CppExtension( CppExtension(
'torch_test_cpp_extension.rng', ['rng_extension.cpp'], 'torch_test_cpp_extension.rng', ['rng_extension.cpp'],

View File

@ -19,11 +19,11 @@ except ImportError as e:
try: try:
if HAS_PYTEST: if HAS_PYTEST:
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp") cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu") ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng") rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
else: else:
import torch_test_cpp_extension.cpp as cpp_extension import torch_test_cpp_extension.cpp as cpp_extension
import torch_test_cpp_extension.msnpu as msnpu_extension import torch_test_cpp_extension.ort as ort_extension
import torch_test_cpp_extension.rng as rng_extension import torch_test_cpp_extension.rng as rng_extension
except ImportError as e: except ImportError as e:
raise RuntimeError( raise RuntimeError(
@ -100,45 +100,45 @@ class TestCppExtensionAOT(common.TestCase):
self.assertFalse(has_value) self.assertFalse(has_value)
class TestMSNPUTensor(common.TestCase): class TestORTTensor(common.TestCase):
def test_unregistered(self): def test_unregistered(self):
a = torch.arange(0, 10, device='cpu') a = torch.arange(0, 10, device='cpu')
with self.assertRaisesRegex(RuntimeError, "Could not run"): with self.assertRaisesRegex(RuntimeError, "Could not run"):
b = torch.arange(0, 10, device='msnpu') b = torch.arange(0, 10, device='ort')
def test_zeros(self): def test_zeros(self):
a = torch.empty(5, 5, device='cpu') a = torch.empty(5, 5, device='cpu')
self.assertEqual(a.device, torch.device('cpu')) self.assertEqual(a.device, torch.device('cpu'))
b = torch.empty(5, 5, device='msnpu') b = torch.empty(5, 5, device='ort')
self.assertEqual(b.device, torch.device('msnpu', 0)) self.assertEqual(b.device, torch.device('ort', 0))
self.assertEqual(msnpu_extension.get_test_int(), 0) self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.get_default_dtype(), b.dtype) self.assertEqual(torch.get_default_dtype(), b.dtype)
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu') c = torch.empty((5, 5), dtype=torch.int64, device='ort')
self.assertEqual(msnpu_extension.get_test_int(), 0) self.assertEqual(ort_extension.get_test_int(), 0)
self.assertEqual(torch.int64, c.dtype) self.assertEqual(torch.int64, c.dtype)
def test_add(self): def test_add(self):
a = torch.empty(5, 5, device='msnpu', requires_grad=True) a = torch.empty(5, 5, device='ort', requires_grad=True)
self.assertEqual(msnpu_extension.get_test_int(), 0) self.assertEqual(ort_extension.get_test_int(), 0)
b = torch.empty(5, 5, device='msnpu') b = torch.empty(5, 5, device='ort')
self.assertEqual(msnpu_extension.get_test_int(), 0) self.assertEqual(ort_extension.get_test_int(), 0)
c = a + b c = a + b
self.assertEqual(msnpu_extension.get_test_int(), 1) self.assertEqual(ort_extension.get_test_int(), 1)
def test_conv_backend_override(self): def test_conv_backend_override(self):
# To simplify tests, we use 4d input here to avoid doing view4d( which # To simplify tests, we use 4d input here to avoid doing view4d( which
# needs more overrides) in _convolution. # needs more overrides) in _convolution.
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True) input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True) weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
bias = torch.empty(6, device='msnpu') bias = torch.empty(6, device='ort')
# Make sure forward is overriden # Make sure forward is overriden
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1) out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
self.assertEqual(msnpu_extension.get_test_int(), 2) self.assertEqual(ort_extension.get_test_int(), 2)
self.assertEqual(out.shape[0], input.shape[0]) self.assertEqual(out.shape[0], input.shape[0])
self.assertEqual(out.shape[1], weight.shape[0]) self.assertEqual(out.shape[1], weight.shape[0])
@ -146,7 +146,7 @@ class TestMSNPUTensor(common.TestCase):
# Double backward is dispatched to _convolution_double_backward. # Double backward is dispatched to _convolution_double_backward.
# It is not tested here as it involves more computation/overrides. # It is not tested here as it involves more computation/overrides.
grad = torch.autograd.grad(out, input, out, create_graph=True) grad = torch.autograd.grad(out, input, out, create_graph=True)
self.assertEqual(msnpu_extension.get_test_int(), 3) self.assertEqual(ort_extension.get_test_int(), 3)
self.assertEqual(grad[0].shape, input.shape) self.assertEqual(grad[0].shape, input.shape)

View File

@ -138,11 +138,11 @@ supported:
self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''') self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''')
# The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case. # The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case.
# Only using MSNPU here because it has a valid backend key but not an autograd key- if this changes we can update the test. # Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test.
def test_backend_has_no_autograd_key_but_provides_entries(self): def test_backend_has_no_autograd_key_but_provides_entries(self):
yaml_str = '''\ yaml_str = '''\
backend: MSNPU backend: Vulkan
cpp_namespace: torch_msnpu cpp_namespace: torch_vulkan
supported: supported:
- add - add
autograd: autograd:
@ -155,7 +155,7 @@ autograd:
def test_backend_autograd_kernel_mismatch_out_functional(self): def test_backend_autograd_kernel_mismatch_out_functional(self):
yaml_str = '''\ yaml_str = '''\
backend: XLA backend: XLA
cpp_namespace: torch_msnpu cpp_namespace: torch_xla
supported: supported:
- add.Tensor - add.Tensor
autograd: autograd:
@ -168,7 +168,7 @@ autograd:
def test_backend_autograd_kernel_mismatch_functional_inplace(self): def test_backend_autograd_kernel_mismatch_functional_inplace(self):
yaml_str = '''\ yaml_str = '''\
backend: XLA backend: XLA
cpp_namespace: torch_msnpu cpp_namespace: torch_xla
supported: supported:
- add.Tensor - add.Tensor
autograd: autograd:
@ -182,7 +182,7 @@ autograd:
def test_op_appears_in_supported_and_autograd_lists(self): def test_op_appears_in_supported_and_autograd_lists(self):
yaml_str = '''\ yaml_str = '''\
backend: XLA backend: XLA
cpp_namespace: torch_msnpu cpp_namespace: torch_xla
supported: supported:
- add.Tensor - add.Tensor
autograd: autograd:

View File

@ -221,10 +221,10 @@ class AbstractTestCases:
# TODO: add torch.* tests when we have proper namespacing on ATen functions # TODO: add torch.* tests when we have proper namespacing on ATen functions
# test_namespace(torch) # test_namespace(torch)
def test_msnpu_error(self): def test_ort_error(self):
with self.assertRaisesRegex(RuntimeError, with self.assertRaisesRegex(RuntimeError,
"Could not run 'aten::empty.memory_format' with arguments from the 'MSNPU' backend"): "Could not run 'aten::empty.memory_format' with arguments from the 'ORT' backend"):
torch.zeros(1, device=torch.device('msnpu')) torch.zeros(1, device=torch.device('ort'))
def test_has_storage(self): def test_has_storage(self):
self.assertIsNotNone(torch.tensor([]).storage()) self.assertIsNotNone(torch.tensor([]).storage())

View File

@ -829,6 +829,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/detail/CPUGuardImpl.cpp", "aten/src/ATen/detail/CPUGuardImpl.cpp",
"aten/src/ATen/detail/CUDAHooksInterface.cpp", "aten/src/ATen/detail/CUDAHooksInterface.cpp",
"aten/src/ATen/detail/HIPHooksInterface.cpp", "aten/src/ATen/detail/HIPHooksInterface.cpp",
"aten/src/ATen/detail/ORTHooksInterface.cpp",
"aten/src/ATen/metal/Context.cpp", "aten/src/ATen/metal/Context.cpp",
"aten/src/ATen/native/AutogradComposite.cpp", "aten/src/ATen/native/AutogradComposite.cpp",
"aten/src/ATen/native/BatchLinearAlgebraKernel.cpp", "aten/src/ATen/native/BatchLinearAlgebraKernel.cpp",

View File

@ -56,7 +56,7 @@ class DispatchKey(Enum):
CUDA = auto() CUDA = auto()
HIP = auto() HIP = auto()
FPGA = auto() FPGA = auto()
MSNPU = auto() ORT = auto()
XLA = auto() XLA = auto()
Lazy = auto() Lazy = auto()
Vulkan = auto() Vulkan = auto()

View File

@ -469,6 +469,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'is_sparse_csr' : ['is_sparse_csr: _bool'], 'is_sparse_csr' : ['is_sparse_csr: _bool'],
'is_quantized': ['is_quantized: _bool'], 'is_quantized': ['is_quantized: _bool'],
'is_meta': ['is_meta: _bool'], 'is_meta': ['is_meta: _bool'],
'is_ort': ['is_ort: _bool'],
'is_mkldnn': ['is_mkldnn: _bool'], 'is_mkldnn': ['is_mkldnn: _bool'],
'is_vulkan': ['is_vulkan: _bool'], 'is_vulkan': ['is_vulkan: _bool'],
'storage_offset': ['def storage_offset(self) -> _int: ...'], 'storage_offset': ['def storage_offset(self) -> _int: ...'],

View File

@ -24,7 +24,7 @@ class DeviceType(Enum):
IDEEP = ... IDEEP = ...
HIP = ... HIP = ...
FPGA = ... FPGA = ...
MSNPU = ... ORT = ...
XLA = ... XLA = ...
MLC = ... MLC = ...
HPU = ... HPU = ...

View File

@ -90,7 +90,7 @@ class Tensor(torch._C._TensorBase):
# does accurate alias tracking; however, the code below # does accurate alias tracking; however, the code below
# doesn't work because of # doesn't work because of
# https://github.com/pytorch/pytorch/issues/47442 # https://github.com/pytorch/pytorch/issues/47442
if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']: if self.is_sparse or self.device.type in ['xla', 'mlc', 'ort', 'meta']:
new_tensor = self.clone() new_tensor = self.clone()
else: else:
new_storage = self.storage().__deepcopy__(memo) new_storage = self.storage().__deepcopy__(memo)
@ -153,28 +153,21 @@ class Tensor(torch._C._TensorBase):
# See Note [Don't serialize hooks] # See Note [Don't serialize hooks]
torch.utils.hooks.warn_if_has_hooks(self) torch.utils.hooks.warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict() backward_hooks: Dict[Any, Any] = OrderedDict()
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor. # Note: Numpy array is chosen to be the rebuild component for XLA, ORT, MLC Tensors.
# We considered a few options: # We considered a few options:
# 1. CPU tensor can't be used here. # 1. CPU tensor can't be used here.
# Otherwise in torch.load CPU storage is reconstructed with randomly # Otherwise in torch.load CPU storage is reconstructed with randomly
# initialized data, moved onto XLA device, and then storage is updated # initialized data, moved onto backend device, and then storage is updated
# to the serialized content. This works perfectly for CPU/CUDA but not XLA. # to the serialized content. This works perfectly for CPU/CUDA but not these backends;
# XLA tensor is disconnected with storage so it doesn't get the update. # their tensors are disconnected with storage so they don't get the update.
# 2. Python list is not a good fit due to performance reason. # 2. Python list is not a good fit due to performance reason.
# `tolist()` converts every single element in the tensor into python objects # `tolist()` converts every single element in the tensor into python objects
# and serialize them one by one. # and serialize them one by one.
if self.device.type == 'xla': if self.device.type in ['xla', 'ort', 'mlc']:
arg_xla = (self.cpu().numpy(), return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
self.dtype, self.dtype,
str(self.device), str(self.device),
self.requires_grad) self.requires_grad))
return (torch._utils._rebuild_xla_tensor, arg_xla)
if self.device.type == 'mlc':
arg_mlc = (self.cpu().numpy(),
self.dtype,
str(self.device),
self.requires_grad)
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
if self.device.type == 'meta': if self.device.type == 'meta':
# NB: This implementation BREAKS storage sharing. Current # NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors. # hypothesis is that no one cares for meta tensors.

View File

@ -173,16 +173,15 @@ def _rebuild_sparse_tensor(layout, data):
raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout)) raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
def _rebuild_xla_tensor(data, dtype, device, requires_grad): def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
tensor = torch.from_numpy(data).to(dtype=dtype, device=device) tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
tensor.requires_grad = requires_grad tensor.requires_grad = requires_grad
return tensor return tensor
def _rebuild_mlc_tensor(data, dtype, device, requires_grad): # Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
tensor = torch.from_numpy(data).to(dtype=dtype, device=device) _rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
tensor.requires_grad = requires_grad _rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy
return tensor
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad): def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):

View File

@ -17,6 +17,6 @@ inline bool THPDevice_Check(PyObject *obj) {
return Py_TYPE(obj) == &THPDeviceType; return Py_TYPE(obj) == &THPDeviceType;
} }
PyObject * THPDevice_New(const at::Device& device); TORCH_API PyObject * THPDevice_New(const at::Device& device);
void THPDevice_init(PyObject *module); TORCH_API void THPDevice_init(PyObject *module);

View File

@ -114,7 +114,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
.value("IDEEP", c10::DeviceType::IDEEP) .value("IDEEP", c10::DeviceType::IDEEP)
.value("HIP", c10::DeviceType::HIP) .value("HIP", c10::DeviceType::HIP)
.value("FPGA", c10::DeviceType::FPGA) .value("FPGA", c10::DeviceType::FPGA)
.value("MSNPU", c10::DeviceType::MSNPU) .value("ORT", c10::DeviceType::ORT)
.value("XLA", c10::DeviceType::XLA) .value("XLA", c10::DeviceType::XLA)
.value("Lazy", c10::DeviceType::Lazy) .value("Lazy", c10::DeviceType::Lazy)
.value("MLC", c10::DeviceType::MLC) .value("MLC", c10::DeviceType::MLC)

View File

@ -834,6 +834,17 @@ PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
END_HANDLE_TH_ERRORS END_HANDLE_TH_ERRORS
} }
PyObject *THPVariable_is_ort(THPVariable *self, void *unused)
{
HANDLE_TH_ERRORS
if (check_has_torch_function((PyObject *)self)) {
return handle_torch_function_getter(self, "is_ort");
}
auto& self_ = THPVariable_Unpack(self);
return torch::autograd::utils::wrap(self_.is_ort());
END_HANDLE_TH_ERRORS
}
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused) PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
{ {
HANDLE_TH_ERRORS HANDLE_TH_ERRORS
@ -980,6 +991,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr}, {"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr}, {"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr}, {"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
{"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr},
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr}, {"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr}, {"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr}, {"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},

View File

@ -119,7 +119,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
{"layout", "prim"}, {"T", "prim"}, {"layout", "prim"}, {"T", "prim"},
{"ndim", "prim"}, {"name", "prim"}, {"ndim", "prim"}, {"name", "prim"},
{"real", "aten"}, {"imag", "aten"}, {"real", "aten"}, {"imag", "aten"},
{"retains_grad", "aten"}, {"retains_grad", "aten"}, {"is_ort", "prim"},
}}, }},
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}}; {TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
auto kind = value_->type()->kind(); auto kind = value_->type()->kind();

View File

@ -2211,6 +2211,14 @@ RegisterOperators reg1(
push(stack, a.is_meta()); push(stack, a.is_meta());
}, },
aliasAnalysisFromSchema()), aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"),
[](Stack* stack) {
at::Tensor a;
pop(stack, a);
push(stack, a.is_ort());
},
aliasAnalysisFromSchema()),
OperatorGenerator( OperatorGenerator(
TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"), TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"),
[](Stack* stack) { [](Stack* stack) {

View File

@ -317,8 +317,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
return c10::DispatchKey::Meta; return c10::DispatchKey::Meta;
case c10::DeviceType::HIP: case c10::DeviceType::HIP:
return c10::DispatchKey::HIP; return c10::DispatchKey::HIP;
case c10::DeviceType::MSNPU: case c10::DeviceType::ORT:
return c10::DispatchKey::MSNPU; return c10::DispatchKey::ORT;
case c10::DeviceType::HPU: case c10::DeviceType::HPU:
return c10::DispatchKey::HPU; return c10::DispatchKey::HPU;
default: default:

View File

@ -1030,6 +1030,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
Tensor.retains_grad.__get__: lambda self: -1, Tensor.retains_grad.__get__: lambda self: -1,
Tensor.is_meta.__get__: lambda self: -1, Tensor.is_meta.__get__: lambda self: -1,
Tensor.is_mlc.__get__: lambda self: -1, Tensor.is_mlc.__get__: lambda self: -1,
Tensor.is_ort.__get__: lambda self: -1,
Tensor.is_mkldnn.__get__: lambda self: -1, Tensor.is_mkldnn.__get__: lambda self: -1,
Tensor.is_quantized.__get__: lambda self: -1, Tensor.is_quantized.__get__: lambda self: -1,
Tensor.is_sparse.__get__: lambda self: -1, Tensor.is_sparse.__get__: lambda self: -1,