mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
## Job Test running on most CI jobs. ## Test binary * `test_main.cpp`: entry for gtest * `test_operator_registration.cpp`: test cases for gtest ## Helper sources * `operator_registry.h/cpp`: simple operator registry for testing purpose. * `Evalue.h`: a boxed data type that wraps ATen types, for testing purpose. * `selected_operators.yaml`: operators Executorch care about so far, we should cover all of them. ## Templates * `NativeFunctions.h`: for generating headers for native functions. (not compiled in the test, since we will be using `libtorch`) * `RegisterCodegenUnboxedKernels.cpp`: for registering boxed operators. * `Functions.h`: for declaring operator C++ APIs. Generated `Functions.h` merely wraps `ATen/Functions.h`. ## Build files * `CMakeLists.txt`: generate code to register ops. * `build.sh`: driver file, to be called by CI job. Pull Request resolved: https://github.com/pytorch/pytorch/pull/89596 Approved by: https://github.com/ezyang
71 lines
1.4 KiB
C++
71 lines
1.4 KiB
C++
#pragma once
|
|
|
|
#include <cstring>
|
|
|
|
#include <c10/util/ArrayRef.h>
|
|
#include "Evalue.h"
|
|
#include <functional>
|
|
#include <map>
|
|
|
|
namespace torch {
|
|
namespace executor {
|
|
|
|
using OpFunction = std::function<void(EValue**)>;
|
|
|
|
template<typename T>
|
|
using ArrayRef = at::ArrayRef<T>;
|
|
|
|
#define EXECUTORCH_SCOPE_PROF(x)
|
|
|
|
struct Operator {
|
|
const char* name_;
|
|
OpFunction op_;
|
|
|
|
Operator() = default;
|
|
|
|
/**
|
|
* We are doing a copy of the string pointer instead of duplicating the string
|
|
* itself, we require the lifetime of the operator name to be at least as long
|
|
* as the operator registry.
|
|
*/
|
|
explicit Operator(const char* name, OpFunction func)
|
|
: name_(name), op_(func) {}
|
|
};
|
|
|
|
/**
|
|
* See OperatorRegistry::hasOpsFn()
|
|
*/
|
|
bool hasOpsFn(const char* name);
|
|
|
|
/**
|
|
* See OperatorRegistry::getOpsFn()
|
|
*/
|
|
OpFunction& getOpsFn(const char* name);
|
|
|
|
|
|
[[nodiscard]] bool register_operators(const ArrayRef<Operator>&);
|
|
|
|
struct OperatorRegistry {
|
|
public:
|
|
OperatorRegistry() : operatorRegSize_(0) {}
|
|
|
|
bool register_operators(const ArrayRef<Operator>&);
|
|
|
|
/**
|
|
* Checks whether an operator with a given name is registered
|
|
*/
|
|
bool hasOpsFn(const char* name);
|
|
|
|
/**
|
|
* Checks whether an operator with a given name is registered
|
|
*/
|
|
OpFunction& getOpsFn(const char* name);
|
|
|
|
private:
|
|
std::map<const char*, OpFunction> operators_map_;
|
|
uint32_t operatorRegSize_;
|
|
};
|
|
|
|
} // namespace executor
|
|
} // namespace torch
|