Check kernel against function schema in c10 op registration (#18256)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/18256

This diff infers the function schema from the kernel function/functor and checks that it matches the specified function schema.

This diff does not allow (yet) to omit specifying the function schema in the registration API. That will come in a future diff.

Reviewed By: dzhulgakov

Differential Revision: D14552738

fbshipit-source-id: 00202b489ede19f26ae686c97416b38c72c11532
This commit is contained in:
Sebastian Messmer 2019-03-30 00:03:44 -07:00 committed by Facebook Github Bot
parent c4bb09cc42
commit 14c28fabd2
11 changed files with 784 additions and 83 deletions

View File

@ -56,6 +56,7 @@ namespace detail {
TensorTypeId dispatch_key;
KernelFunction* kernel_func = nullptr;
KernelCacheCreatorFunction cache_creator_func = nullptr;
std::unique_ptr<FunctionSchema> inferred_function_schema = nullptr;
};
// is_registration_config_parameter is a concept that returns true_type iff its argument is

View File

@ -0,0 +1,47 @@
#include "infer_schema.h"
#include <sstream>
namespace c10 {
namespace {
std::string serialize_schema(const FunctionSchema& schema) {
std::ostringstream str;
str << schema;
return str.str();
}
}
C10_EXPORT void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified) {
if (inferred.arguments().size() != specified.arguments().size()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
"The number of arguments is different. Specified ", specified.arguments().size(),
" but inferred ", inferred.arguments().size());
}
if (inferred.returns().size() != specified.returns().size()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
"The number of returns is different.Specified ", specified.returns().size(),
" but inferred ", inferred.returns().size());
}
for (size_t i = 0; i < inferred.arguments().size(); ++i) {
if (*inferred.arguments()[i].type() != *specified.arguments()[i].type()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
"Type mismatch in argument ", i, ": specified ", specified.arguments()[i].type()->str(),
" but inferred ", inferred.arguments()[i].type()->str());
}
}
for (size_t i = 0; i < inferred.returns().size(); ++i) {
if (*inferred.returns()[i].type() != *specified.returns()[i].type()) {
AT_ERROR("In operator registration: Specified function schema [", serialize_schema(specified), "] ",
"doesn't match inferred function schema [", serialize_schema(inferred), "]. ",
"Type mismatch in return ", i, ": specified ", specified.returns()[i].type()->str(),
" but inferred ", inferred.returns()[i].type()->str());
}
}
}
}

View File

