[torchgen] Add CI job to make sure torchgen works for Executorch op registration (#89596)

## Job

Test running on most CI jobs.

## Test binary

* `test_main.cpp`: entry for gtest
* `test_operator_registration.cpp`: test cases for gtest

## Helper sources

* `operator_registry.h/cpp`: simple operator registry for testing purpose.
* `Evalue.h`: a boxed data type that wraps ATen types, for testing purpose.
* `selected_operators.yaml`: operators Executorch care about so far, we should cover all of them.

## Templates

* `NativeFunctions.h`: for generating headers for native functions. (not compiled in the test, since we will be using `libtorch`)
* `RegisterCodegenUnboxedKernels.cpp`: for registering boxed operators.
* `Functions.h`: for declaring operator C++ APIs. Generated `Functions.h` merely wraps `ATen/Functions.h`.

## Build files

* `CMakeLists.txt`: generate code to register ops.
* `build.sh`: driver file, to be called by CI job.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89596
Approved by: https://github.com/ezyang
This commit is contained in:
Mengwei Liu 2022-12-20 21:50:39 +00:00 committed by PyTorch MergeBot
parent 37ea99cd25
commit 2f154f68ea
15 changed files with 1264 additions and 1 deletions

View File

@ -748,6 +748,13 @@ test_docs_test() {
.jenkins/pytorch/docs-test.sh
}
test_executorch() {
# Test torchgen generated code for Executorch.
echo "Testing Executorch op registration"
"$BUILD_BIN_DIR"/test_edge_op_registration
assert_git_not_dirty
}
if ! [[ "${BUILD_ENVIRONMENT}" == *libtorch* || "${BUILD_ENVIRONMENT}" == *-bazel-* || "${BUILD_ENVIRONMENT}" == *-tsan* ]]; then
(cd test && python -c "import torch; print(torch.__config__.show())")
(cd test && python -c "import torch; print(torch.__config__.parallel_info())")
@ -875,4 +882,5 @@ else
test_custom_backend
test_torch_function_benchmark
test_benchmarks
test_executorch
fi

View File

@ -455,7 +455,7 @@ option(
TRACING_BASED
"Master flag to build Lite Interpreter with tracing build option"
OFF)
option(BUILD_EXECUTORCH "Master flag to build Executorch" ON)
# This is a fix for a rare build issue on Ubuntu:
# symbol lookup error: miniconda3/envs/pytorch-py3.7/lib/libmkl_intel_lp64.so: undefined symbol: mkl_blas_dsyrk
# https://software.intel.com/en-us/articles/symbol-lookup-error-when-linking-intel-mkl-with-gcc-on-ubuntu

View File

@ -1126,6 +1126,12 @@ install(FILES
"${TORCH_SRC_DIR}/custom_class_detail.h"
DESTINATION ${TORCH_INSTALL_INCLUDE_DIR}/torch)
if(BUILD_TEST)
if(BUILD_EXECUTORCH)
add_subdirectory(
${TORCH_ROOT}/test/edge
${CMAKE_BINARY_DIR}/test_edge_op_registration
)
endif()
if(BUILD_LITE_INTERPRETER)
add_subdirectory(
${TORCH_ROOT}/test/cpp/lite_interpreter_runtime

View File

@ -536,6 +536,8 @@ class build_ext(setuptools.command.build_ext.build_ext):
report('-- Using static dispatch with backend {}'.format(cmake_cache_vars['STATIC_DISPATCH_BACKEND']))
if cmake_cache_vars['USE_LIGHTWEIGHT_DISPATCH']:
report('-- Using lightweight dispatch')
if cmake_cache_vars['BUILD_EXECUTORCH']:
report('-- Building Executorch')
if cmake_cache_vars['USE_ITT']:
report('-- Using ITT')

70
test/edge/CMakeLists.txt Normal file
View File

@ -0,0 +1,70 @@
cmake_minimum_required(VERSION 3.1)
set(TORCH_ROOT ${CMAKE_CURRENT_LIST_DIR}/../..)
set(TEST_ROOT ${TORCH_ROOT}/test/edge)
set(OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/out)
file(GLOB_RECURSE all_python "${TORCH_ROOT}/torchgen/*.py")
# Generate unboxing kernels
set(GEN_COMMAND
"${PYTHON_EXECUTABLE}" -m torchgen.gen_executorch
--source-path=${TEST_ROOT}
--install_dir=${OUTPUT_DIRECTORY}
--tags-path=${TORCH_ROOT}/aten/src/ATen/native/tags.yaml
--aten_yaml_path=${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml
--use_aten_lib
--op_selection_yaml_path=${TEST_ROOT}/selected_operators.yaml
)
set(GEN_COMMAND_sources
${OUTPUT_DIRECTORY}/RegisterCodegenUnboxedKernelsEverything.cpp
${OUTPUT_DIRECTORY}/Functions.h
${OUTPUT_DIRECTORY}/NativeFunctions.h
)
message(STATUS "Generating sources for unboxing kernels ${GEN_COMMAND}")
add_custom_command(
COMMENT "Generating sources"
OUTPUT ${GEN_COMMAND_sources}
COMMAND ${GEN_COMMAND}
DEPENDS
${all_python}
${TORCH_ROOT}/aten/src/ATen/native/native_functions.yaml
${TORCH_ROOT}/aten/src/ATen/native/tags.yaml
${TEST_ROOT}/templates/Functions.h
${TEST_ROOT}/templates/NativeFunctions.h
${TEST_ROOT}/templates/RegisterCodegenUnboxedKernels.cpp
WORKING_DIRECTORY ${TORCH_ROOT}
)
add_custom_target(unbox_target DEPENDS ${GEN_COMMAND_sources})
add_library(unbox_lib STATIC
${GEN_COMMAND_sources}
${TEST_ROOT}/operator_registry.cpp
)
target_include_directories(unbox_lib PUBLIC ${TEST_ROOT} ${ATen_CPU_INCLUDE})
target_link_libraries(unbox_lib PUBLIC torch_cpu)
target_compile_definitions(unbox_lib PUBLIC USE_ATEN_LIB)
add_executable(test_edge_op_registration
${TEST_ROOT}/test_operator_registration.cpp
${TEST_ROOT}/test_main.cpp
)
target_compile_definitions(test_edge_op_registration PRIVATE USE_GTEST)
set(TEST_DEPENDENCIES gtest unbox_lib)
target_link_libraries(test_edge_op_registration PRIVATE
${TEST_DEPENDENCIES}
)
if(CMAKE_CXX_COMPILER_ID MATCHES "AppleClang")
target_link_options(test_edge_op_registration PRIVATE
"-Wl,-force_load,$<TARGET_FILE:unbox_lib>"
)
elseif(CMAKE_CXX_COMPILER_ID MATCHES "Clang|GNU")
target_link_options(test_edge_op_registration PRIVATE
"-Wl,--whole-archive,$<TARGET_FILE:unbox_lib>,--no-whole-archive"
)
endif()
if(INSTALL_TEST)
install(TARGETS test_edge_op_registration DESTINATION bin)
endif()

479
test/edge/Evalue.h Normal file
View File

@ -0,0 +1,479 @@
#pragma once
#include <ATen/ATen.h>
/**
* WARNING: EValue is a class used by Executorch, for its boxed operators. It
* contains similar logic as `IValue` in PyTorch, by providing APIs to convert
* boxed values to unboxed values.
*
* It's mirroring a fbcode internal source file
* [`EValue.h`](https://www.internalfb.com/code/fbsource/xplat/executorch/core/values/Evalue.h).
*
* The reason why we are mirroring this class, is to make sure we have CI job
* coverage on torchgen logic, given that torchgen is used for both Executorch
* and PyTorch.
*
* If any of the logic here needs to be changed, please update fbcode version of
* `Evalue.h` as well. These two versions will be merged as soon as Executorch
* is in OSS (hopefully by Q2 2023).
*/
namespace torch {
namespace executor {
#define ET_CHECK_MSG TORCH_CHECK_MSG
#define EXECUTORCH_FORALL_TAGS(_) \
_(None) \
_(Tensor) \
_(String) \
_(Double) \
_(Int) \
_(Bool) \
_(ListBool) \
_(ListDouble) \
_(ListInt) \
_(ListTensor) \
_(ListScalar) \
_(ListOptionalTensor)
enum class Tag : uint32_t {
#define DEFINE_TAG(x) x,
EXECUTORCH_FORALL_TAGS(DEFINE_TAG)
#undef DEFINE_TAG
};
struct EValue;
template <typename T>
struct evalue_to_const_ref_overload_return {
using type = T;
};
template <>
struct evalue_to_const_ref_overload_return<at::Tensor> {
using type = const at::Tensor&;
};
template <typename T>
struct evalue_to_ref_overload_return {
using type = T;
};
template <>
struct evalue_to_ref_overload_return<at::Tensor> {
using type = at::Tensor&;
};
/*
* Helper class used to correlate EValues in the executor table, with the
* unwrapped list of the proper type. Because values in the runtime's values
* table can change during execution, we cannot statically allocate list of
* objects at deserialization. Imagine the serialized list says index 0 in the
* value table is element 2 in the list, but during execution the value in
* element 2 changes (in the case of tensor this means the TensorImpl* stored in
* the tensor changes). To solve this instead they must be created dynamically
* whenever they are used.
*/
template <typename T>
class EValObjectList {
public:
EValObjectList() = default;
/*
* Wrapped_vals is a list of pointers into the values table of the runtime
* whose destinations correlate with the elements of the list, unwrapped_vals
* is a container of the same size whose serves as memory to construct the
* unwrapped vals.
*/
EValObjectList(EValue** wrapped_vals, T* unwrapped_vals, int size)
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
/*
* Constructs and returns the list of T specified by the EValue pointers
*/
at::ArrayRef<T> get() const;
private:
// Source of truth for the list
at::ArrayRef<EValue*> wrapped_vals_;
// Same size as wrapped_vals
mutable T* unwrapped_vals_;
};
// Aggregate typing system similar to IValue only slimmed down with less
// functionality, no dependencies on atomic, and fewer supported types to better
// suit embedded systems (ie no intrusive ptr)
struct EValue {
union Payload {
// When in ATen mode at::Tensor is not trivially copyable, this nested union
// lets us handle tensor as a special case while leaving the rest of the
// fields in a simple state instead of requiring a switch on tag everywhere.
union TriviallyCopyablePayload {
TriviallyCopyablePayload() : as_int(0) {}
// Scalar supported through these 3 types
int64_t as_int;
double as_double;
bool as_bool;
// TODO(jakeszwe): convert back to pointers to optimize size of this
// struct
at::ArrayRef<char> as_string;
at::ArrayRef<int64_t> as_int_list;
at::ArrayRef<double> as_double_list;
at::ArrayRef<bool> as_bool_list;
EValObjectList<at::Tensor> as_tensor_list;
EValObjectList<at::optional<at::Tensor>> as_list_optional_tensor;
} copyable_union;
// Since a Tensor just holds a TensorImpl*, there's no value to use Tensor*
// here.
at::Tensor as_tensor;
Payload() {}
~Payload() {}
};
// Data storage and type tag
Payload payload;
Tag tag;
// Basic ctors and assignments
EValue(const EValue& rhs) : EValue(rhs.payload, rhs.tag) {}
EValue(EValue&& rhs) noexcept : tag(rhs.tag) {
moveFrom(std::move(rhs));
}
EValue& operator=(EValue&& rhs) & noexcept {
if (&rhs == this) {
return *this;
}
destroy();
moveFrom(std::move(rhs));
return *this;
}
EValue& operator=(EValue const& rhs) & {
// Define copy assignment through copy ctor and move assignment
*this = EValue(rhs);
return *this;
}
~EValue() {
destroy();
}
/****** None Type ******/
EValue() : tag(Tag::None) {
payload.copyable_union.as_int = 0;
}
bool isNone() const {
return tag == Tag::None;
}
/****** Int Type ******/
/*implicit*/ EValue(int64_t i) : tag(Tag::Int) {
payload.copyable_union.as_int = i;
}
bool isInt() const {
return tag == Tag::Int;
}
int64_t toInt() const {
ET_CHECK_MSG(isInt(), "EValue is not an int.");
return payload.copyable_union.as_int;
}
/****** Double Type ******/
/*implicit*/ EValue(double d) : tag(Tag::Double) {
payload.copyable_union.as_double = d;
}
bool isDouble() const {
return tag == Tag::Double;
}
double toDouble() const {
ET_CHECK_MSG(isDouble(), "EValue is not a Double.");
return payload.copyable_union.as_double;
}
/****** Bool Type ******/
/*implicit*/ EValue(bool b) : tag(Tag::Bool) {
payload.copyable_union.as_bool = b;
}
bool isBool() const {
return tag == Tag::Bool;
}
bool toBool() const {
ET_CHECK_MSG(isBool(), "EValue is not a Bool.");
return payload.copyable_union.as_bool;
}
/****** Scalar Type ******/
/// Construct an EValue using the implicit value of a Scalar.
/*implicit*/ EValue(at::Scalar s) {
if (s.isIntegral(false)) {
tag = Tag::Int;
payload.copyable_union.as_int = s.to<int64_t>();
} else if (s.isFloatingPoint()) {
tag = Tag::Double;
payload.copyable_union.as_double = s.to<double>();
} else if (s.isBoolean()) {
tag = Tag::Bool;
payload.copyable_union.as_bool = s.to<bool>();
} else {
ET_CHECK_MSG(false, "Scalar passed to EValue is not initialized.");
}
}
bool isScalar() const {
return tag == Tag::Int || tag == Tag::Double || tag == Tag::Bool;
}
at::Scalar toScalar() const {
// Convert from implicit value to Scalar using implicit constructors.
if (isDouble()) {
return toDouble();
} else if (isInt()) {
return toInt();
} else if (isBool()) {
return toBool();
} else {
ET_CHECK_MSG(false, "EValue is not a Scalar.");
return c10::Scalar();
}
}
/****** Tensor Type ******/
/*implicit*/ EValue(at::Tensor t) : tag(Tag::Tensor) {
// When built in aten mode, at::Tensor has a non trivial constructor
// destructor, so regular assignment to a union field is UB. Instead we must
// go through placement new (which causes a refcount bump).
new (&payload.as_tensor) at::Tensor(t);
}
bool isTensor() const {
return tag == Tag::Tensor;
}
at::Tensor toTensor() && {
ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
return std::move(payload.as_tensor);
}
at::Tensor& toTensor() & {
ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
return payload.as_tensor;
}
const at::Tensor& toTensor() const& {
ET_CHECK_MSG(isTensor(), "EValue is not a Tensor.");
return payload.as_tensor;
}
/****** String Type ******/
/*implicit*/ EValue(const char* s, size_t size) : tag(Tag::String) {
payload.copyable_union.as_string = at::ArrayRef<char>(s, size);
}
bool isString() const {
return tag == Tag::String;
}
at::string_view toString() const {
ET_CHECK_MSG(isString(), "EValue is not a String.");
return at::string_view(
payload.copyable_union.as_string.data(),
payload.copyable_union.as_string.size());
}
/****** Int List Type ******/
/*implicit*/ EValue(at::ArrayRef<int64_t> i) : tag(Tag::ListInt) {
payload.copyable_union.as_int_list = i;
}
bool isIntList() const {
return tag == Tag::ListInt;
}
at::ArrayRef<int64_t> toIntList() const {
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
return payload.copyable_union.as_int_list;
}
/****** Bool List Type ******/
/*implicit*/ EValue(at::ArrayRef<bool> b) : tag(Tag::ListBool) {
payload.copyable_union.as_bool_list = b;
}
bool isBoolList() const {
return tag == Tag::ListBool;
}
at::ArrayRef<bool> toBoolList() const {
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
return payload.copyable_union.as_bool_list;
}
/****** Double List Type ******/
/*implicit*/ EValue(at::ArrayRef<double> d) : tag(Tag::ListDouble) {
payload.copyable_union.as_double_list = d;
}
bool isDoubleList() const {
return tag == Tag::ListDouble;
}
at::ArrayRef<double> toDoubleList() const {
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
return payload.copyable_union.as_double_list;
}
/****** Tensor List Type ******/
/*implicit*/ EValue(EValObjectList<at::Tensor> t) : tag(Tag::ListTensor) {
payload.copyable_union.as_tensor_list = t;
}
bool isTensorList() const {
return tag == Tag::ListTensor;
}
at::ArrayRef<at::Tensor> toTensorList() const {
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
return payload.copyable_union.as_tensor_list.get();
}
/****** List Optional Tensor Type ******/
/*implicit*/ EValue(EValObjectList<at::optional<at::Tensor>> t)
: tag(Tag::ListOptionalTensor) {
payload.copyable_union.as_list_optional_tensor = t;
}
bool isListOptionalTensor() const {
return tag == Tag::ListOptionalTensor;
}
at::ArrayRef<at::optional<at::Tensor>> toListOptionalTensor() {
return payload.copyable_union.as_list_optional_tensor.get();
}
/****** ScalarType Type ******/
at::ScalarType toScalarType() const {
ET_CHECK_MSG(isInt(), "EValue is not a ScalarType.");
return static_cast<at::ScalarType>(payload.copyable_union.as_int);
}
/****** MemoryFormat Type ******/
at::MemoryFormat toMemoryFormat() const {
ET_CHECK_MSG(isInt(), "EValue is not a MemoryFormat.");
return static_cast<at::MemoryFormat>(payload.copyable_union.as_int);
}
template <typename T>
T to() &&;
template <typename T>
typename evalue_to_ref_overload_return<T>::type to() &;
/**
* Converts the EValue to an optional object that can represent both T and
* an uninitialized state.
*/
template <typename T>
inline at::optional<T> toOptional() {
if (this->isNone()) {
return at::nullopt;
}
return this->to<T>();
}
private:
// Pre cond: the payload value has had its destructor called
void clearToNone() noexcept {
payload.copyable_union.as_int = 0;
tag = Tag::None;
}
// Shared move logic
void moveFrom(EValue&& rhs) noexcept {
if (rhs.isTensor()) {
new (&payload.as_tensor) at::Tensor(std::move(rhs.payload.as_tensor));
rhs.payload.as_tensor.~Tensor();
} else {
payload.copyable_union = rhs.payload.copyable_union;
}
tag = rhs.tag;
rhs.clearToNone();
}
// Destructs stored tensor if there is one
void destroy() {
// Necessary for ATen tensor to refcount decrement the intrusive_ptr to
// tensorimpl that got a refcount increment when we placed it in the evalue,
// no-op if executorch tensor #ifdef could have a
// minor performance bump for a code maintainability hit
if (isTensor()) {
payload.as_tensor.~Tensor();
} else if (isTensorList()) {
for (auto& tensor : toTensorList()) {
tensor.~Tensor();
}
} else if (isListOptionalTensor()) {
for (auto& optional_tensor : toListOptionalTensor()) {
optional_tensor.~optional();
}
}
}
EValue(const Payload& p, Tag t) : tag(t) {
if (isTensor()) {
new (&payload.as_tensor) at::Tensor(p.as_tensor);
} else {
payload.copyable_union = p.copyable_union;
}
}
};
#define EVALUE_DEFINE_TO(T, method_name) \
template <> \
inline evalue_to_ref_overload_return<T>::type EValue::to<T>()& { \
return static_cast<T>(this->method_name()); \
}
template <>
inline at::Tensor& EValue::to<at::Tensor>() & {
return this->toTensor();
}
EVALUE_DEFINE_TO(at::Scalar, toScalar)
EVALUE_DEFINE_TO(int64_t, toInt)
EVALUE_DEFINE_TO(bool, toBool)
EVALUE_DEFINE_TO(double, toDouble)
EVALUE_DEFINE_TO(at::string_view, toString)
EVALUE_DEFINE_TO(at::ScalarType, toScalarType)
EVALUE_DEFINE_TO(at::MemoryFormat, toMemoryFormat)
EVALUE_DEFINE_TO(at::optional<at::Tensor>, toOptional<at::Tensor>)
EVALUE_DEFINE_TO(at::ArrayRef<int64_t>, toIntList)
EVALUE_DEFINE_TO(
at::optional<at::ArrayRef<int64_t>>,
toOptional<at::ArrayRef<int64_t>>)
EVALUE_DEFINE_TO(
at::optional<at::ArrayRef<double>>,
toOptional<at::ArrayRef<double>>)
EVALUE_DEFINE_TO(at::ArrayRef<at::optional<at::Tensor>>, toListOptionalTensor)
EVALUE_DEFINE_TO(at::ArrayRef<double>, toDoubleList)
#undef EVALUE_DEFINE_TO
template <typename T>
at::ArrayRef<T> EValObjectList<T>::get() const {
for (size_t i = 0; i < wrapped_vals_.size(); i++) {
unwrapped_vals_[i] = wrapped_vals_[i]->template to<T>();
}
return at::ArrayRef<T>{unwrapped_vals_, wrapped_vals_.size()};
}
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,45 @@
#include <c10/util/Exception.h>
#include <operator_registry.h>
namespace torch {
namespace executor {
OperatorRegistry& getOperatorRegistry() {
static OperatorRegistry operator_registry;
return operator_registry;
}
bool register_operators(const ArrayRef<Operator>& operators) {
return getOperatorRegistry().register_operators(operators);
}
bool OperatorRegistry::register_operators(
const ArrayRef<Operator>& operators) {
for (const auto& op : operators) {
this->operators_map_[op.name_] = op.op_;
}
return true;
}
bool hasOpsFn(const char* name) {
return getOperatorRegistry().hasOpsFn(name);
}
bool OperatorRegistry::hasOpsFn(const char* name) {
auto op = this->operators_map_.find(name);
return op != this->operators_map_.end();
}
OpFunction& getOpsFn(const char* name) {
return getOperatorRegistry().getOpsFn(name);
}
OpFunction& OperatorRegistry::getOpsFn(const char* name) {
auto op = this->operators_map_.find(name);
TORCH_CHECK_MSG(op != this->operators_map_.end(), "Operator not found!");
return op->second;
}
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,70 @@
#pragma once
#include <cstring>
#include <c10/util/ArrayRef.h>
#include "Evalue.h"
#include <functional>
#include <map>
namespace torch {
namespace executor {
using OpFunction = std::function<void(EValue**)>;
template<typename T>
using ArrayRef = at::ArrayRef<T>;
#define EXECUTORCH_SCOPE_PROF(x)
struct Operator {
const char* name_;
OpFunction op_;
Operator() = default;
/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
*/
explicit Operator(const char* name, OpFunction func)
: name_(name), op_(func) {}
};
/**
* See OperatorRegistry::hasOpsFn()
*/
bool hasOpsFn(const char* name);
/**
* See OperatorRegistry::getOpsFn()
*/
OpFunction& getOpsFn(const char* name);
[[nodiscard]] bool register_operators(const ArrayRef<Operator>&);
struct OperatorRegistry {
public:
OperatorRegistry() : operatorRegSize_(0) {}
bool register_operators(const ArrayRef<Operator>&);
/**
* Checks whether an operator with a given name is registered
*/
bool hasOpsFn(const char* name);
/**
* Checks whether an operator with a given name is registered
*/
OpFunction& getOpsFn(const char* name);
private:
std::map<const char*, OpFunction> operators_map_;
uint32_t operatorRegSize_;
};
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,450 @@
build_features: []
custom_classes: []
include_all_non_op_selectives: false
include_all_operators: false
kernel_metadata: {}
operators:
aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::_reshape_alias_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::_softmax.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::_to_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::_unique2.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::add.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::addmm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::avg_pool2d.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::baddbmm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::bitwise_and.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::bmm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::cat.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::clamp.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::clone.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::constant_pad_nd.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::conv1d.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::convolution.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::cumsum.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::detach_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::div.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::embedding.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::eq.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::eq.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::exp.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::expand_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::floor_divide.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::gelu.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::grid_sampler_2d.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::gt.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::index.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::index_put.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::index_select.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::leaky_relu.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::linalg_inv_ex.inverse:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::logit.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::masked_fill.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::max.unary_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::max_pool2d_with_indices.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::mean.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::minimum.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::mm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::mul.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::native_batch_norm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::native_layer_norm.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::ne.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::nonzero.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::permute_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::pixel_shuffle.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::relu.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::remainder.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::repeat.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::round.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::rsub.Scalar_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::select_copy.int_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::sigmoid.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::slice_copy.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::softplus.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::sort.values:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::split_copy.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::split_with_sizes_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::stack.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::sub.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::sum.IntList_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::tanh.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::topk.values:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::transpose_copy.int_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::unbind_copy.int_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::unsafe_split.Tensor_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::unsqueeze_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::upsample_bilinear2d.vec_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::upsample_nearest2d.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::upsample_nearest2d.vec_out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::view_copy.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true
aten::zeros_like.out:
debug_info:
- functions.yaml
include_all_overloads: false
is_root_operator: true
is_used_for_training: true

View File

@ -0,0 +1,25 @@
// clang-format off
#pragma once
#include <ATen/Context.h>
#include <ATen/DeviceGuard.h>
#include <ATen/TensorUtils.h>
#include <ATen/TracerMode.h>
#include <ATen/core/Generator.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Scalar.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Optional.h>
// ${generated_comment}
${static_dispatch_extra_headers}
namespace torch {
namespace executor {
${Functions_declarations}
} // namespace executor
} // namespace torch

View File

@ -0,0 +1,31 @@
#pragma once
// ${generated_comment}
#ifdef TORCH_ASSERT_NO_OPERATORS
#error This change adds a dependency on native_functions.yaml, \
meaning the file will need to be re-compiled every time an operator \
is changed or added. Consider if your change would be better placed in \
another file, or if a more specific header might achieve the same goal. \
See NOTE: [Tensor vs. TensorBase]
#endif
#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS)
#error This change adds a dependency on all pytorch operators, meaning the \
file will need to be re-compiled every time an operator is changed or added. \
Consider including a specific operator from <ATen/ops/{my_operator}_native.h> \
and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS].
#endif
#include <c10/core/Scalar.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorOptions.h>
#include <c10/util/Deprecated.h>
#include <c10/util/Optional.h>
#include <c10/core/QScheme.h>
#include <ATen/core/Reduction.h>
#include <ATen/core/Tensor.h>
#include <tuple>
#include <vector>
${nativeFunctions_declarations}

View File

@ -0,0 +1,25 @@
#include <operator_registry.h>
#include "Functions.h"
namespace torch {
namespace executor {
namespace {
using OpArrayRef = ::at::ArrayRef<::torch::executor::Operator>;
static Operator operators_to_register[] = {
${unboxed_ops} // Generated operators
};
// Explicitly convert to ArrayRef, so that the API can take an empty C array of
// Operators.
static OpArrayRef op_array_ref(
operators_to_register,
operators_to_register + sizeof(operators_to_register) / sizeof(Operator));
// Return value not used. Keep the static variable assignment to register
// operators in static initialization time.
static auto success_with_op_reg = register_operators(op_array_ref);
} // namespace
} // namespace executor
} // namespace torch

18
test/edge/test_main.cpp Normal file
View File

@ -0,0 +1,18 @@
#include <gtest/gtest.h>
std::string add_negative_flag(const std::string& flag) {
std::string filter = ::testing::GTEST_FLAG(filter);
if (filter.find('-') == std::string::npos) {
filter.push_back('-');
} else {
filter.push_back(':');
}
filter += flag;
return filter;
}
int main(int argc, char* argv[]) {
::testing::InitGoogleTest(&argc, argv);
::testing::GTEST_FLAG(filter) = add_negative_flag("*_CUDA:*_MultiCUDA");
return RUN_ALL_TESTS();
}

View File

@ -0,0 +1,28 @@
#include "operator_registry.h"
#include <gtest/gtest.h>
namespace torch {
namespace executor {
// add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
TEST(OperatorRegistrationTest, Add) {
EValue values[4];
values[0] = EValue(at::ones({2, 3}));
values[1] = EValue(at::ones({2, 3}));
values[2] = EValue(int64_t(1));
values[3] = EValue(at::zeros({2, 3}));
ASSERT_TRUE(hasOpsFn("aten::add.out"));
auto op = getOpsFn("aten::add.out");
EValue* kernel_values[4];
for (size_t i = 0; i < 4; i++) {
kernel_values[i] = &values[i];
}
op(kernel_values);
at::Tensor expected = at::ones({2, 3});
expected = at::fill(expected, 2);
ASSERT_TRUE(expected.equal(kernel_values[3]->toTensor()));
}
} // namespace executor
} // namespace torch

View File

@ -14,3 +14,9 @@ def define_targets(rules):
srcs = [":torchgen"],
visibility = ["//visibility:public"],
)
rules.py_binary(
name = "gen_executorch",
srcs = [":torchgen"],
visibility = ["//visibility:public"],
)