[PyTorch] Remove reference_cast in make_boxed_from_unboxed_functor (#51319)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51319

We were going out of our way to accommodate `IValue::to<Tensor>` returning a copy of the inner Tensor. `IValue::toTensor` is capable of returning a reference without copying, so if we use it directly, we can allow kernels that want to take `Tensor &` to do so!
As a bonus, we get reduced build times.
ghstack-source-id: 121378961

Test Plan:
Rely on CI for correctness.
Profiled build time with -ftime-trace for RegisterCPU.cpp using an extracted build invocation.

Before: P168244900

After: P168245014

Note reduced time spent compiling make_boxed_from_unboxed_functor.

I also ran the AdIndexer benchmark (https://fb.quip.com/ztERAYjuzdlr) with static runtime disabled and batch size 1 to see how big the effect on boxed call performance was (any kernels that take `Tensor&` or `const Tensor&` should now actually save a refcount bump). Looks like it was roughly 1% better:

Before: 124-125 usec/iter
After: 122-123 usec/iter

Reviewed By: bhosmer

Differential Revision: D26138549

fbshipit-source-id: b0f830527da360c542c815bef2f7e1692615b32a
This commit is contained in:
Scott Wolchok 2021-02-17 10:33:13 -08:00 committed by Facebook GitHub Bot
parent c442776f3c
commit a9f5e7229e
3 changed files with 53 additions and 34 deletions

View File

@ -249,20 +249,59 @@ namespace impl {
// ivalue_to_arg
template<class T>
struct decay_if_not_tensor final {
using type = std::decay_t<T>;
};
template<>
struct decay_if_not_tensor<at::Tensor&> final {
using type = at::Tensor&;
};
template<>
struct decay_if_not_tensor<const at::Tensor&> final {
using type = const at::Tensor&;
};
template<class T, bool AllowDeprecatedTypes>
struct ivalue_to_arg final {
static T call(IValue&& v) {
static decltype(auto) call(IValue& v) {
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
return std::move(v).to<T>();
}
};
// The following two specializations take advantage of specialized
// `toTensor()` overloads on IValue to avoid copying.
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
// We cannot use the default implementation if they asked for a
// `at::Tensor&` because it moves from the IValue, so it can't get
// an lvalue reference.
static at::Tensor& call(IValue& v) {
// Tensor& is valid, don't bother asserting
return v.toTensor();
}
};
template<bool AllowDeprecatedTypes>
struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
// We should not use the default implementation if they asked for
// a `const at::Tensor&` because it moves from the IValue and they
// didn't ask for that.
static const at::Tensor& call(IValue& v) {
// const Tensor& is valid, don't bother asserting
return v.toTensor();
}
};
template<class T, bool AllowDeprecatedTypes>
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
// to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
static std::vector<T> call(IValue&& v) {
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(std::move(v));
static std::vector<T> call(IValue& v) {
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
}
};
template<class T, bool AllowDeprecatedTypes>
@ -270,8 +309,8 @@ namespace impl {
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
// to optional<ArrayRef<T>>.
static OptionalArray<T> call(IValue&& v) {
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(std::move(v));
static OptionalArray<T> call(IValue& v) {
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
}
};
@ -296,19 +335,6 @@ namespace impl {
}
};
// reference_cast allows casting references, e.g. T&& to T&:
// T make_t() {}
// T& v = reference_cast<T&>(make_t()); // make_t() returns a T&& which is cast to T&.
// If the target is a non-reference value, then it gets moved:
// T make_t() {}
// T v = reference_cast<T>(make_t()); // no copies involved
// The first example actually also shows why reference_cast is usually a very bad idea. v now is a lvalue
// reference to a dead temporary. Use with caution!
template<class T, class U>
T reference_cast(U&& t) {
return std::forward<T>(t);
}
// wrap_kernel_functor_unboxed_
template<class KernelFunctor, class OpSignature>
@ -363,21 +389,14 @@ namespace impl {
call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
/*
* For ops that take "Tensor&" as an argument, ivalue_to_arg would still return a "Tensor" by value
* and C++ doesn't allow us to call (*functor) with a temporary "Tensor" when it expects "Tensor&".
* We use reference_cast to explicitly cast our temporary to a "Tensor&" and make it pass the compiler.
* Even though usually dangerous, this is ok here because temporaries live until the end of the statement.
* TODO We should remove reference_cast once kernels don't take "Tensor&" arguments anymore
*/
// We're explicitly filtering out DispatchKeySet from the argument list.
// Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
// We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
// See Note [Plumbing Keys Through The Dispatcher] for the background.
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet, reference_cast<ArgTypes>(
ivalue_to_arg<std::decay_t<ArgTypes>, AllowDeprecatedTypes>::call(
std::move(torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))
))...);
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet,
ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call(
torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
)...);
}
template<class Functor, bool AllowDeprecatedTypes>

View File

@ -639,7 +639,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC
expectThrows<c10::Error>([] {
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
.kernel(DispatchKey::CPU, [] (const int64_t&) {})
.kernel(DispatchKey::CUDA, [] (int64_t&) {}));
.kernel(DispatchKey::CUDA, [] (int64_t) {}));
}, "Mismatch in kernel C++ signatures");
}

View File

@ -79,11 +79,11 @@ call_torchbind_method_from_stack(
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
// TODO We shouldn't use c10::impl stuff directly here. We should use the KernelFunction API instead.
return (functor)(c10::impl::ivalue_to_arg<
std::remove_cv_t<std::remove_reference_t<
typename c10::impl::decay_if_not_tensor<
c10::guts::typelist::
element_t<ivalue_arg_indices, IValueArgTypes>>>,
AllowDeprecatedTypes>::call(std::move(
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
element_t<ivalue_arg_indices, IValueArgTypes>>::type,
AllowDeprecatedTypes>::call(
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args))...);
}
template <class Functor, bool AllowDeprecatedTypes>