@ -18,63 +18,76 @@ void checkStaticTypes() {
// Give nice error messages for some of the common error cases.
// Use a LOUD ERROR MESSAGE SO USERS SEE THE STATIC_ASSERT
static_assert(
!std::is_integral<T>::value || std::is_same<T, int64_t>::value,
"INVALID TYPE: Only int64_t is supported as an integral argument type");
!std::is_integral<T>::value || std::is_same<T, int64_t>::value || std::is_same<T, bool>::value,
"INVALID TYPE: Only int64_t and bool are supported as an integral argument type");
static_assert(
!std::is_same<T, float>::value,
"INVALID TYPE: float is not supported as an argument type, use double instead");
}
template <typename First, typename Second, typename... Rest>
void checkStaticTypes() {
checkStaticTypes<First>();
checkStaticTypes<Second, Rest...>();
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTypes(guts::index_sequence<Is...>) {
checkStaticTypes<guts::decay_t<Ts>...>();
// Check types for common errors
(void)std::initializer_list<int>{(
checkStaticTypes<Ts>()
, 0)...};
// Arguments are named "_<index>"
return {Argument("_" + std::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
return {Argument("_" + c10::guts::to_string(Is), getTypePtr<guts::decay_t<Ts>>())...};
}
template <typename... Ts, size_t... Is>
::std::vector<Argument> createReturns(guts::index_sequence<Is...>) {
return createArgumentVectorFromTypes<Ts..., Is...>();
/// Creates a vector of `Argument` from a list of C++ types that are specified
/// as template arguments.
template<class ParameterTypes> struct createArguments final {};
template<class... ParameterTypes>
struct createArguments<guts::typelist::typelist<ParameterTypes...>> final {
static std::vector<Argument> call() {
return createArgumentVectorFromTypes<ParameterTypes...>(
guts::make_index_sequence<sizeof...(ParameterTypes)>()
);
}
};
/// Unpack a tuple return type into a vector of return types, one per tuple
/// element.
template <typename... Ts>
::std::vector<Argument> createReturns(std::tuple<Ts...>* tuple) {
return createReturns<Ts...>(guts::make_index_sequence<sizeof...(Ts)>());
}
/// Creates a vector of `Argument` from a list of C++ types that are specified
/// as a tuple (i.e. in the way c10 kernels return values).
/// It can be a tuple<A, B, C> if there's three output arguments with types A, B, C.
/// It can be an empty tuple<>, or void for kernels that don't return anything.
/// It can be a single type A (i.e. no tuple) for the case where a kernel just
/// returns one value.
template<class ReturnTypeTuple, class Enable = void> struct createReturns final {};
/// Create a single-element `vector` for simple (non-tuple) return types.
template <typename ReturnType>
::std::vector<Argument> createReturns(ReturnType*) {
checkStaticTypes<guts::decay_t<ReturnType>>();
return {Argument("_1", getTypePtr<guts::decay_t<ReturnType>>())};
template<class... ReturnTypes>
struct createReturns<std::tuple<ReturnTypes...>, void> final {
static std::vector<Argument> call() {
return createArgumentVectorFromTypes<ReturnTypes...>(
guts::make_index_sequence<sizeof...(ReturnTypes)>()
);
}
};
/// Creates a vector of `Argument` from `FunctionTraits` and a pack of indices
/// into the argument list.
template <typename FunctionTraits, size_t... Is>
::std::vector<Argument> createArgumentVectorFromTraits(guts::index_sequence<Is...> indices) {
using ArgumentTypes = typename FunctionTraits::parameter_types;
return createArgumentVectorFromTypes<
c10::guts::typelist::element_t<Is, ArgumentTypes>...>(indices);
template<class ReturnType>
struct createReturns<ReturnType, guts::enable_if_t<!std::is_same<void, ReturnType>::value && !guts::is_instantiation_of<std::tuple, ReturnType>::value>> final {
static std::vector<Argument> call() {
return createReturns<std::tuple<ReturnType>>::call();
}
};
template<>
struct createReturns<void, void> final {
static std::vector<Argument> call() {
return createReturns<std::tuple<>>::call();
}
};
/// Creates a `FunctionSchema` object from a `FunctionTraits` type for a
/// function.
template <typename FunctionTraits>
FunctionSchema createFunctionSchemaFromTraits(std::string name, std::string overload_name) {
using ReturnType = typename FunctionTraits::return_type;
using ParameterTypes = typename FunctionTraits::parameter_types;
auto arguments = createArgumentVectorFromTraits<FunctionTraits>(
guts::make_index_sequence<FunctionTraits::number_of_parameters>());
auto returns = createReturns(static_cast<ReturnType*>(nullptr));
auto arguments = createArguments<ParameterTypes>::call();
auto returns = createReturns<ReturnType>::call();
return {std::move(name), std::move(overload_name), std::move(arguments), std::move(returns)};
}
@ -85,4 +98,6 @@ FunctionSchema inferFunctionSchema(std::string name, std::string overload_name)
return detail::createFunctionSchemaFromTraits<guts::infer_function_traits_t<FuncType>>(std::move(name), std::move(overload_name));
}
C10_API void assertSchemasHaveSameSignature(const FunctionSchema& inferred, const FunctionSchema& specified);
}

View File

@ -8,6 +8,7 @@ using c10::RegisterOperators;
using c10::FunctionSchema;
using c10::Argument;
using c10::IntType;
using c10::FloatType;
using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
@ -29,21 +30,23 @@ C10_DEFINE_TENSOR_TYPE(TensorType1);
C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
void errorKernel(const Tensor&) {
int64_t errorKernel(const Tensor& tensor, int64_t input) {
EXPECT_TRUE(false); // this kernel should never be called
return 0;
}
FunctionSchema errorOpSchema(
"_test::error",
"",
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{}));
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
int incrementKernel(const Tensor& tensor, int input) {
int64_t incrementKernel(const Tensor& tensor, int64_t input) {
return input + 1;
}
int decrementKernel(const Tensor& tensor, int input) {
int64_t decrementKernel(const Tensor& tensor, int64_t input) {
return input - 1;
}
@ -159,7 +162,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_wh
EXPECT_EQ(0, result.size());
}
int kernelWithIntOutput(Tensor, int a, int b) {
int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) {
return a + b;
}
@ -237,7 +240,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutp
EXPECT_EQ(TensorType1(), result[0].toTensorListRef()[2].type_id());
}
std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int input1, int input2, int input3) {
std::vector<int64_t> kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
return {input1, input2, input3};
}
@ -393,9 +396,9 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByV
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
int captured_int_input = 0;
int64_t captured_int_input = 0;
void kernelWithIntInputWithoutOutput(Tensor, int input1) {
void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) {
captured_int_input = input1;
}
@ -419,7 +422,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_witho
EXPECT_EQ(3, captured_int_input);
}
int kernelWithIntInputWithOutput(Tensor, int input1) {
int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) {
return input1 + 1;
}
@ -442,7 +445,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withO
EXPECT_EQ(4, outputs[0].toInt());
}
int captured_input_list_size = 0;
int64_t captured_input_list_size = 0;
void kernelWithIntListInputWithoutOutput(Tensor, ArrayRef<int64_t> input1) {
captured_input_list_size = input1.size();
@ -468,7 +471,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_w
EXPECT_EQ(3, captured_input_list_size);
}
int kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
int64_t kernelWithIntListInputWithOutput(Tensor, ArrayRef<int64_t> input1) {
return input1.size();
}
@ -514,7 +517,7 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
EXPECT_EQ(2, captured_input_list_size);
}
int kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
int64_t kernelWithTensorListInputWithOutput(ArrayRef<Tensor> input1) {
return input1.size();
}
@ -536,4 +539,308 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInpu
EXPECT_EQ(2, outputs[0].toInt());
}
template<class Return, class... Args> struct kernel_func final {
static Return func(Args...) { return {}; }
};
template<class... Args> struct kernel_func<void, Args...> final {
static void func(Args...) {}
};
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<void, Tensor, Tensor>::func), &kernel_func<void, Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor, int64_t>::func), &kernel_func<int64_t, Tensor, int64_t>::func>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
), kernel<decltype(kernel_func<void, Tensor>::func), &kernel_func<void, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
), kernel<decltype(kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func), &kernel_func<std::tuple<Tensor, Tensor>, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel<decltype(kernel_func<int64_t, Tensor>::func), &kernel_func<int64_t, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel<decltype(kernel_func<Tensor, Tensor>::func), &kernel_func<Tensor, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
), kernel<decltype(kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func), &kernel_func<std::tuple<Tensor, int64_t>, Tensor>::func>(), dispatchKey(TensorType1())),
c10::Error
);
}
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/core/op_registration/kernel_stackbased.h>
#include <ATen/core/op_registration/infer_schema.h>
namespace c10 {
/**
@ -129,6 +130,14 @@ namespace detail {
private:
std::tuple<Args...> constructor_parameters_;
};
template<class KernelFunctor>
class FunctionSchemaInferer final {
public:
std::unique_ptr<FunctionSchema> operator()() const {
return guts::make_unique<FunctionSchema>(inferFunctionSchema<KernelFunctor>("", ""));
}
};
}
/**
@ -168,20 +177,15 @@ namespace detail {
* > c10::dispatchKey(CPUTensorId()));
*/
template<class KernelFunctor, class... ConstructorParameters>
inline constexpr auto kernel(ConstructorParameters&&... constructorParameters)
// enable_if: only enable it if KernelFunctor is actually a functor and inherits from c10::OperatorKernel
-> guts::enable_if_t<
guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
decltype(kernel(
inline constexpr guts::enable_if_t<guts::is_functor<KernelFunctor>::value && std::is_base_of<OperatorKernel, KernelFunctor>::value,
detail::KernelRegistrationConfigParameter<detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>, detail::FunctionSchemaInferer<KernelFunctor>>>
kernel(ConstructorParameters&&... constructorParameters) {
return {
&detail::wrap_kernel_functor<KernelFunctor>::call,
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
))> {
static_assert(std::is_constructible<KernelFunctor, ConstructorParameters...>::value, "KernelFunctor cannot be constructed with the given arguments");
return kernel(
&detail::wrap_kernel_functor<KernelFunctor>::call,
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...)
);
detail::KernelFactory<KernelFunctor, guts::decay_t<ConstructorParameters>...>(std::forward<ConstructorParameters>(constructorParameters)...),
detail::FunctionSchemaInferer<KernelFunctor>()
};
}
}

