mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add support for the ONNX Runtime Eager Mode backend (#58248)
Summary: This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort. We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends). The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective). Pull Request resolved: https://github.com/pytorch/pytorch/pull/58248 Reviewed By: astaff Differential Revision: D30344992 Pulled By: albanD fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2
This commit is contained in:
parent
b95ce1591d
commit
c78ab28441
|
|
@ -9,6 +9,7 @@
|
|||
#include <ATen/core/LegacyTypeDispatch.h>
|
||||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
#include <ATen/detail/HIPHooksInterface.h>
|
||||
#include <ATen/detail/ORTHooksInterface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/core/QEngine.h>
|
||||
|
|
@ -79,6 +80,9 @@ class TORCH_API Context {
|
|||
static bool hasMLC() {
|
||||
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
|
||||
}
|
||||
static bool hasORT() {
|
||||
return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
|
||||
}
|
||||
// defined in header so that getNonVariableType has ability to inline
|
||||
// call_once check. getNonVariableType is called fairly frequently
|
||||
THCState* lazyInitCUDA() {
|
||||
|
|
@ -292,6 +296,10 @@ static inline bool hasMLC() {
|
|||
return globalContext().hasMLC();
|
||||
}
|
||||
|
||||
static inline bool hasORT() {
|
||||
return globalContext().hasORT();
|
||||
}
|
||||
|
||||
// Despite its name, this function returns the number of *CUDA* GPUs.
|
||||
static inline size_t getNumGPUs() {
|
||||
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
|
||||
|
|
|
|||
|
|
@ -184,6 +184,10 @@ std::string show_config() {
|
|||
ss << detail::getCUDAHooks().showConfig();
|
||||
}
|
||||
|
||||
if (hasORT()) {
|
||||
ss << detail::getORTHooks().showConfig();
|
||||
}
|
||||
|
||||
ss << " - Build settings: ";
|
||||
for (const auto& pair : caffe2::GetBuildOptions()) {
|
||||
if (!pair.second.empty()) {
|
||||
|
|
|
|||
|
|
@ -405,6 +405,7 @@ _(aten, is_complex) \
|
|||
_(aten, is_contiguous) \
|
||||
_(aten, is_cuda) \
|
||||
_(aten, is_mlc) \
|
||||
_(aten, is_ort) \
|
||||
_(aten, is_distributed) \
|
||||
_(aten, is_floating_point) \
|
||||
_(aten, is_inference) \
|
||||
|
|
|
|||
|
|
@ -13,13 +13,13 @@ There’s four main use cases
|
|||
* You’re writing a new operator that isn’t supposed to be part of the public PyTorch API.
|
||||
* You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators.
|
||||
* You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files.
|
||||
* You’re writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`.
|
||||
* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.
|
||||
|
||||
For these use cases, the custom operator API is the better solution.
|
||||
|
||||
### What is the price for using the custom operator API instead of `native_functions.yaml`?
|
||||
|
||||
If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
|
||||
If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
|
||||
|
||||
* It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
|
||||
* The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.
|
||||
|
|
|
|||
31
aten/src/ATen/detail/ORTHooksInterface.cpp
Normal file
31
aten/src/ATen/detail/ORTHooksInterface.cpp
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
#include <ATen/detail/ORTHooksInterface.h>
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
namespace at {
|
||||
namespace detail {
|
||||
|
||||
// See getCUDAHooks for some more commentary
|
||||
const ORTHooksInterface& getORTHooks() {
|
||||
static std::unique_ptr<ORTHooksInterface> ort_hooks;
|
||||
static std::once_flag once;
|
||||
std::call_once(once, [] {
|
||||
ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {});
|
||||
if (!ort_hooks) {
|
||||
ort_hooks =
|
||||
// NOLINTNEXTLINE(modernize-make-unique)
|
||||
std::unique_ptr<ORTHooksInterface>(new ORTHooksInterface());
|
||||
}
|
||||
});
|
||||
return *ort_hooks;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
|
||||
C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs)
|
||||
|
||||
} // namespace at
|
||||
36
aten/src/ATen/detail/ORTHooksInterface.h
Normal file
36
aten/src/ATen/detail/ORTHooksInterface.h
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
constexpr const char* ORT_HELP =
|
||||
" You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
|
||||
"The 'torch_ort' module is provided by the ONNX Runtime itself "
|
||||
"(https://onnxruntime.ai).";
|
||||
|
||||
// NB: Class must live in `at` due to limitations of Registry.h.
|
||||
namespace at {
|
||||
|
||||
struct TORCH_API ORTHooksInterface {
|
||||
// This should never actually be implemented, but it is used to
|
||||
// squelch -Werror=non-virtual-dtor
|
||||
virtual ~ORTHooksInterface() {}
|
||||
|
||||
virtual std::string showConfig() const {
|
||||
TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
|
||||
}
|
||||
};
|
||||
|
||||
// NB: dummy argument to suppress "ISO C++11 requires at least one argument
|
||||
// for the "..." in a variadic macro"
|
||||
struct TORCH_API ORTHooksArgs {};
|
||||
|
||||
C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
|
||||
#define REGISTER_ORT_HOOKS(clsname) \
|
||||
C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)
|
||||
|
||||
namespace detail {
|
||||
TORCH_API const ORTHooksInterface& getORTHooks();
|
||||
} // namespace detail
|
||||
|
||||
} // namespace at
|
||||
|
|
@ -492,6 +492,12 @@ class TORCH_API Tensor {
|
|||
return impl_->is_mlc();
|
||||
}
|
||||
|
||||
/// Returns if a `Tensor` is ort tensor.
|
||||
bool is_ort() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
return impl_->is_ort();
|
||||
}
|
||||
|
||||
/// Returns if a `Tensor` is vulkan tensor.
|
||||
bool is_vulkan() const {
|
||||
// NB: this is not a native function to avoid dispatching overhead.
|
||||
|
|
|
|||
|
|
@ -6,6 +6,11 @@
|
|||
|
||||
#include <torch/csrc/jit/runtime/operator.h>
|
||||
|
||||
// NB. These tests use the ORT dispatch key to test backend dispatching
|
||||
// machinery, but these tests are not specific to ORT at all. The ORT
|
||||
// backend is fully out-of-tree, so it's safe to use this key for
|
||||
// in-tree tests.
|
||||
|
||||
using namespace at;
|
||||
|
||||
static int test_int;
|
||||
|
|
@ -17,16 +22,16 @@ Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::op
|
|||
Storage(
|
||||
Storage::use_byte_size_t(),
|
||||
0,
|
||||
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)),
|
||||
at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
|
||||
nullptr,
|
||||
false),
|
||||
DispatchKey::MSNPU,
|
||||
DispatchKey::ORT,
|
||||
caffe2::TypeMeta::Make<float>());
|
||||
return Tensor(std::move(tensor_impl));
|
||||
}
|
||||
|
||||
Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
|
||||
auto out = empty({5, 5}, at::kMSNPU); // Don't return self as-is
|
||||
auto out = empty({5, 5}, at::kORT); // Don't return self as-is
|
||||
test_int = 2;
|
||||
return out;
|
||||
}
|
||||
|
|
@ -42,28 +47,28 @@ Tensor empty_strided_override(
|
|||
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
||||
TORCH_LIBRARY_IMPL(aten, ORT, m) {
|
||||
m.impl("aten::empty.memory_format", empty_override);
|
||||
m.impl("aten::empty_strided", empty_strided_override);
|
||||
m.impl("aten::add.Tensor", add_override);
|
||||
}
|
||||
|
||||
TEST(BackendExtensionTest, TestRegisterOp) {
|
||||
Tensor a = empty({5, 5}, at::kMSNPU);
|
||||
ASSERT_EQ(a.device().type(), at::kMSNPU);
|
||||
Tensor a = empty({5, 5}, at::kORT);
|
||||
ASSERT_EQ(a.device().type(), at::kORT);
|
||||
ASSERT_EQ(a.device().index(), 1);
|
||||
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
|
||||
ASSERT_EQ(test_int, 1);
|
||||
|
||||
Tensor b = empty_like(a, at::kMSNPU);
|
||||
ASSERT_EQ(b.device().type(), at::kMSNPU);
|
||||
Tensor b = empty_like(a, at::kORT);
|
||||
ASSERT_EQ(b.device().type(), at::kORT);
|
||||
ASSERT_EQ(b.device().index(), 1);
|
||||
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());
|
||||
|
||||
add(a, b);
|
||||
ASSERT_EQ(test_int, 2);
|
||||
|
||||
// Ensure that non-MSNPU operator still works
|
||||
// Ensure that non-ORT operator still works
|
||||
Tensor d = empty({5, 5}, at::kCPU);
|
||||
ASSERT_EQ(d.device().type(), at::kCPU);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ enum class Backend {
|
|||
SparseHIP,
|
||||
SparseVE,
|
||||
SparseXPU,
|
||||
MSNPU,
|
||||
ORT,
|
||||
XLA,
|
||||
Vulkan,
|
||||
Metal,
|
||||
|
|
@ -66,8 +66,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
|
|||
return Backend::VE;
|
||||
} else if (t == DispatchKey::FPGA) {
|
||||
return Backend::FPGA;
|
||||
} else if (t == DispatchKey::MSNPU) {
|
||||
return Backend::MSNPU;
|
||||
} else if (t == DispatchKey::ORT) {
|
||||
return Backend::ORT;
|
||||
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
|
||||
return Backend::XLA;
|
||||
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
|
||||
|
|
@ -123,8 +123,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
|
|||
return DispatchKey::VE;
|
||||
case Backend::FPGA:
|
||||
return DispatchKey::FPGA;
|
||||
case Backend::MSNPU:
|
||||
return DispatchKey::MSNPU;
|
||||
case Backend::ORT:
|
||||
return DispatchKey::ORT;
|
||||
case Backend::XLA:
|
||||
return DispatchKey::XLA;
|
||||
case Backend::Lazy:
|
||||
|
|
@ -178,8 +178,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
|||
return DeviceType::VE;
|
||||
case Backend::FPGA:
|
||||
return DeviceType::FPGA;
|
||||
case Backend::MSNPU:
|
||||
return DeviceType::MSNPU;
|
||||
case Backend::ORT:
|
||||
return DeviceType::ORT;
|
||||
case Backend::XLA:
|
||||
return DeviceType::XLA;
|
||||
case Backend::Lazy:
|
||||
|
|
@ -235,8 +235,8 @@ static inline const char* toString(Backend b) {
|
|||
return "FPGA";
|
||||
case Backend::XPU:
|
||||
return "XPU";
|
||||
case Backend::MSNPU:
|
||||
return "MSNPU";
|
||||
case Backend::ORT:
|
||||
return "ORT";
|
||||
case Backend::XLA:
|
||||
return "XLA";
|
||||
case Backend::Lazy:
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ DeviceType parse_type(const std::string& device_string) {
|
|||
{"hip", DeviceType::HIP},
|
||||
{"ve", DeviceType::VE},
|
||||
{"fpga", DeviceType::FPGA},
|
||||
{"msnpu", DeviceType::MSNPU},
|
||||
{"ort", DeviceType::ORT},
|
||||
{"xla", DeviceType::XLA},
|
||||
{"lazy", DeviceType::Lazy},
|
||||
{"vulkan", DeviceType::Vulkan},
|
||||
|
|
@ -47,7 +47,7 @@ DeviceType parse_type(const std::string& device_string) {
|
|||
}
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
|
||||
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
|
||||
device_string);
|
||||
}
|
||||
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
|
||||
|
|
|
|||
|
|
@ -25,8 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
|||
return lower_case ? "ve" : "VE";
|
||||
case DeviceType::FPGA:
|
||||
return lower_case ? "fpga" : "FPGA";
|
||||
case DeviceType::MSNPU:
|
||||
return lower_case ? "msnpu" : "MSNPU";
|
||||
case DeviceType::ORT:
|
||||
return lower_case ? "ort" : "ORT";
|
||||
case DeviceType::XLA:
|
||||
return lower_case ? "xla" : "XLA";
|
||||
case DeviceType::Lazy:
|
||||
|
|
@ -75,7 +75,7 @@ bool isValidDeviceType(DeviceType d) {
|
|||
case DeviceType::HIP:
|
||||
case DeviceType::VE:
|
||||
case DeviceType::FPGA:
|
||||
case DeviceType::MSNPU:
|
||||
case DeviceType::ORT:
|
||||
case DeviceType::XLA:
|
||||
case DeviceType::Lazy:
|
||||
case DeviceType::MLC:
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ enum class DeviceType : int8_t {
|
|||
IDEEP = 5, // IDEEP.
|
||||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MSNPU = 8, // MSNPU
|
||||
ORT = 8, // ONNX Runtime / Microsoft
|
||||
XLA = 9, // XLA / TPU
|
||||
Vulkan = 10, // Vulkan
|
||||
Metal = 11, // Metal
|
||||
|
|
@ -42,7 +42,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
|
|||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kFPGA = DeviceType::FPGA;
|
||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||
constexpr DeviceType kORT = DeviceType::ORT;
|
||||
constexpr DeviceType kXLA = DeviceType::XLA;
|
||||
constexpr DeviceType kMLC = DeviceType::MLC;
|
||||
constexpr DeviceType kMeta = DeviceType::Meta;
|
||||
|
|
|
|||
|
|
@ -19,8 +19,8 @@ const char* toString(DispatchKey t) {
|
|||
return "FPGA";
|
||||
case DispatchKey::XPU:
|
||||
return "XPU";
|
||||
case DispatchKey::MSNPU:
|
||||
return "MSNPU";
|
||||
case DispatchKey::ORT:
|
||||
return "ORT";
|
||||
case DispatchKey::XLA:
|
||||
return "XLA";
|
||||
case DispatchKey::Lazy:
|
||||
|
|
|
|||
|
|
@ -59,8 +59,15 @@ enum class DispatchKey : uint8_t {
|
|||
// CUDA]
|
||||
FPGA, // Xilinx support lives out of tree at
|
||||
// https://gitlab.com/pytorch-complex/vitis_kernels
|
||||
MSNPU, // unused externally, but tested at
|
||||
// test/cpp_extensions/msnpu_extension.cpp
|
||||
|
||||
// ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
|
||||
// https://github.com/microsoft/onnxruntime, and is also used to test general
|
||||
// backend/extension machinery in the core. cf:
|
||||
// - test/cpp_extensions/ort_extension.cpp
|
||||
// - test/test_torch.py
|
||||
// - aten/src/ATen/test/extension_backend_test.cpp
|
||||
ORT,
|
||||
|
||||
XLA, // lives out of tree at https://github.com/pytorch/xla
|
||||
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
|
||||
Vulkan,
|
||||
|
|
@ -114,7 +121,7 @@ enum class DispatchKey : uint8_t {
|
|||
|
||||
// Here are reserved backends for user-defined backends, see Note [Private use
|
||||
// DispatchKey]
|
||||
// To see some example about how to use this, check out MSNPU
|
||||
// To see some example about how to use this, check out ORT
|
||||
PrivateUse1,
|
||||
PrivateUse2,
|
||||
PrivateUse3,
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
|
|||
DispatchKey::PrivateUse3,
|
||||
DispatchKey::MLC,
|
||||
DispatchKey::HPU,
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Meta,
|
||||
});
|
||||
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
|
|||
{DispatchKey::HIP,
|
||||
DispatchKey::VE,
|
||||
DispatchKey::FPGA,
|
||||
DispatchKey::MSNPU,
|
||||
DispatchKey::ORT,
|
||||
DispatchKey::Vulkan,
|
||||
DispatchKey::Metal,
|
||||
DispatchKey::QuantizedCPU,
|
||||
|
|
|
|||
|
|
@ -873,6 +873,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
return key_set_.has(DispatchKey::MLC);
|
||||
}
|
||||
|
||||
bool is_ort() const {
|
||||
return key_set_.has(DispatchKey::ORT);
|
||||
}
|
||||
|
||||
// TODO: remove this once we don't automatically enabled Autograd dispatch
|
||||
// keys
|
||||
// in TensorImpl constructor.
|
||||
|
|
|
|||
|
|
@ -663,8 +663,8 @@ inline DispatchKey computeDispatchKey(
|
|||
return DispatchKey::VE;
|
||||
case DeviceType::FPGA:
|
||||
return DispatchKey::FPGA;
|
||||
case DeviceType::MSNPU:
|
||||
return DispatchKey::MSNPU;
|
||||
case DeviceType::ORT:
|
||||
return DispatchKey::ORT;
|
||||
case DeviceType::XLA:
|
||||
return DispatchKey::XLA;
|
||||
case DeviceType::Lazy:
|
||||
|
|
@ -790,10 +790,8 @@ inline DeviceType dispatchKeyToDeviceType(DispatchKey dispatch_key) {
|
|||
case DispatchKey::HPU:
|
||||
case DispatchKey::AutogradHPU:
|
||||
return DeviceType::HPU;
|
||||
|
||||
// stuff that isn't real
|
||||
case DispatchKey::MSNPU:
|
||||
return DeviceType::MSNPU;
|
||||
case DispatchKey::ORT:
|
||||
return DeviceType::ORT;
|
||||
default:
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
|
|||
|
|
@ -219,7 +219,7 @@ enum DeviceTypeProto {
|
|||
PROTO_IDEEP = 5; // IDEEP.
|
||||
PROTO_HIP = 6; // AMD HIP
|
||||
PROTO_FPGA = 7; // FPGA
|
||||
PROTO_MSNPU = 8; // MSNPU
|
||||
PROTO_ORT = 8; // ONNX Runtime
|
||||
PROTO_XLA = 9; // XLA / TPU
|
||||
PROTO_MLC = 10; // ML Compute
|
||||
// Change the following number if you add more devices in the code.
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ class _DeviceTypeProto(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapp
|
|||
PROTO_IDEEP = DeviceTypeProto.V(5)
|
||||
PROTO_HIP = DeviceTypeProto.V(6)
|
||||
PROTO_FPGA = DeviceTypeProto.V(7)
|
||||
PROTO_MSNPU = DeviceTypeProto.V(8)
|
||||
PROTO_ORT = DeviceTypeProto.V(8)
|
||||
PROTO_XLA = DeviceTypeProto.V(9)
|
||||
PROTO_MLC = DeviceTypeProto.V(10)
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
|
||||
|
|
@ -37,7 +37,7 @@ PROTO_OPENCL = DeviceTypeProto.V(4)
|
|||
PROTO_IDEEP = DeviceTypeProto.V(5)
|
||||
PROTO_HIP = DeviceTypeProto.V(6)
|
||||
PROTO_FPGA = DeviceTypeProto.V(7)
|
||||
PROTO_MSNPU = DeviceTypeProto.V(8)
|
||||
PROTO_ORT = DeviceTypeProto.V(8)
|
||||
PROTO_XLA = DeviceTypeProto.V(9)
|
||||
PROTO_MLC = DeviceTypeProto.V(10)
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = DeviceTypeProto.V(11)
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|||
Storage(
|
||||
Storage::use_byte_size_t(),
|
||||
0,
|
||||
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)),
|
||||
at::DataPtr(nullptr, Device(DeviceType::ORT, 0)),
|
||||
nullptr,
|
||||
false),
|
||||
DispatchKey::MSNPU,
|
||||
DispatchKey::ORT,
|
||||
dtype);
|
||||
// This is a hack to workaround the shape checks in _convolution.
|
||||
tensor_impl->set_sizes_contiguous(size);
|
||||
|
|
@ -52,7 +52,7 @@ std::tuple<Tensor,Tensor,Tensor> fake_convolution_backward(
|
|||
get_tensor(input.dtype(), {}));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
||||
TORCH_LIBRARY_IMPL(aten, ORT, m) {
|
||||
m.impl("empty.memory_format", empty_override);
|
||||
m.impl("add.out", add_out_override);
|
||||
m.impl("convolution_overrideable", fake_convolution);
|
||||
|
|
@ -61,34 +61,34 @@ TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
|||
|
||||
// TODO: Extend this to exercise multi-device setting. In that case,
|
||||
// we need to add a thread local variable to track the current device.
|
||||
struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = DeviceType::MSNPU;
|
||||
MSNPUGuardImpl() {}
|
||||
MSNPUGuardImpl(DeviceType t) {
|
||||
AT_ASSERT(t == DeviceType::MSNPU);
|
||||
struct ORTGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr DeviceType static_type = DeviceType::ORT;
|
||||
ORTGuardImpl() {}
|
||||
ORTGuardImpl(DeviceType t) {
|
||||
AT_ASSERT(t == DeviceType::ORT);
|
||||
}
|
||||
DeviceType type() const override {
|
||||
return DeviceType::MSNPU;
|
||||
return DeviceType::ORT;
|
||||
}
|
||||
Device exchangeDevice(Device d) const override {
|
||||
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
||||
AT_ASSERT(d.type() == DeviceType::ORT);
|
||||
AT_ASSERT(d.index() == 0);
|
||||
return d;
|
||||
}
|
||||
Device getDevice() const override {
|
||||
return Device(DeviceType::MSNPU, 0);
|
||||
return Device(DeviceType::ORT, 0);
|
||||
}
|
||||
void setDevice(Device d) const override {
|
||||
AT_ASSERT(d.type() == DeviceType::MSNPU);
|
||||
AT_ASSERT(d.type() == DeviceType::ORT);
|
||||
AT_ASSERT(d.index() == 0);
|
||||
}
|
||||
void uncheckedSetDevice(Device d) const noexcept override {
|
||||
}
|
||||
Stream getStream(Device d) const noexcept override {
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
|
||||
}
|
||||
Stream exchangeStream(Stream s) const noexcept override {
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::MSNPU, 0));
|
||||
return Stream(Stream::DEFAULT, Device(DeviceType::ORT, 0));
|
||||
}
|
||||
DeviceIndex deviceCount() const noexcept override {
|
||||
return 1;
|
||||
|
|
@ -99,23 +99,23 @@ struct MSNPUGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
|||
const Stream& stream,
|
||||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override {
|
||||
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
||||
TORCH_CHECK(false, "ORT backend doesn't support events.");
|
||||
}
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override {
|
||||
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
||||
TORCH_CHECK(false, "ORT backend doesn't support events.");
|
||||
}
|
||||
bool queryEvent(void* event) const override {
|
||||
TORCH_CHECK(false, "MSNPU backend doesn't support events.");
|
||||
TORCH_CHECK(false, "ORT backend doesn't support events.");
|
||||
}
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override { }
|
||||
};
|
||||
|
||||
constexpr DeviceType MSNPUGuardImpl::static_type;
|
||||
C10_REGISTER_GUARD_IMPL(MSNPU, MSNPUGuardImpl);
|
||||
constexpr DeviceType ORTGuardImpl::static_type;
|
||||
C10_REGISTER_GUARD_IMPL(ORT, ORTGuardImpl);
|
||||
|
||||
int get_test_int() {
|
||||
return test_int;
|
||||
|
|
@ -21,7 +21,7 @@ ext_modules = [
|
|||
'torch_test_cpp_extension.cpp', ['extension.cpp'],
|
||||
extra_compile_args=CXX_FLAGS),
|
||||
CppExtension(
|
||||
'torch_test_cpp_extension.msnpu', ['msnpu_extension.cpp'],
|
||||
'torch_test_cpp_extension.ort', ['ort_extension.cpp'],
|
||||
extra_compile_args=CXX_FLAGS),
|
||||
CppExtension(
|
||||
'torch_test_cpp_extension.rng', ['rng_extension.cpp'],
|
||||
|
|
|
|||
|
|
@ -19,11 +19,11 @@ except ImportError as e:
|
|||
try:
|
||||
if HAS_PYTEST:
|
||||
cpp_extension = pytest.importorskip("torch_test_cpp_extension.cpp")
|
||||
msnpu_extension = pytest.importorskip("torch_test_cpp_extension.msnpu")
|
||||
ort_extension = pytest.importorskip("torch_test_cpp_extension.ort")
|
||||
rng_extension = pytest.importorskip("torch_test_cpp_extension.rng")
|
||||
else:
|
||||
import torch_test_cpp_extension.cpp as cpp_extension
|
||||
import torch_test_cpp_extension.msnpu as msnpu_extension
|
||||
import torch_test_cpp_extension.ort as ort_extension
|
||||
import torch_test_cpp_extension.rng as rng_extension
|
||||
except ImportError as e:
|
||||
raise RuntimeError(
|
||||
|
|
@ -100,45 +100,45 @@ class TestCppExtensionAOT(common.TestCase):
|
|||
self.assertFalse(has_value)
|
||||
|
||||
|
||||
class TestMSNPUTensor(common.TestCase):
|
||||
class TestORTTensor(common.TestCase):
|
||||
def test_unregistered(self):
|
||||
a = torch.arange(0, 10, device='cpu')
|
||||
with self.assertRaisesRegex(RuntimeError, "Could not run"):
|
||||
b = torch.arange(0, 10, device='msnpu')
|
||||
b = torch.arange(0, 10, device='ort')
|
||||
|
||||
def test_zeros(self):
|
||||
a = torch.empty(5, 5, device='cpu')
|
||||
self.assertEqual(a.device, torch.device('cpu'))
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
self.assertEqual(b.device, torch.device('msnpu', 0))
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
self.assertEqual(b.device, torch.device('ort', 0))
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.get_default_dtype(), b.dtype)
|
||||
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='msnpu')
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
c = torch.empty((5, 5), dtype=torch.int64, device='ort')
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
self.assertEqual(torch.int64, c.dtype)
|
||||
|
||||
def test_add(self):
|
||||
a = torch.empty(5, 5, device='msnpu', requires_grad=True)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
a = torch.empty(5, 5, device='ort', requires_grad=True)
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
b = torch.empty(5, 5, device='msnpu')
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 0)
|
||||
b = torch.empty(5, 5, device='ort')
|
||||
self.assertEqual(ort_extension.get_test_int(), 0)
|
||||
|
||||
c = a + b
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 1)
|
||||
self.assertEqual(ort_extension.get_test_int(), 1)
|
||||
|
||||
def test_conv_backend_override(self):
|
||||
# To simplify tests, we use 4d input here to avoid doing view4d( which
|
||||
# needs more overrides) in _convolution.
|
||||
input = torch.empty(2, 4, 10, 2, device='msnpu', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='msnpu', requires_grad=True)
|
||||
bias = torch.empty(6, device='msnpu')
|
||||
input = torch.empty(2, 4, 10, 2, device='ort', requires_grad=True)
|
||||
weight = torch.empty(6, 4, 2, 2, device='ort', requires_grad=True)
|
||||
bias = torch.empty(6, device='ort')
|
||||
|
||||
# Make sure forward is overriden
|
||||
out = torch.nn.functional.conv1d(input, weight, bias, 2, 0, 1, 1)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 2)
|
||||
self.assertEqual(ort_extension.get_test_int(), 2)
|
||||
self.assertEqual(out.shape[0], input.shape[0])
|
||||
self.assertEqual(out.shape[1], weight.shape[0])
|
||||
|
||||
|
|
@ -146,7 +146,7 @@ class TestMSNPUTensor(common.TestCase):
|
|||
# Double backward is dispatched to _convolution_double_backward.
|
||||
# It is not tested here as it involves more computation/overrides.
|
||||
grad = torch.autograd.grad(out, input, out, create_graph=True)
|
||||
self.assertEqual(msnpu_extension.get_test_int(), 3)
|
||||
self.assertEqual(ort_extension.get_test_int(), 3)
|
||||
self.assertEqual(grad[0].shape, input.shape)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -138,11 +138,11 @@ supported:
|
|||
self.assertExpectedInline(output_error, '''Found an invalid operator name: abs_BAD''')
|
||||
|
||||
# The backend is valid, but doesn't have a valid autograd key. They can't override autograd kernels in that case.
|
||||
# Only using MSNPU here because it has a valid backend key but not an autograd key- if this changes we can update the test.
|
||||
# Only using Vulkan here because it has a valid backend key but not an autograd key- if this changes we can update the test.
|
||||
def test_backend_has_no_autograd_key_but_provides_entries(self):
|
||||
yaml_str = '''\
|
||||
backend: MSNPU
|
||||
cpp_namespace: torch_msnpu
|
||||
backend: Vulkan
|
||||
cpp_namespace: torch_vulkan
|
||||
supported:
|
||||
- add
|
||||
autograd:
|
||||
|
|
@ -155,7 +155,7 @@ autograd:
|
|||
def test_backend_autograd_kernel_mismatch_out_functional(self):
|
||||
yaml_str = '''\
|
||||
backend: XLA
|
||||
cpp_namespace: torch_msnpu
|
||||
cpp_namespace: torch_xla
|
||||
supported:
|
||||
- add.Tensor
|
||||
autograd:
|
||||
|
|
@ -168,7 +168,7 @@ autograd:
|
|||
def test_backend_autograd_kernel_mismatch_functional_inplace(self):
|
||||
yaml_str = '''\
|
||||
backend: XLA
|
||||
cpp_namespace: torch_msnpu
|
||||
cpp_namespace: torch_xla
|
||||
supported:
|
||||
- add.Tensor
|
||||
autograd:
|
||||
|
|
@ -182,7 +182,7 @@ autograd:
|
|||
def test_op_appears_in_supported_and_autograd_lists(self):
|
||||
yaml_str = '''\
|
||||
backend: XLA
|
||||
cpp_namespace: torch_msnpu
|
||||
cpp_namespace: torch_xla
|
||||
supported:
|
||||
- add.Tensor
|
||||
autograd:
|
||||
|
|
|
|||
|
|
@ -221,10 +221,10 @@ class AbstractTestCases:
|
|||
# TODO: add torch.* tests when we have proper namespacing on ATen functions
|
||||
# test_namespace(torch)
|
||||
|
||||
def test_msnpu_error(self):
|
||||
def test_ort_error(self):
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"Could not run 'aten::empty.memory_format' with arguments from the 'MSNPU' backend"):
|
||||
torch.zeros(1, device=torch.device('msnpu'))
|
||||
"Could not run 'aten::empty.memory_format' with arguments from the 'ORT' backend"):
|
||||
torch.zeros(1, device=torch.device('ort'))
|
||||
|
||||
def test_has_storage(self):
|
||||
self.assertIsNotNone(torch.tensor([]).storage())
|
||||
|
|
|
|||
|
|
@ -829,6 +829,7 @@ aten_cpu_source_non_codegen_list = [
|
|||
"aten/src/ATen/detail/CPUGuardImpl.cpp",
|
||||
"aten/src/ATen/detail/CUDAHooksInterface.cpp",
|
||||
"aten/src/ATen/detail/HIPHooksInterface.cpp",
|
||||
"aten/src/ATen/detail/ORTHooksInterface.cpp",
|
||||
"aten/src/ATen/metal/Context.cpp",
|
||||
"aten/src/ATen/native/AutogradComposite.cpp",
|
||||
"aten/src/ATen/native/BatchLinearAlgebraKernel.cpp",
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class DispatchKey(Enum):
|
|||
CUDA = auto()
|
||||
HIP = auto()
|
||||
FPGA = auto()
|
||||
MSNPU = auto()
|
||||
ORT = auto()
|
||||
XLA = auto()
|
||||
Lazy = auto()
|
||||
Vulkan = auto()
|
||||
|
|
|
|||
|
|
@ -469,6 +469,7 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
|
|||
'is_sparse_csr' : ['is_sparse_csr: _bool'],
|
||||
'is_quantized': ['is_quantized: _bool'],
|
||||
'is_meta': ['is_meta: _bool'],
|
||||
'is_ort': ['is_ort: _bool'],
|
||||
'is_mkldnn': ['is_mkldnn: _bool'],
|
||||
'is_vulkan': ['is_vulkan: _bool'],
|
||||
'storage_offset': ['def storage_offset(self) -> _int: ...'],
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ class DeviceType(Enum):
|
|||
IDEEP = ...
|
||||
HIP = ...
|
||||
FPGA = ...
|
||||
MSNPU = ...
|
||||
ORT = ...
|
||||
XLA = ...
|
||||
MLC = ...
|
||||
HPU = ...
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Tensor(torch._C._TensorBase):
|
|||
# does accurate alias tracking; however, the code below
|
||||
# doesn't work because of
|
||||
# https://github.com/pytorch/pytorch/issues/47442
|
||||
if self.is_sparse or self.device.type in ['xla', 'mlc', 'meta']:
|
||||
if self.is_sparse or self.device.type in ['xla', 'mlc', 'ort', 'meta']:
|
||||
new_tensor = self.clone()
|
||||
else:
|
||||
new_storage = self.storage().__deepcopy__(memo)
|
||||
|
|
@ -153,28 +153,21 @@ class Tensor(torch._C._TensorBase):
|
|||
# See Note [Don't serialize hooks]
|
||||
torch.utils.hooks.warn_if_has_hooks(self)
|
||||
backward_hooks: Dict[Any, Any] = OrderedDict()
|
||||
# Note: Numpy array is chosen to be the rebuild component for XLA Tensor.
|
||||
# Note: Numpy array is chosen to be the rebuild component for XLA, ORT, MLC Tensors.
|
||||
# We considered a few options:
|
||||
# 1. CPU tensor can't be used here.
|
||||
# Otherwise in torch.load CPU storage is reconstructed with randomly
|
||||
# initialized data, moved onto XLA device, and then storage is updated
|
||||
# to the serialized content. This works perfectly for CPU/CUDA but not XLA.
|
||||
# XLA tensor is disconnected with storage so it doesn't get the update.
|
||||
# initialized data, moved onto backend device, and then storage is updated
|
||||
# to the serialized content. This works perfectly for CPU/CUDA but not these backends;
|
||||
# their tensors are disconnected with storage so they don't get the update.
|
||||
# 2. Python list is not a good fit due to performance reason.
|
||||
# `tolist()` converts every single element in the tensor into python objects
|
||||
# and serialize them one by one.
|
||||
if self.device.type == 'xla':
|
||||
arg_xla = (self.cpu().numpy(),
|
||||
if self.device.type in ['xla', 'ort', 'mlc']:
|
||||
return (torch._utils._rebuild_device_tensor_from_numpy, (self.cpu().numpy(),
|
||||
self.dtype,
|
||||
str(self.device),
|
||||
self.requires_grad)
|
||||
return (torch._utils._rebuild_xla_tensor, arg_xla)
|
||||
if self.device.type == 'mlc':
|
||||
arg_mlc = (self.cpu().numpy(),
|
||||
self.dtype,
|
||||
str(self.device),
|
||||
self.requires_grad)
|
||||
return (torch._utils._rebuild_mlc_tensor, arg_mlc)
|
||||
self.requires_grad))
|
||||
if self.device.type == 'meta':
|
||||
# NB: This implementation BREAKS storage sharing. Current
|
||||
# hypothesis is that no one cares for meta tensors.
|
||||
|
|
|
|||
|
|
@ -173,16 +173,15 @@ def _rebuild_sparse_tensor(layout, data):
|
|||
raise NotImplementedError("rebuilding sparse tensor for layout %s" % (layout))
|
||||
|
||||
|
||||
def _rebuild_xla_tensor(data, dtype, device, requires_grad):
|
||||
def _rebuild_device_tensor_from_numpy(data, dtype, device, requires_grad):
|
||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||
tensor.requires_grad = requires_grad
|
||||
return tensor
|
||||
|
||||
|
||||
def _rebuild_mlc_tensor(data, dtype, device, requires_grad):
|
||||
tensor = torch.from_numpy(data).to(dtype=dtype, device=device)
|
||||
tensor.requires_grad = requires_grad
|
||||
return tensor
|
||||
# Should not be used, only here to be able to load Tensors serialized with older versions of pytorch
|
||||
_rebuild_xla_tensor = _rebuild_device_tensor_from_numpy
|
||||
_rebuild_mlc_tensor = _rebuild_device_tensor_from_numpy
|
||||
|
||||
|
||||
def _rebuild_meta_tensor_no_storage(dtype, size, stride, requires_grad):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@ inline bool THPDevice_Check(PyObject *obj) {
|
|||
return Py_TYPE(obj) == &THPDeviceType;
|
||||
}
|
||||
|
||||
PyObject * THPDevice_New(const at::Device& device);
|
||||
TORCH_API PyObject * THPDevice_New(const at::Device& device);
|
||||
|
||||
void THPDevice_init(PyObject *module);
|
||||
TORCH_API void THPDevice_init(PyObject *module);
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ PyObject* THPAutograd_initExtension(PyObject* _unused, PyObject *unused) {
|
|||
.value("IDEEP", c10::DeviceType::IDEEP)
|
||||
.value("HIP", c10::DeviceType::HIP)
|
||||
.value("FPGA", c10::DeviceType::FPGA)
|
||||
.value("MSNPU", c10::DeviceType::MSNPU)
|
||||
.value("ORT", c10::DeviceType::ORT)
|
||||
.value("XLA", c10::DeviceType::XLA)
|
||||
.value("Lazy", c10::DeviceType::Lazy)
|
||||
.value("MLC", c10::DeviceType::MLC)
|
||||
|
|
|
|||
|
|
@ -834,6 +834,17 @@ PyObject *THPVariable_is_mlc(THPVariable *self, void *unused)
|
|||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_is_ort(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
if (check_has_torch_function((PyObject *)self)) {
|
||||
return handle_torch_function_getter(self, "is_ort");
|
||||
}
|
||||
auto& self_ = THPVariable_Unpack(self);
|
||||
return torch::autograd::utils::wrap(self_.is_ort());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
PyObject *THPVariable_is_vulkan(THPVariable *self, void *unused)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
|
|
@ -980,6 +991,7 @@ static struct PyGetSetDef THPVariable_properties[] = {
|
|||
{"is_sparse_csr", (getter)THPVariable_is_sparse_csr, nullptr, nullptr, nullptr},
|
||||
{"is_mkldnn", (getter)THPVariable_is_mkldnn, nullptr, nullptr, nullptr},
|
||||
{"is_mlc", (getter)THPVariable_is_mlc, nullptr, nullptr, nullptr},
|
||||
{"is_ort", (getter)THPVariable_is_ort, nullptr, nullptr, nullptr},
|
||||
{"is_vulkan", (getter)THPVariable_is_vulkan, nullptr, nullptr, nullptr},
|
||||
{"is_complex", (getter)THPVariable_is_complex, nullptr, nullptr, nullptr},
|
||||
{"is_quantized", (getter)THPVariable_is_quantized, nullptr, nullptr, nullptr},
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
|
|||
{"layout", "prim"}, {"T", "prim"},
|
||||
{"ndim", "prim"}, {"name", "prim"},
|
||||
{"real", "aten"}, {"imag", "aten"},
|
||||
{"retains_grad", "aten"},
|
||||
{"retains_grad", "aten"}, {"is_ort", "prim"},
|
||||
}},
|
||||
{TypeKind::DeviceObjType, {{"type", "prim"}, {"index", "prim"}}}};
|
||||
auto kind = value_->type()->kind();
|
||||
|
|
|
|||
|
|
@ -2211,6 +2211,14 @@ RegisterOperators reg1(
|
|||
push(stack, a.is_meta());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::is_ort(Tensor a) -> bool"),
|
||||
[](Stack* stack) {
|
||||
at::Tensor a;
|
||||
pop(stack, a);
|
||||
push(stack, a.is_ort());
|
||||
},
|
||||
aliasAnalysisFromSchema()),
|
||||
OperatorGenerator(
|
||||
TORCH_SELECTIVE_SCHEMA("prim::name(Tensor a) -> str?"),
|
||||
[](Stack* stack) {
|
||||
|
|
|
|||
|
|
@ -317,8 +317,8 @@ inline CppFunction dispatch(c10::DeviceType type, Func&& raw_f) {
|
|||
return c10::DispatchKey::Meta;
|
||||
case c10::DeviceType::HIP:
|
||||
return c10::DispatchKey::HIP;
|
||||
case c10::DeviceType::MSNPU:
|
||||
return c10::DispatchKey::MSNPU;
|
||||
case c10::DeviceType::ORT:
|
||||
return c10::DispatchKey::ORT;
|
||||
case c10::DeviceType::HPU:
|
||||
return c10::DispatchKey::HPU;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -1030,6 +1030,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
Tensor.retains_grad.__get__: lambda self: -1,
|
||||
Tensor.is_meta.__get__: lambda self: -1,
|
||||
Tensor.is_mlc.__get__: lambda self: -1,
|
||||
Tensor.is_ort.__get__: lambda self: -1,
|
||||
Tensor.is_mkldnn.__get__: lambda self: -1,
|
||||
Tensor.is_quantized.__get__: lambda self: -1,
|
||||
Tensor.is_sparse.__get__: lambda self: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user