mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Introduce backend extensions (overriding operators on custom backends)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/15153 Reviewed By: gchanan Differential Revision: D13445571 fbshipit-source-id: 62e2ebe0a6e81c4983b47cddb57ee5eb78e96708
This commit is contained in:
parent
64186e06ec
commit
7e642dfff3
|
|
@ -74,6 +74,7 @@ enum class TypeID {
|
|||
SparseCUDAInt,
|
||||
SparseCUDALong,
|
||||
SparseCUDAShort,
|
||||
MSNPU,
|
||||
CPUComplexFloat,
|
||||
CPUComplexDouble,
|
||||
CUDAComplexFloat,
|
||||
|
|
|
|||
|
|
@ -116,6 +116,13 @@ TYPE_DEFINITION_BODY_NATIVE = CodeTemplate("""\
|
|||
${return_call} at::native::${native_type_method_dispatch}(/* native_actuals */ ${native_actuals});
|
||||
""")
|
||||
|
||||
# Overrideable stubs to be used in user-extendable backends
|
||||
TYPE_DEFINITION_EXTENSION_BACKEND = CodeTemplate("""\
|
||||
${return_type} ${Type}::${method_prefix_derived}${api_name}(${type_method_formals}) const {
|
||||
return ${Type}Dispatch::get_function<${return_type} (*)(${formals_types})>("${schema}")(${native_actuals});
|
||||
}
|
||||
""")
|
||||
|
||||
# add non-virtual declaration to Tensor.h
|
||||
TENSOR_METHOD_DECLARATION = CodeTemplate("""\
|
||||
${return_type} ${api_name}(${method_formals_with_defaults})${const_mark};
|
||||
|
|
@ -489,6 +496,7 @@ FunctionOption = TypedDict('FunctionOption', {
|
|||
'formals_list': List[AtFormal],
|
||||
'formals_with_defaults': List[str],
|
||||
'formals': List[str],
|
||||
'formals_types': List[str],
|
||||
'inferred_type': str,
|
||||
'inplace': bool,
|
||||
'matches_jit_signature': bool,
|
||||
|
|
@ -513,6 +521,8 @@ FunctionOption = TypedDict('FunctionOption', {
|
|||
'return': ReturnDecl,
|
||||
'returns': List[ReturnType],
|
||||
'scalar_check': str,
|
||||
# schema used for extension backend operator registration
|
||||
'schema': str,
|
||||
'sparse': bool,
|
||||
'type_definition_body': List[str],
|
||||
'type_method_actuals': List[str],
|
||||
|
|
@ -1595,3 +1605,28 @@ def create_derived(backend_type_env, declarations):
|
|||
except NYIError:
|
||||
pass
|
||||
return type_object_declarations, type_object_definitions
|
||||
|
||||
|
||||
def create_extension_backend(backend_type_env, declarations):
|
||||
# type: (Environment, List[FunctionOption]) -> Tuple[List[str], List[str]]
|
||||
type_object_declarations = []
|
||||
type_object_definitions = []
|
||||
|
||||
for declaration in declarations:
|
||||
for option in declaration['options']:
|
||||
if not option.get('skip', False):
|
||||
try:
|
||||
option['formals_types'] = [f['type'] for f in option['formals_list']]
|
||||
option['native_actuals'] = [f['name'] for f in option['formals_list']]
|
||||
schema_args = ", ".join(
|
||||
["{} {}".format(f['dynamic_type'], f['name']) for f in option['formals_list']])
|
||||
return_type = NATIVE_DYNAMIC_TYPE.get(option['return_type'], option['return_type'])
|
||||
option['schema'] = "{}({}) -> {}".format(option['api_name'], schema_args, return_type)
|
||||
env = nested_dict(option, backend_type_env)
|
||||
type_object_declarations.append(
|
||||
TYPE_DERIVED_DECLARATION.substitute(env))
|
||||
type_object_definitions.append(
|
||||
TYPE_DEFINITION_EXTENSION_BACKEND.substitute(env))
|
||||
except NYIError:
|
||||
pass
|
||||
return type_object_declarations, type_object_definitions
|
||||
|
|
|
|||
|
|
@ -121,6 +121,8 @@ TYPE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Type.h")
|
|||
TYPE_EXTENDED_INTERFACE_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtendedInterface.h")
|
||||
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
|
||||
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
|
||||
TYPE_EXTENSION_BACKEND_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.h")
|
||||
TYPE_EXTENSION_BACKEND_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeExtension.cpp")
|
||||
|
||||
LEGACY_TH_DISPATCHER_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.h")
|
||||
LEGACY_TH_DISPATCHER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHDispatcher.cpp")
|
||||
|
|
@ -141,10 +143,18 @@ LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctio
|
|||
|
||||
NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h")
|
||||
|
||||
EXTENSION_BACKEND_REGISTRATION_H = CodeTemplate.from_file(TEMPLATE_PATH + "/ExtensionBackendRegistration.h")
|
||||
|
||||
TYPE_REGISTER = CodeTemplate("""\
|
||||
context->registerType(Backend::${backend}, ScalarType::${scalar_type}, new ${type_name}());
|
||||
""")
|
||||
|
||||
EXTENSION_BACKEND_REGISTER_SWITCH = CodeTemplate("""\
|
||||
case Backend::${Backend}:
|
||||
${Type}Dispatch::register_function(schema, fn);
|
||||
break;
|
||||
""")
|
||||
|
||||
core_file_manager = FileManager(core_install_dir)
|
||||
file_manager = FileManager()
|
||||
cuda_file_manager = FileManager()
|
||||
|
|
@ -164,6 +174,7 @@ generators = {
|
|||
|
||||
backends = ['CPU', 'CUDA']
|
||||
densities = ['Dense', 'Sparse']
|
||||
extension_backends = ['MSNPU']
|
||||
|
||||
# scalar_name, c_type, accreal, th_scalar_type, is_floating_type
|
||||
scalar_types = [
|
||||
|
|
@ -193,6 +204,8 @@ top_env = {
|
|||
'function_definitions': [],
|
||||
'type_ids': [],
|
||||
'native_function_declarations': [],
|
||||
'extension_backend_headers': [],
|
||||
'extension_backend_register_switches': [],
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -347,6 +360,37 @@ def generate_storage_type_and_tensor(backend, density, scalar_type, declarations
|
|||
return env
|
||||
|
||||
|
||||
def generate_type_extension_backend(backend, declarations):
|
||||
env = {}
|
||||
env['Type'] = "{}Type".format(backend)
|
||||
env['Backend'] = backend
|
||||
env['DeviceType'] = backend
|
||||
env['is_extension_backend'] = True
|
||||
env['TypeID'] = 'TypeID::' + backend
|
||||
top_env['type_ids'].append(backend + ',')
|
||||
|
||||
declarations, definitions = function_wrapper.create_extension_backend(
|
||||
env, declarations)
|
||||
env['type_method_declarations'] = declarations
|
||||
env['type_method_definitions'] = definitions
|
||||
|
||||
fm = file_manager
|
||||
fm.write(env['Type'] + ".cpp", TYPE_EXTENSION_BACKEND_CPP, env)
|
||||
fm.write(env['Type'] + ".h", TYPE_EXTENSION_BACKEND_H, env)
|
||||
|
||||
for scalar_name, _, _, _, _ in scalar_types:
|
||||
type_register = TYPE_REGISTER.substitute(backend=env['Backend'], scalar_type=scalar_name, type_name=env['Type'])
|
||||
top_env['cpu_type_registrations'].append(type_register)
|
||||
extension_backend_register_switch = EXTENSION_BACKEND_REGISTER_SWITCH.substitute(env)
|
||||
top_env['extension_backend_register_switches'].append(extension_backend_register_switch)
|
||||
top_env['extension_backend_headers'].append(
|
||||
'#include <ATen/{}.h>'.format(env['Type']))
|
||||
top_env['cpu_type_headers'].append(
|
||||
'#include "ATen/{}.h"'.format(env['Type']))
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def generate_legacy_th_dispatcher(backend, density, scalar_type, declarations):
|
||||
assert density != 'Sparse'
|
||||
scalar_name, c_type, accreal, th_scalar_type, is_floating_type = scalar_type
|
||||
|
|
@ -384,7 +428,7 @@ def declare_outputs():
|
|||
core_file_manager.will_write(f)
|
||||
files = ['Declarations.yaml', 'TypeExtendedInterface.h', 'TypeDefault.cpp', 'TypeDefault.h',
|
||||
'LegacyTHDispatcher.h', 'LegacyTHDispatcher.cpp', 'LegacyTHFunctions.h',
|
||||
'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h']
|
||||
'Functions.h', 'NativeFunctions.h', 'RegisterCPU.cpp', 'RegisterCPU.h', 'ExtensionBackendRegistration.h']
|
||||
for f in files:
|
||||
file_manager.will_write(f)
|
||||
cuda_files = ['RegisterCUDA.cpp', 'RegisterCUDA.h']
|
||||
|
|
@ -411,6 +455,9 @@ def declare_outputs():
|
|||
if density != 'Sparse':
|
||||
fm.will_write("{}{}{}{}.h".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
|
||||
fm.will_write("{}{}{}{}.cpp".format('LegacyTH', full_backend, scalar_name, 'Dispatcher'))
|
||||
for backend in extension_backends:
|
||||
file_manager.will_write("{}Type.h".format(backend))
|
||||
file_manager.will_write("{}Type.cpp".format(backend))
|
||||
|
||||
|
||||
def filter_by_extension(files, *extensions):
|
||||
|
|
@ -472,6 +519,8 @@ def generate_outputs():
|
|||
for backend, density, scalar_type in iterate_types():
|
||||
all_types.append(generate_storage_type_and_tensor(
|
||||
backend, density, scalar_type, declarations))
|
||||
for backend in extension_backends:
|
||||
all_types.append(generate_type_extension_backend(backend, declarations))
|
||||
|
||||
all_legacy_th_dispatchers = []
|
||||
for backend, density, scalar_type in iterate_types():
|
||||
|
|
@ -506,6 +555,8 @@ def generate_outputs():
|
|||
|
||||
file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env)
|
||||
|
||||
file_manager.write('ExtensionBackendRegistration.h', EXTENSION_BACKEND_REGISTRATION_H, top_env)
|
||||
|
||||
file_manager.check_all_files_written()
|
||||
cuda_file_manager.check_all_files_written()
|
||||
|
||||
|
|
|
|||
19
aten/src/ATen/templates/ExtensionBackendRegistration.h
Normal file
19
aten/src/ATen/templates/ExtensionBackendRegistration.h
Normal file
|
|
@ -0,0 +1,19 @@
|
|||
#pragma once
|
||||
#include <ATen/Backend.h>
|
||||
${extension_backend_headers}
|
||||
|
||||
namespace at {
|
||||
|
||||
template <typename FnPtr>
|
||||
inline void register_extension_backend_op(
|
||||
Backend backend,
|
||||
const char * schema,
|
||||
FnPtr fn) {
|
||||
switch (backend) {
|
||||
${extension_backend_register_switches}
|
||||
default:
|
||||
AT_ERROR("Invalid extension backend: ", toString(backend));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace at
|
||||
51
aten/src/ATen/templates/TypeExtension.cpp
Normal file
51
aten/src/ATen/templates/TypeExtension.cpp
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
#include <ATen/${Type}.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
std::unordered_map<std::string, void *>& ${Type}Dispatch::get_fn_table() {
|
||||
static std::unordered_map<std::string, void *> fn_table;
|
||||
return fn_table;
|
||||
}
|
||||
|
||||
${Type}::${Type}()
|
||||
: TypeDefault(${Backend}TensorId(), /*is_variable=*/false, /*is_undefined=*/false) {}
|
||||
|
||||
Allocator* ${Type}::allocator() const {
|
||||
AT_ERROR("allocator is not implemented for ${Type}");
|
||||
}
|
||||
|
||||
Device ${Type}::getDeviceFromPtr(void * data) const {
|
||||
return DeviceType::${DeviceType};
|
||||
}
|
||||
|
||||
std::unique_ptr<Generator> ${Type}::generator() const {
|
||||
AT_ERROR("generator is not implemented for ${Type}");
|
||||
}
|
||||
|
||||
ScalarType ${Type}::scalarType() const {
|
||||
AT_ERROR("scalarType is not implemented for ${Type}");
|
||||
}
|
||||
|
||||
caffe2::TypeMeta ${Type}::typeMeta() const {
|
||||
AT_ERROR("typeMeta is not implemented for ${Type}");
|
||||
}
|
||||
|
||||
Backend ${Type}::backend() const {
|
||||
return Backend::${Backend};
|
||||
}
|
||||
|
||||
const char * ${Type}::toString() const {
|
||||
return "${Type}";
|
||||
}
|
||||
|
||||
TypeID ${Type}::ID() const {
|
||||
return ${TypeID};
|
||||
}
|
||||
|
||||
size_t ${Type}::elementSizeInBytes() const {
|
||||
AT_ERROR("elementSizeInBytes is not implemented for ${Type}");
|
||||
}
|
||||
|
||||
${type_method_definitions}
|
||||
|
||||
} // namespace at
|
||||
49
aten/src/ATen/templates/TypeExtension.h
Normal file
49
aten/src/ATen/templates/TypeExtension.h
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
#pragma once
|
||||
#include <ATen/TypeDefault.h>
|
||||
|
||||
namespace at {
|
||||
|
||||
// This dispatch class holds static map in which function pointers are
|
||||
// registered by schema.
|
||||
// TODO: Check for invalid schemas prior to registration.
|
||||
struct CAFFE2_API ${Type}Dispatch {
|
||||
template<typename FnPtr>
|
||||
static FnPtr get_function(const std::string& schema) {
|
||||
auto & fn_table = get_fn_table();
|
||||
auto it = fn_table.find(schema);
|
||||
if (it != fn_table.end()) {
|
||||
return reinterpret_cast<FnPtr>(it->second);
|
||||
}
|
||||
AT_ERROR("No function registered for schema: ", schema);
|
||||
}
|
||||
|
||||
template<typename FnPtr>
|
||||
static void register_function(const std::string& schema, FnPtr fn) {
|
||||
auto & fn_table = get_fn_table();
|
||||
if (fn_table.find(schema) != fn_table.end()) {
|
||||
AT_ERROR("Function already registered for schema: ", schema);
|
||||
}
|
||||
fn_table[schema] = reinterpret_cast<void *>(fn);
|
||||
}
|
||||
|
||||
static std::unordered_map<std::string, void *>& get_fn_table();
|
||||
};
|
||||
|
||||
struct CAFFE2_API ${Type} : public TypeDefault {
|
||||
explicit ${Type}();
|
||||
|
||||
Allocator* allocator() const override;
|
||||
Device getDeviceFromPtr(void * data) const override;
|
||||
std::unique_ptr<Generator> generator() const override;
|
||||
|
||||
virtual ScalarType scalarType() const override;
|
||||
virtual caffe2::TypeMeta typeMeta() const override;
|
||||
virtual Backend backend() const override;
|
||||
virtual const char * toString() const override;
|
||||
virtual size_t elementSizeInBytes() const override;
|
||||
virtual TypeID ID() const override;
|
||||
|
||||
${type_method_declarations}
|
||||
};
|
||||
|
||||
} // namespace at
|
||||
|
|
@ -20,7 +20,8 @@ list(APPEND ATen_CPU_TEST_SRCS
|
|||
${CMAKE_CURRENT_SOURCE_DIR}/undefined_tensor_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/verify_api_visibility.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/tbb_init_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp)
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/weakref_test.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/extension_backend_test.cpp)
|
||||
|
||||
list(APPEND ATen_CUDA_TEST_SRCS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/cuda_integer_divider_test.cu
|
||||
|
|
|
|||
66
aten/src/ATen/test/extension_backend_test.cpp
Normal file
66
aten/src/ATen/test/extension_backend_test.cpp
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#include <ATen/ExtensionBackendRegistration.h>
|
||||
|
||||
using namespace at;
|
||||
|
||||
static int test_int;
|
||||
|
||||
Tensor empty_override(IntList size, const TensorOptions & options) {
|
||||
test_int = 1;
|
||||
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
||||
Storage(
|
||||
caffe2::TypeMeta::Make<float>(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false),
|
||||
MSNPUTensorId(),
|
||||
false);
|
||||
return Tensor(std::move(tensor_impl));
|
||||
}
|
||||
|
||||
Tensor empty_like_override(const Tensor & self, const TensorOptions & options) {
|
||||
test_int = 2;
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
||||
test_int = 3;
|
||||
return a;
|
||||
}
|
||||
|
||||
TEST(BackendExtensionTest, TestRegisterOp) {
|
||||
EXPECT_ANY_THROW(empty({5, 5}, at::kMSNPU));
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override);
|
||||
Tensor a = empty({5, 5}, at::kMSNPU);
|
||||
ASSERT_EQ(a.device().type(), at::kMSNPU);
|
||||
ASSERT_EQ(a.device().index(), 1);
|
||||
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
|
||||
ASSERT_EQ(test_int, 1);
|
||||
|
||||
EXPECT_ANY_THROW(empty_like(a, at::kMSNPU));
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"empty_like(Tensor self, TensorOptions options) -> Tensor", &empty_like_override);
|
||||
Tensor b = empty_like(a, at::kMSNPU);
|
||||
ASSERT_EQ(test_int, 2);
|
||||
|
||||
EXPECT_ANY_THROW(add(a, b));
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"add(Tensor self, Tensor other, Scalar alpha) -> Tensor", &add_override);
|
||||
add(a, b);
|
||||
ASSERT_EQ(test_int, 3);
|
||||
|
||||
// Ensure that non-MSNPU operator still works
|
||||
Tensor d = empty({5, 5}, at::kCPU);
|
||||
ASSERT_EQ(d.device().type(), at::kCPU);
|
||||
|
||||
// Attempt to register on a schema that has already has a function
|
||||
EXPECT_ANY_THROW(
|
||||
register_extension_backend_op(
|
||||
Backend::MSNPU,
|
||||
"empty(IntList size, TensorOptions options) -> Tensor", &empty_override)
|
||||
);
|
||||
}
|
||||
|
|
@ -17,6 +17,7 @@ VALGRIND=${VALGRIND:=ON}
|
|||
./scalar_tensor_test
|
||||
./tensor_interop_test
|
||||
./undefined_tensor_test
|
||||
./extension_backend_test
|
||||
if [[ -x ./cudnn_test ]]; then
|
||||
./cudnn_test
|
||||
fi
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ namespace c10 {
|
|||
* would make sense in your use case. If it doesn't make sense, maybe
|
||||
* you want DeviceType.
|
||||
*/
|
||||
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, Undefined, NumOptions };
|
||||
enum class Backend { CPU, CUDA, HIP, SparseCPU, SparseCUDA, SparseHIP, MSNPU, Undefined, NumOptions };
|
||||
|
||||
static inline Backend toSparse(Backend b) {
|
||||
switch (b) {
|
||||
|
|
@ -49,6 +49,8 @@ static inline Backend toDense(Backend b) {
|
|||
return Backend::CUDA;
|
||||
case Backend::HIP:
|
||||
return Backend::HIP;
|
||||
case Backend::MSNPU:
|
||||
return Backend::MSNPU;
|
||||
case Backend::SparseCPU:
|
||||
return Backend::CPU;
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -67,6 +69,8 @@ static inline Backend tensorTypeIdToBackend(TensorTypeId t) {
|
|||
return Backend::CUDA;
|
||||
} else if (t == HIPTensorId()) {
|
||||
return Backend::HIP;
|
||||
} else if (t == MSNPUTensorId()) {
|
||||
return Backend::MSNPU;
|
||||
} else if (t == SparseCPUTensorId()) {
|
||||
return Backend::SparseCPU;
|
||||
} else if (t == SparseCUDATensorId()) {
|
||||
|
|
@ -88,6 +92,8 @@ static inline TensorTypeId backendToTensorTypeId(Backend b) {
|
|||
return CUDATensorId();
|
||||
case Backend::HIP:
|
||||
return HIPTensorId();
|
||||
case Backend::MSNPU:
|
||||
return MSNPUTensorId();
|
||||
case Backend::SparseCPU:
|
||||
return SparseCPUTensorId();
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -109,6 +115,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
|
|||
return DeviceType::CUDA;
|
||||
case Backend::HIP:
|
||||
return DeviceType::HIP;
|
||||
case Backend::MSNPU:
|
||||
return DeviceType::MSNPU;
|
||||
case Backend::SparseCPU:
|
||||
return DeviceType::CPU;
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -130,6 +138,8 @@ static inline Backend deviceTypeToBackend(DeviceType d) {
|
|||
return Backend::CUDA;
|
||||
case DeviceType::HIP:
|
||||
return Backend::HIP;
|
||||
case DeviceType::MSNPU:
|
||||
return Backend::MSNPU;
|
||||
default:
|
||||
AT_ERROR("Unknown device type ", d);
|
||||
}
|
||||
|
|
@ -149,6 +159,8 @@ static inline Backend backendToCPU(Backend b) {
|
|||
return Backend::SparseCPU;
|
||||
case Backend::SparseHIP:
|
||||
return Backend::SparseCPU;
|
||||
case Backend::MSNPU:
|
||||
return Backend::CPU;
|
||||
case Backend::Undefined:
|
||||
return Backend::Undefined;
|
||||
default:
|
||||
|
|
@ -161,6 +173,7 @@ static inline Backend backendToCUDA(Backend b) {
|
|||
case Backend::CPU:
|
||||
case Backend::CUDA:
|
||||
case Backend::HIP:
|
||||
case Backend::MSNPU:
|
||||
return Backend::CUDA;
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -178,6 +191,7 @@ static inline Backend backendToHIP(Backend b) {
|
|||
case Backend::CPU:
|
||||
case Backend::CUDA:
|
||||
case Backend::HIP:
|
||||
case Backend::MSNPU:
|
||||
return Backend::HIP;
|
||||
case Backend::SparseCPU:
|
||||
case Backend::SparseCUDA:
|
||||
|
|
@ -193,6 +207,7 @@ static inline Backend backendToHIP(Backend b) {
|
|||
constexpr DeviceType kCPU = DeviceType::CPU;
|
||||
constexpr DeviceType kCUDA = DeviceType::CUDA;
|
||||
constexpr DeviceType kHIP = DeviceType::HIP;
|
||||
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
|
||||
|
||||
static inline const char* toString(Backend b) {
|
||||
switch (b) {
|
||||
|
|
@ -202,6 +217,8 @@ static inline const char* toString(Backend b) {
|
|||
return "CUDA";
|
||||
case Backend::HIP:
|
||||
return "HIP";
|
||||
case Backend::MSNPU:
|
||||
return "MSNPU";
|
||||
case Backend::SparseCPU:
|
||||
return "SparseCPU";
|
||||
case Backend::SparseCUDA:
|
||||
|
|
|
|||
|
|
@ -23,6 +23,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
|
|||
return lower_case ? "hip" : "HIP";
|
||||
case DeviceType::FPGA:
|
||||
return lower_case ? "fpga" : "FPGA";
|
||||
case DeviceType::MSNPU:
|
||||
return lower_case ? "msnpu" : "MSNPU";
|
||||
default:
|
||||
AT_ERROR(
|
||||
"Unknown device: ",
|
||||
|
|
@ -53,6 +55,7 @@ bool isValidDeviceType(DeviceType d) {
|
|||
case DeviceType::IDEEP:
|
||||
case DeviceType::HIP:
|
||||
case DeviceType::FPGA:
|
||||
case DeviceType::MSNPU:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -21,11 +21,12 @@ enum class DeviceType : int16_t {
|
|||
IDEEP = 5, // IDEEP.
|
||||
HIP = 6, // AMD HIP
|
||||
FPGA = 7, // FPGA
|
||||
MSNPU = 8, // MSNPU
|
||||
// NB: If you add more devices:
|
||||
// - Change the implementations of DeviceTypeName and isValidDeviceType
|
||||
// in DeviceType.cpp
|
||||
// - Change the number below
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 8,
|
||||
COMPILE_TIME_MAX_DEVICE_TYPES = 9,
|
||||
ONLY_FOR_TEST = 20901, // This device type is only for test.
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -357,7 +357,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
int64_t get_device() const {
|
||||
// NB: This method is not virtual and tries to avoid dispatches in the common case for perf.
|
||||
const auto tid = type_id();
|
||||
if (tid == CUDATensorId() || tid == HIPTensorId()) {
|
||||
if (tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
|
||||
// TODO: #12934 investigate caching device on TensorImpl to avoid this vdispatch.
|
||||
return storage().device().index();
|
||||
}
|
||||
|
|
@ -369,7 +369,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
// TODO: This is a little convoluted so it would be good to investigate
|
||||
// caching device on TensorImpl (#12934) to speed up device() calls in all cases.
|
||||
const auto tid = type_id();
|
||||
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId()) {
|
||||
if (tid == CPUTensorId() || tid == CUDATensorId() || tid == HIPTensorId() || tid == MSNPUTensorId()) {
|
||||
// NB: storage(), not storage_, b/c of Variable.
|
||||
const auto& mystorage = storage();
|
||||
if (mystorage) {
|
||||
|
|
|
|||
|
|
@ -326,13 +326,15 @@ struct C10_API TensorOptions {
|
|||
|
||||
// Resolves the ATen backend specified by the current construction axes.
|
||||
Backend backend() const noexcept {
|
||||
Backend backend;
|
||||
if (device().type() == Device::Type::CPU) {
|
||||
backend = (layout() == kStrided) ? Backend::CPU : Backend::SparseCPU;
|
||||
} else {
|
||||
backend = (layout() == kStrided) ? Backend::CUDA : Backend::SparseCUDA;
|
||||
Backend backend = deviceTypeToBackend(device().type());
|
||||
switch (layout()) {
|
||||
case kStrided:
|
||||
return backend;
|
||||
case kSparse:
|
||||
return toSparse(backend);
|
||||
default:
|
||||
return backend;
|
||||
}
|
||||
return backend;
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -507,6 +509,8 @@ inline TensorTypeId computeTensorTypeId(TensorOptions options) {
|
|||
return IDEEPTensorId();
|
||||
case DeviceType::HIP:
|
||||
return HIPTensorId();
|
||||
case DeviceType::MSNPU:
|
||||
return MSNPUTensorId();
|
||||
default:
|
||||
AT_ERROR("Unsupported device type for dense layout: ", options.device().type());
|
||||
}
|
||||
|
|
@ -543,6 +547,8 @@ inline DeviceType computeDeviceType(TensorTypeId tid) {
|
|||
return DeviceType::IDEEP;
|
||||
} else if (tid == HIPTensorId()) {
|
||||
return DeviceType::HIP;
|
||||
} else if (tid == MSNPUTensorId()) {
|
||||
return DeviceType::MSNPU;
|
||||
} else if (tid == SparseCPUTensorId()) {
|
||||
return DeviceType::CPU;
|
||||
} else if (tid == SparseCUDATensorId()) {
|
||||
|
|
|
|||
|
|
@ -68,5 +68,6 @@ C10_DEFINE_TENSOR_TYPE(OpenCLTensorId);
|
|||
C10_DEFINE_TENSOR_TYPE(IDEEPTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(HIPTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(SparseHIPTensorId);
|
||||
C10_DEFINE_TENSOR_TYPE(MSNPUTensorId);
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -107,6 +107,7 @@ C10_DECLARE_TENSOR_TYPE(OpenCLTensorId); // Caffe2 only
|
|||
C10_DECLARE_TENSOR_TYPE(IDEEPTensorId); // Caffe2 only
|
||||
C10_DECLARE_TENSOR_TYPE(HIPTensorId); // PyTorch/Caffe2 supported
|
||||
C10_DECLARE_TENSOR_TYPE(SparseHIPTensorId); // PyTorch only
|
||||
C10_DECLARE_TENSOR_TYPE(MSNPUTensorId); // PyTorch only
|
||||
|
||||
} // namespace c10
|
||||
|
||||
|
|
|
|||
|
|
@ -178,8 +178,9 @@ enum DeviceTypeProto {
|
|||
PROTO_IDEEP = 5; // IDEEP.
|
||||
PROTO_HIP = 6; // AMD HIP
|
||||
PROTO_FPGA = 7; // FPGA
|
||||
PROTO_MSNPU = 8; // MSNPU
|
||||
// Change the following number if you add more devices in the code.
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 8;
|
||||
PROTO_COMPILE_TIME_MAX_DEVICE_TYPES = 9;
|
||||
PROTO_ONLY_FOR_TEST = 20901; // This device type is only for test.
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user