View File

@ -9,6 +9,7 @@ using c10::FunctionSchema;
using c10::OperatorKernel;
using c10::Argument;
using c10::IntType;
using c10::FloatType;
using c10::ListType;
using c10::kernel;
using c10::dispatchKey;
@ -31,25 +32,27 @@ C10_DECLARE_TENSOR_TYPE(TensorType2);
C10_DEFINE_TENSOR_TYPE(TensorType2);
struct ErrorKernel final : public OperatorKernel {
void operator()(const Tensor&) {
int64_t operator()(const Tensor&, int64_t) {
EXPECT_TRUE(false); // this kernel should never be called
return 0;
}
};
FunctionSchema errorOpSchema(
"_test::error",
"",
(std::vector<Argument>{Argument("dummy")}),
(std::vector<Argument>{}));
(std::vector<Argument>{Argument("dummy"),
Argument("input", IntType::get())}),
(std::vector<Argument>{Argument("output", IntType::get())}));
struct IncrementKernel final : OperatorKernel {
int operator()(const Tensor& tensor, int input) {
int64_t operator()(const Tensor& tensor, int64_t input) {
return input + 1;
}
};
struct DecrementKernel final : OperatorKernel {
int operator()(const Tensor& tensor, int input) {
int64_t operator()(const Tensor& tensor, int64_t input) {
return input - 1;
}
};
@ -171,7 +174,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whe
}
struct KernelWithIntOutput final : OperatorKernel {
int operator()(Tensor, int a, int b) {
int64_t operator()(Tensor, int64_t a, int64_t b) {
return a + b;
}
};
@ -256,7 +259,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutpu
}
struct KernelWithIntListOutput final : OperatorKernel {
std::vector<int64_t> operator()(const Tensor&, int input1, int input2, int input3) {
std::vector<int64_t> operator()(const Tensor&, int64_t input1, int64_t input2, int64_t input3) {
return {input1, input2, input3};
}
};
@ -423,10 +426,10 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByVa
EXPECT_EQ(TensorType2(), captured_input.type_id());
}
int captured_int_input = 0;
int64_t captured_int_input = 0;
struct KernelWithIntInputWithoutOutput final : OperatorKernel {
void operator()(Tensor, int input1) {
void operator()(Tensor, int64_t input1) {
captured_int_input = input1;
}
};
@ -452,7 +455,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withou
}
struct KernelWithIntInputWithOutput final : OperatorKernel {
int operator()(Tensor, int input1) {
int64_t operator()(Tensor, int64_t input1) {
return input1 + 1;
}
};
@ -476,7 +479,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOu
EXPECT_EQ(4, outputs[0].toInt());
}
int captured_input_list_size = 0;
int64_t captured_input_list_size = 0;
struct KernelWithIntListInputWithoutOutput final : OperatorKernel {
void operator()(Tensor, ArrayRef<int64_t> input1) {
@ -505,7 +508,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_wi
}
struct KernelWithIntListInputWithOutput final : OperatorKernel {
int operator()(Tensor, ArrayRef<int64_t> input1) {
int64_t operator()(Tensor, ArrayRef<int64_t> input1) {
return input1.size();
}
};
@ -555,7 +558,7 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput
}
struct KernelWithTensorListInputWithOutput final : OperatorKernel {
int operator()(ArrayRef<Tensor> input1) {
int64_t operator()(ArrayRef<Tensor> input1) {
return input1.size();
}
};
@ -689,5 +692,308 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstru
EXPECT_EQ(13, outputs[0].toInt());
}
template<class Return, class... Args> struct KernelFunc final : OperatorKernel{
Return operator()(Args...) { return {}; }
};
template<class... Args> struct KernelFunc<void, Args...> final : OperatorKernel {
void operator()(Args...) {}
};
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2")}),
(std::vector<Argument>{})
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{}),
(std::vector<Argument>{})
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg"), Argument("arg2"), Argument("arg3")}),
(std::vector<Argument>{})
), kernel<KernelFunc<void, Tensor, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1"), Argument("arg2", FloatType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg1", IntType::get()), Argument("arg2", IntType::get())}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor, int64_t>>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()),
Argument("ret2", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret"), Argument("ret2")})
), kernel<KernelFunc<void, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2")})
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{})
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1")})
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2"), Argument("ret3")})
), kernel<KernelFunc<std::tuple<Tensor, Tensor>, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
}
TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) {
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", IntType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel<KernelFunc<int64_t, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret")})
), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret", FloatType::get())})
), kernel<KernelFunc<Tensor, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
// assert this does not fail because it matches
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", IntType::get())})
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1()));
// and now a set of mismatching schemas
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1"), Argument("ret2", FloatType::get())})
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
EXPECT_THROW(
RegisterOperators()
.op(FunctionSchema(
"_test::mismatch",
"",
(std::vector<Argument>{Argument("arg")}),
(std::vector<Argument>{Argument("ret1", IntType::get()), Argument("ret2", IntType::get())})
), kernel<KernelFunc<std::tuple<Tensor, int64_t>, Tensor>>(), dispatchKey(TensorType1())),
c10::Error
);
}
}

