pytorch/test/edge/operator_registry.h
Mengwei Liu 2f154f68ea [torchgen] Add CI job to make sure torchgen works for Executorch op registration (#89596)
## Job

Test running on most CI jobs.

## Test binary

* `test_main.cpp`: entry for gtest
* `test_operator_registration.cpp`: test cases for gtest

## Helper sources

* `operator_registry.h/cpp`: simple operator registry for testing purpose.
* `Evalue.h`: a boxed data type that wraps ATen types, for testing purpose.
* `selected_operators.yaml`: operators Executorch care about so far, we should cover all of them.

## Templates

* `NativeFunctions.h`: for generating headers for native functions. (not compiled in the test, since we will be using `libtorch`)
* `RegisterCodegenUnboxedKernels.cpp`: for registering boxed operators.
* `Functions.h`: for declaring operator C++ APIs. Generated `Functions.h` merely wraps `ATen/Functions.h`.

## Build files

* `CMakeLists.txt`: generate code to register ops.
* `build.sh`: driver file, to be called by CI job.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89596
Approved by: https://github.com/ezyang
2022-12-21 03:07:32 +00:00

71 lines
1.4 KiB
C++

#pragma once
#include <cstring>
#include <c10/util/ArrayRef.h>
#include "Evalue.h"
#include <functional>
#include <map>
namespace torch {
namespace executor {
using OpFunction = std::function<void(EValue**)>;
template<typename T>
using ArrayRef = at::ArrayRef<T>;
#define EXECUTORCH_SCOPE_PROF(x)
struct Operator {
const char* name_;
OpFunction op_;
Operator() = default;
/**
* We are doing a copy of the string pointer instead of duplicating the string
* itself, we require the lifetime of the operator name to be at least as long
* as the operator registry.
*/
explicit Operator(const char* name, OpFunction func)
: name_(name), op_(func) {}
};
/**
* See OperatorRegistry::hasOpsFn()
*/
bool hasOpsFn(const char* name);
/**
* See OperatorRegistry::getOpsFn()
*/
OpFunction& getOpsFn(const char* name);
[[nodiscard]] bool register_operators(const ArrayRef<Operator>&);
struct OperatorRegistry {
public:
OperatorRegistry() : operatorRegSize_(0) {}
bool register_operators(const ArrayRef<Operator>&);
/**
* Checks whether an operator with a given name is registered
*/
bool hasOpsFn(const char* name);
/**
* Checks whether an operator with a given name is registered
*/
OpFunction& getOpsFn(const char* name);
private:
std::map<const char*, OpFunction> operators_map_;
uint32_t operatorRegSize_;
};
} // namespace executor
} // namespace torch