mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Check kernel against function schema in c10 op registration (#18256)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/18256 This diff infers the function schema from the kernel function/functor and checks that it matches the specified function schema. This diff does not allow (yet) to omit specifying the function schema in the registration API. That will come in a future diff. Reviewed By: dzhulgakov Differential Revision: D14552738 fbshipit-source-id: 00202b489ede19f26ae686c97416b38c72c11532
This commit is contained in:
parent
c4bb09cc42
commit
14c28fabd2
|
|
@ -56,6 +56,7 @@ namespace detail {
|
|||
TensorTypeId dispatch_key;
|
||||
KernelFunction* kernel_func = nullptr;
|
||||
KernelCacheCreatorFunction cache_creator_func = nullptr;
|
||||
std::unique_ptr<FunctionSchema> inferred_function_schema = nullptr;
|
||||
};
|
||||
|
||||
// is_registration_config_parameter is a concept that returns true_type iff its argument is
|
||||
|
|
|
|||
47
aten/src/ATen/core/op_registration/infer_schema.cpp
Normal file
47
aten/src/ATen/core/op_registration/infer_schema.cpp
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#include "infer_schema.h"
|
||||
#include <sstream>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
namespace {
|
||||
std::string serialize_schema(const FunctionSchema& schema) {
|
||||
std::ostringstream str;
|
||||
str << schema;
|
||||
return str.str();
|
||||
}
|
||||
}
|
||||
|
||||
C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified) {
|
||||
if (inferred.arguments().size() != specified.arguments().size()) {
|
||||
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
|
||||
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
|
||||
"The number of arguments is different. Specified ", specified.arguments().size(),
|
||||
" but inferred ", inferred.arguments().size());
|
||||
}
|
||||
if (inferred.returns().size() != specified.returns().size()) {
|
||||
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
|
||||
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
|
||||
"The number of returns is different.Specified ", specified.returns().size(),
|
||||
" but inferred ", inferred.returns().size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inferred.arguments().size(); ++i) {
|
||||
if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
|
||||
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
|
||||
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
|
||||
"Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
|
||||
" but inferred ", inferred.arguments()[i].type()->str());
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < inferred.returns().size(); ++i) {
|
||||
if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
|
||||
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
|
||||
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
|
||||
"Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
|
||||
" but inferred ", inferred.returns()[i].type()->str());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -18,63 +18,76 @@ void checkStaticTypes() {
|
|||
// Give nice error messages for some of the common error cases.
|
||||
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
|
||||
static_assert(
|
||||
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
|
||||
"INVALID TYPE: Only int64_t is supported as an integral argument type");
|
||||
!std::is_integral<T>::value || std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
|
||||
"INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
|
||||
static_assert(
|
||||
!std::is_same<T, float>::value,
|
||||
"INVALID TYPE: float is not supported as an argument type, use double instead");
|
||||
}
|
||||
|
||||
template <typename First, typename Second, typename... Rest>
|
||||
void checkStaticTypes() {
|
||||
checkStaticTypes<First>();
|
||||
checkStaticTypes<Second, Rest...>();
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
|
||||
checkStaticTypes<guts::decay_t<Ts>...>();
|
||||
// Check types for common errors
|
||||
(void)std::initializer_list<int>{(
|
||||
checkStaticTypes<Ts>()
|
||||
, 0)...};
|
||||
|
||||
// Arguments are named "_<index>"
|
||||
return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
|
||||
return {Argument("_" + c10::guts::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
|
||||
}
|
||||
|
||||
template <typename... Ts, size_t... Is>
|
||||
::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
|
||||
return createArgumentVectorFromTypes<Ts..., Is...>();
|
||||
}
|
||||
/// Creates a vector of `Argument` from a list of C++ types that are specified
|
||||
/// as template arguments.
|
||||
template<class ParameterTypes> struct createArguments final {};
|
||||
template<class... ParameterTypes>
|
||||
struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
|
||||
static std::vector<Argument> call() {
|
||||
return createArgumentVectorFromTypes<ParameterTypes...>(
|
||||
guts::make_index_sequence<sizeof...(ParameterTypes)>()
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Unpack a tuple return type into a vector of return types, one per tuple
|
||||
/// element.
|
||||
template <typename... Ts>
|
||||
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
|
||||
return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
|
||||
}
|
||||
/// Creates a vector of `Argument` from a list of C++ types that are specified
|
||||
/// as a tuple (i.e. in the way c10 kernels return values).
|
||||
/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
|
||||
/// It can be an empty tuple<>, or void for kernels that don't return anything.
|
||||
/// It can be a single type A (i.e. no tuple) for the case where a kernel just
|
||||
/// returns one value.
|
||||
template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
|
||||
|
||||
/// Create a single-element `vector` for simple (non-tuple) return types.
|
||||
template <typename ReturnType>
|
||||
::std::vector<Argument> createReturns(ReturnType*) {
|
||||
checkStaticTypes<guts::decay_t<ReturnType>>();
|
||||
return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
|
||||
}
|
||||
template<class... ReturnTypes>
|
||||
struct createReturns<std::tuple<ReturnTypes...>, void> final {
|
||||
static std::vector<Argument> call() {
|
||||
return createArgumentVectorFromTypes<ReturnTypes...>(
|
||||
guts::make_index_sequence<sizeof...(ReturnTypes)>()
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
|
||||
/// into the argument list.
|
||||
template <typename FunctionTraits, size_t... Is>
|
||||
::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
|
||||
using ArgumentTypes = typename FunctionTraits::parameter_types;
|
||||
return createArgumentVectorFromTypes<
|
||||
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
|
||||
}
|
||||
template<class ReturnType>
|
||||
struct createReturns<ReturnType, guts::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
|
||||
static std::vector<Argument> call() {
|
||||
return createReturns<std::tuple<ReturnType>>::call();
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct createReturns<void, void> final {
|
||||
static std::vector<Argument> call() {
|
||||
return createReturns<std::tuple<>>::call();
|
||||
}
|
||||
};
|
||||
|
||||
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
|
||||
/// function.
|
||||
template <typename FunctionTraits>
|
||||
FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
|
||||
using ReturnType = typename FunctionTraits::return_type;
|
||||
using ParameterTypes = typename FunctionTraits::parameter_types;
|
||||
|
||||
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
|
||||
guts::make_index_sequence<FunctionTraits::number_of_parameters>());
|
||||
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
|
||||
auto arguments = createArguments<ParameterTypes>::call();
|
||||
auto returns = createReturns<ReturnType>::call();
|
||||
|
||||
return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
|
||||
}
|
||||
|
|
@ -85,4 +98,6 @@ FunctionSchema inferFunctionSchema(std::string name, std::string overload_name)
|
|||
return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
|
||||
}
|
||||
|
||||
C10_API void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified);
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ using c10::RegisterOperators;
|
|||
using c10::FunctionSchema;
|
||||
using c10::Argument;
|
||||
using c10::IntType;
|
||||
using c10::FloatType;
|
||||
using c10::ListType;
|
||||
using c10::kernel;
|
||||
using c10::dispatchKey;
|
||||
|
|
@ -29,21 +30,23 @@ C10_DEFINE_TENSOR_TYPE(TensorType1);
|
|||
C10_DECLARE_TENSOR_TYPE(TensorType2);
|
||||
C10_DEFINE_TENSOR_TYPE(TensorType2);
|
||||
|
||||
void errorKernel(const Tensor&) {
|
||||
int64_t errorKernel(const Tensor& tensor, int64_t input) {
|
||||
EXPECT_TRUE(false); // this kernel should never be called
|
||||
return 0;
|
||||
}
|
||||
|
||||
FunctionSchema errorOpSchema(
|
||||
"_test::error",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy")}),
|
||||
(std::vector<Argument>{}));
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
int incrementKernel(const Tensor& tensor, int input) {
|
||||
int64_t incrementKernel(const Tensor& tensor, int64_t input) {
|
||||
return input + 1;
|
||||
}
|
||||
|
||||
int decrementKernel(const Tensor& tensor, int input) {
|
||||
int64_t decrementKernel(const Tensor& tensor, int64_t input) {
|
||||
return input - 1;
|
||||
}
|
||||
|
||||
|
|
@ -159,7 +162,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_wh
|
|||
EXPECT_EQ(0, result.size());
|
||||
}
|
||||
|
||||
int kernelWithIntOutput(Tensor, int a, int b) {
|
||||
int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
|
|
@ -237,7 +240,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutp
|
|||
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
|
||||
}
|
||||
|
||||
std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) {
|
||||
std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
|
||||
return {input1, input2, input3};
|
||||
}
|
||||
|
||||
|
|
@ -393,9 +396,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV
|
|||
EXPECT_EQ(TensorType2(), captured_input.type_id());
|
||||
}
|
||||
|
||||
int captured_int_input = 0;
|
||||
int64_t captured_int_input = 0;
|
||||
|
||||
void kernelWithIntInputWithoutOutput(Tensor, int input1) {
|
||||
void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) {
|
||||
captured_int_input = input1;
|
||||
}
|
||||
|
||||
|
|
@ -419,7 +422,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_witho
|
|||
EXPECT_EQ(3, captured_int_input);
|
||||
}
|
||||
|
||||
int kernelWithIntInputWithOutput(Tensor, int input1) {
|
||||
int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) {
|
||||
return input1 + 1;
|
||||
}
|
||||
|
||||
|
|
@ -442,7 +445,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withO
|
|||
EXPECT_EQ(4, outputs[0].toInt());
|
||||
}
|
||||
|
||||
int captured_input_list_size = 0;
|
||||
int64_t captured_input_list_size = 0;
|
||||
|
||||
void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef<int64_t> input1) {
|
||||
captured_input_list_size = input1.size();
|
||||
|
|
@ -468,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_w
|
|||
EXPECT_EQ(3, captured_input_list_size);
|
||||
}
|
||||
|
||||
int kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
|
||||
int64_t kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
|
||||
return input1.size();
|
||||
}
|
||||
|
||||
|
|
@ -514,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
|
|||
EXPECT_EQ(2, captured_input_list_size);
|
||||
}
|
||||
|
||||
int kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
|
||||
int64_t kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
|
||||
return input1.size();
|
||||
}
|
||||
|
||||
|
|
@ -536,4 +539,308 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
|
|||
EXPECT_EQ(2, outputs[0].toInt());
|
||||
}
|
||||
|
||||
template<class Return, class... Args> struct kernel_func final {
|
||||
static Return func(Args...) { return {}; }
|
||||
};
|
||||
template<class... Args> struct kernel_func<void, Args...> final {
|
||||
static void func(Args...) {}
|
||||
};
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()),
|
||||
Argument("ret2", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
|
||||
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1")})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
|
||||
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/op_registration/kernel_stackbased.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
|
||||
namespace c10 {
|
||||
/**
|
||||
|
|
@ -129,6 +130,14 @@ namespace detail {
|
|||
private:
|
||||
std::tuple<Args...> constructor_parameters_;
|
||||
};
|
||||
|
||||
template<class KernelFunctor>
|
||||
class FunctionSchemaInferer final {
|
||||
public:
|
||||
std::unique_ptr<FunctionSchema> operator()() const {
|
||||
return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -168,20 +177,15 @@ namespace detail {
|
|||
* > c10::dispatchKey(CPUTensorId()));
|
||||
*/
|
||||
template<class KernelFunctor, class... ConstructorParameters>
|
||||
inline constexpr auto kernel(ConstructorParameters&&... constructorParameters)
|
||||
// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
|
||||
-> guts::enable_if_t<
|
||||
guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
|
||||
decltype(kernel(
|
||||
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
|
||||
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
|
||||
kernel(ConstructorParameters&&... constructorParameters) {
|
||||
return {
|
||||
&detail::wrap_kernel_functor<KernelFunctor>::call,
|
||||
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
|
||||
))> {
|
||||
static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "KernelFunctor cannot be constructed with the given arguments");
|
||||
|
||||
return kernel(
|
||||
&detail::wrap_kernel_functor<KernelFunctor>::call,
|
||||
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
|
||||
);
|
||||
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
|
||||
detail::FunctionSchemaInferer<KernelFunctor>()
|
||||
};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ using c10::FunctionSchema;
|
|||
using c10::OperatorKernel;
|
||||
using c10::Argument;
|
||||
using c10::IntType;
|
||||
using c10::FloatType;
|
||||
using c10::ListType;
|
||||
using c10::kernel;
|
||||
using c10::dispatchKey;
|
||||
|
|
@ -31,25 +32,27 @@ C10_DECLARE_TENSOR_TYPE(TensorType2);
|
|||
C10_DEFINE_TENSOR_TYPE(TensorType2);
|
||||
|
||||
struct ErrorKernel final : public OperatorKernel {
|
||||
void operator()(const Tensor&) {
|
||||
int64_t operator()(const Tensor&, int64_t) {
|
||||
EXPECT_TRUE(false); // this kernel should never be called
|
||||
return 0;
|
||||
}
|
||||
};
|
||||
|
||||
FunctionSchema errorOpSchema(
|
||||
"_test::error",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("dummy")}),
|
||||
(std::vector<Argument>{}));
|
||||
(std::vector<Argument>{Argument("dummy"),
|
||||
Argument("input", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("output", IntType::get())}));
|
||||
|
||||
struct IncrementKernel final : OperatorKernel {
|
||||
int operator()(const Tensor& tensor, int input) {
|
||||
int64_t operator()(const Tensor& tensor, int64_t input) {
|
||||
return input + 1;
|
||||
}
|
||||
};
|
||||
|
||||
struct DecrementKernel final : OperatorKernel {
|
||||
int operator()(const Tensor& tensor, int input) {
|
||||
int64_t operator()(const Tensor& tensor, int64_t input) {
|
||||
return input - 1;
|
||||
}
|
||||
};
|
||||
|
|
@ -171,7 +174,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whe
|
|||
}
|
||||
|
||||
struct KernelWithIntOutput final : OperatorKernel {
|
||||
int operator()(Tensor, int a, int b) {
|
||||
int64_t operator()(Tensor, int64_t a, int64_t b) {
|
||||
return a + b;
|
||||
}
|
||||
};
|
||||
|
|
@ -256,7 +259,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutpu
|
|||
}
|
||||
|
||||
struct KernelWithIntListOutput final : OperatorKernel {
|
||||
std::vector<int64_t> operator()(const Tensor&, int input1, int input2, int input3) {
|
||||
std::vector<int64_t> operator()(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
|
||||
return {input1, input2, input3};
|
||||
}
|
||||
};
|
||||
|
|
@ -423,10 +426,10 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa
|
|||
EXPECT_EQ(TensorType2(), captured_input.type_id());
|
||||
}
|
||||
|
||||
int captured_int_input = 0;
|
||||
int64_t captured_int_input = 0;
|
||||
|
||||
struct KernelWithIntInputWithoutOutput final : OperatorKernel {
|
||||
void operator()(Tensor, int input1) {
|
||||
void operator()(Tensor, int64_t input1) {
|
||||
captured_int_input = input1;
|
||||
}
|
||||
};
|
||||
|
|
@ -452,7 +455,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withou
|
|||
}
|
||||
|
||||
struct KernelWithIntInputWithOutput final : OperatorKernel {
|
||||
int operator()(Tensor, int input1) {
|
||||
int64_t operator()(Tensor, int64_t input1) {
|
||||
return input1 + 1;
|
||||
}
|
||||
};
|
||||
|
|
@ -476,7 +479,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOu
|
|||
EXPECT_EQ(4, outputs[0].toInt());
|
||||
}
|
||||
|
||||
int captured_input_list_size = 0;
|
||||
int64_t captured_input_list_size = 0;
|
||||
|
||||
struct KernelWithIntListInputWithoutOutput final : OperatorKernel {
|
||||
void operator()(Tensor, ArrayRef<int64_t> input1) {
|
||||
|
|
@ -505,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_wi
|
|||
}
|
||||
|
||||
struct KernelWithIntListInputWithOutput final : OperatorKernel {
|
||||
int operator()(Tensor, ArrayRef<int64_t> input1) {
|
||||
int64_t operator()(Tensor, ArrayRef<int64_t> input1) {
|
||||
return input1.size();
|
||||
}
|
||||
};
|
||||
|
|
@ -555,7 +558,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput
|
|||
}
|
||||
|
||||
struct KernelWithTensorListInputWithOutput final : OperatorKernel {
|
||||
int operator()(ArrayRef<Tensor> input1) {
|
||||
int64_t operator()(ArrayRef<Tensor> input1) {
|
||||
return input1.size();
|
||||
}
|
||||
};
|
||||
|
|
@ -689,5 +692,308 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstru
|
|||
EXPECT_EQ(13, outputs[0].toInt());
|
||||
}
|
||||
|
||||
template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
|
||||
Return operator()(Args...) { return {}; }
|
||||
};
|
||||
template<class... Args> struct KernelFunc<void, Args...> final : OperatorKernel {
|
||||
void operator()(Args...) {}
|
||||
};
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()),
|
||||
Argument("ret2", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
|
||||
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1")})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", IntType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret")})
|
||||
), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret", FloatType::get())})
|
||||
), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
// assert this does not fail because it matches
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
|
||||
|
||||
// and now a set of mismatching schemas
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
|
||||
EXPECT_THROW(
|
||||
RegisterOperators()
|
||||
.op(FunctionSchema(
|
||||
"_test::mismatch",
|
||||
"",
|
||||
(std::vector<Argument>{Argument("arg")}),
|
||||
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
|
||||
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
|
||||
c10::Error
|
||||
);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,29 +17,40 @@ namespace c10 {
|
|||
|
||||
namespace detail {
|
||||
|
||||
template<class KernelCacheCreatorFunction_>
|
||||
struct NoFunctionSchemaInference final {
|
||||
std::unique_ptr<FunctionSchema> operator()() const {
|
||||
return nullptr;
|
||||
}
|
||||
};
|
||||
|
||||
template<class KernelCacheCreatorFunction_, class InferFunctionSchemaFunction>
|
||||
struct KernelRegistrationConfigParameter final {
|
||||
template<class KernelCacheCreatorFunction__>
|
||||
constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func)
|
||||
: kernel_func_(kernel_func), cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func)) {
|
||||
constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func, InferFunctionSchemaFunction&& infer_function_schema_func)
|
||||
: kernel_func_(kernel_func)
|
||||
, cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func))
|
||||
, infer_function_schema_func_(std::forward<InferFunctionSchemaFunction>(infer_function_schema_func)) {
|
||||
}
|
||||
|
||||
void apply(KernelRegistrationConfig* registration) const & {
|
||||
registration->kernel_func = kernel_func_;
|
||||
registration->cache_creator_func = cache_creator_func_;
|
||||
registration->inferred_function_schema = infer_function_schema_func_();
|
||||
}
|
||||
|
||||
void apply(KernelRegistrationConfig* registration) && {
|
||||
registration->kernel_func = kernel_func_;
|
||||
registration->cache_creator_func = std::move(cache_creator_func_);
|
||||
registration->inferred_function_schema = std::move(infer_function_schema_func_)();
|
||||
}
|
||||
|
||||
private:
|
||||
KernelFunction* kernel_func_;
|
||||
KernelCacheCreatorFunction_ cache_creator_func_;
|
||||
InferFunctionSchemaFunction infer_function_schema_func_;
|
||||
};
|
||||
|
||||
static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
|
||||
static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction, NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -61,10 +72,10 @@ namespace detail {
|
|||
* > c10::dispatchKey(CPUTensorId()));
|
||||
*/
|
||||
template<class KernelCacheCreatorFunction_>
|
||||
inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
|
||||
static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
|
||||
inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
|
||||
static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
|
||||
|
||||
return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator)};
|
||||
return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator), detail::NoFunctionSchemaInference()};
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <ATen/core/op_registration/kernel_stackbased.h>
|
||||
#include <ATen/core/op_registration/kernel_functor.h>
|
||||
#include <ATen/core/op_registration/kernel_function.h>
|
||||
#include <ATen/core/op_registration/infer_schema.h>
|
||||
|
||||
namespace c10 {
|
||||
|
||||
|
|
@ -65,6 +66,11 @@ public:
|
|||
guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
|
||||
op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
|
||||
detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
|
||||
|
||||
if (config.inferred_function_schema.get() != nullptr) {
|
||||
assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
|
||||
}
|
||||
|
||||
registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
|
||||
return std::move(*this);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -229,6 +229,9 @@ class DummyClassForToString final {};
|
|||
namespace std {
|
||||
// We use SFINAE to detect if std::to_string exists for a type, but that only works
|
||||
// if the function name is defined. So let's define a std::to_string for a dummy type.
|
||||
// If you're getting an error here saying that this overload doesn't match your
|
||||
// std::to_string() call, then you're calling std::to_string() but should be calling
|
||||
// c10::guts::to_string().
|
||||
inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
|
||||
}
|
||||
namespace c10 { namespace guts { namespace detail {
|
||||
|
|
|
|||
|
|
@ -61,5 +61,6 @@ struct is_type_condition : std::false_type {};
|
|||
template<template<class> class C>
|
||||
struct is_type_condition<C, guts::enable_if_t<std::is_same<bool, guts::remove_cv_t<decltype(C<int>::value)>>::value>> : std::true_type {};
|
||||
|
||||
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,7 +111,7 @@ static auto registry = c10::RegisterOperators().op(
|
|||
(std::vector<c10::Argument>{
|
||||
c10::Argument("inputs", ListType::ofTensors()),
|
||||
c10::Argument("output"),
|
||||
c10::Argument("split_info", FloatType::get()),
|
||||
c10::Argument("split_info"),
|
||||
c10::Argument("add", IntType::get()),
|
||||
c10::Argument("add_axis", IntType::get())}),
|
||||
(std::vector<c10::Argument>{})),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user