View File

@ -17,29 +17,40 @@ namespace c10 {
namespace detail {
template<class KernelCacheCreatorFunction_>
struct NoFunctionSchemaInference final {
std::unique_ptr<FunctionSchema> operator()() const {
return nullptr;
}
};
template<class KernelCacheCreatorFunction_, class InferFunctionSchemaFunction>
struct KernelRegistrationConfigParameter final {
template<class KernelCacheCreatorFunction__>
constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func)
: kernel_func_(kernel_func), cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func)) {
constexpr KernelRegistrationConfigParameter(KernelFunction* kernel_func, KernelCacheCreatorFunction__&& cache_creator_func, InferFunctionSchemaFunction&& infer_function_schema_func)
: kernel_func_(kernel_func)
, cache_creator_func_(std::forward<KernelCacheCreatorFunction__>(cache_creator_func))
, infer_function_schema_func_(std::forward<InferFunctionSchemaFunction>(infer_function_schema_func)) {
}
void apply(KernelRegistrationConfig* registration) const & {
registration->kernel_func = kernel_func_;
registration->cache_creator_func = cache_creator_func_;
registration->inferred_function_schema = infer_function_schema_func_();
}
void apply(KernelRegistrationConfig* registration) && {
registration->kernel_func = kernel_func_;
registration->cache_creator_func = std::move(cache_creator_func_);
registration->inferred_function_schema = std::move(infer_function_schema_func_)();
}
private:
KernelFunction* kernel_func_;
KernelCacheCreatorFunction_ cache_creator_func_;
InferFunctionSchemaFunction infer_function_schema_func_;
};
static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
static_assert(is_registration_config_parameter<KernelRegistrationConfigParameter<KernelCacheCreatorFunction, NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
}
/**
@ -61,10 +72,10 @@ namespace detail {
* > c10::dispatchKey(CPUTensorId()));
*/
template<class KernelCacheCreatorFunction_>
inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
inline constexpr detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference> kernel(KernelFunction* kernel_func, KernelCacheCreatorFunction_&& cache_creator) {
static_assert(detail::is_registration_config_parameter<detail::KernelRegistrationConfigParameter<guts::decay_t<KernelCacheCreatorFunction_>, detail::NoFunctionSchemaInference>>::value, "KernelRegistrationConfigParameter must fulfill the registration config parameter concept");
return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator)};
return {kernel_func, std::forward<KernelCacheCreatorFunction_>(cache_creator), detail::NoFunctionSchemaInference()};
}
}

View File

@ -10,6 +10,7 @@
#include <ATen/core/op_registration/kernel_stackbased.h>
#include <ATen/core/op_registration/kernel_functor.h>
#include <ATen/core/op_registration/kernel_function.h>
#include <ATen/core/op_registration/infer_schema.h>
namespace c10 {
@ -65,6 +66,11 @@ public:
guts::enable_if_t<guts::conjunction<detail::is_registration_config_parameter<guts::decay_t<ConfigParameters>>...>::value, RegisterOperators>
op(FunctionSchema schema, ConfigParameters&&... configParameters) && {
detail::KernelRegistrationConfig config = detail::make_registration_config(std::forward<ConfigParameters>(configParameters)...);
if (config.inferred_function_schema.get() != nullptr) {
assertSchemasHaveSameSignature(*config.inferred_function_schema, schema);
}
registrars_.emplace_back(std::move(schema), config.dispatch_key, config.kernel_func, std::move(config.cache_creator_func));
return std::move(*this);
}

View File

@ -229,6 +229,9 @@ class DummyClassForToString final {};
namespace std {
// We use SFINAE to detect if std::to_string exists for a type, but that only works
// if the function name is defined. So let's define a std::to_string for a dummy type.
// If you're getting an error here saying that this overload doesn't match your
// std::to_string() call, then you're calling std::to_string() but should be calling
// c10::guts::to_string().
inline std::string to_string(c10::guts::detail::DummyClassForToString) { return ""; }
}
namespace c10 { namespace guts { namespace detail {

View File

@ -61,5 +61,6 @@ struct is_type_condition : std::false_type {};
template<template<class> class C>
struct is_type_condition<C, guts::enable_if_t<std::is_same<bool, guts::remove_cv_t<decltype(C<int>::value)>>::value>> : std::true_type {};
}
}

View File

@ -111,7 +111,7 @@ static auto registry = c10::RegisterOperators().op(
(std::vector<c10::Argument>{
c10::Argument("inputs", ListType::ofTensors()),
c10::Argument("output"),
c10::Argument("split_info", FloatType::get()),
c10::Argument("split_info"),
c10::Argument("add", IntType::get()),
c10::Argument("add_axis", IntType::get())}),
(std::vector<c10::Argument>{})),