mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Remove reference_cast in make_boxed_from_unboxed_functor (#51319)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51319 We were going out of our way to accommodate `IValue::to<Tensor>` returning a copy of the inner Tensor. `IValue::toTensor` is capable of returning a reference without copying, so if we use it directly, we can allow kernels that want to take `Tensor &` to do so! As a bonus, we get reduced build times. ghstack-source-id: 121378961 Test Plan: Rely on CI for correctness. Profiled build time with -ftime-trace for RegisterCPU.cpp using an extracted build invocation. Before: P168244900 After: P168245014 Note reduced time spent compiling make_boxed_from_unboxed_functor. I also ran the AdIndexer benchmark (https://fb.quip.com/ztERAYjuzdlr) with static runtime disabled and batch size 1 to see how big the effect on boxed call performance was (any kernels that take `Tensor&` or `const Tensor&` should now actually save a refcount bump). Looks like it was roughly 1% better: Before: 124-125 usec/iter After: 122-123 usec/iter Reviewed By: bhosmer Differential Revision: D26138549 fbshipit-source-id: b0f830527da360c542c815bef2f7e1692615b32a
This commit is contained in:
parent
c442776f3c
commit
a9f5e7229e
|
|
@ -249,20 +249,59 @@ namespace impl {
|
||||||
|
|
||||||
// ivalue_to_arg
|
// ivalue_to_arg
|
||||||
|
|
||||||
|
template<class T>
|
||||||
|
struct decay_if_not_tensor final {
|
||||||
|
using type = std::decay_t<T>;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct decay_if_not_tensor<at::Tensor&> final {
|
||||||
|
using type = at::Tensor&;
|
||||||
|
};
|
||||||
|
|
||||||
|
template<>
|
||||||
|
struct decay_if_not_tensor<const at::Tensor&> final {
|
||||||
|
using type = const at::Tensor&;
|
||||||
|
};
|
||||||
|
|
||||||
template<class T, bool AllowDeprecatedTypes>
|
template<class T, bool AllowDeprecatedTypes>
|
||||||
struct ivalue_to_arg final {
|
struct ivalue_to_arg final {
|
||||||
static T call(IValue&& v) {
|
static decltype(auto) call(IValue& v) {
|
||||||
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
|
assert_is_valid_input_type<T, AllowDeprecatedTypes>();
|
||||||
return std::move(v).to<T>();
|
return std::move(v).to<T>();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// The following two specializations take advantage of specialized
|
||||||
|
// `toTensor()` overloads on IValue to avoid copying.
|
||||||
|
template<bool AllowDeprecatedTypes>
|
||||||
|
struct ivalue_to_arg<at::Tensor&, AllowDeprecatedTypes> final {
|
||||||
|
// We cannot use the default implementation if they asked for a
|
||||||
|
// `at::Tensor&` because it moves from the IValue, so it can't get
|
||||||
|
// an lvalue reference.
|
||||||
|
static at::Tensor& call(IValue& v) {
|
||||||
|
// Tensor& is valid, don't bother asserting
|
||||||
|
return v.toTensor();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
template<bool AllowDeprecatedTypes>
|
||||||
|
struct ivalue_to_arg<const at::Tensor&, AllowDeprecatedTypes> final {
|
||||||
|
// We should not use the default implementation if they asked for
|
||||||
|
// a `const at::Tensor&` because it moves from the IValue and they
|
||||||
|
// didn't ask for that.
|
||||||
|
static const at::Tensor& call(IValue& v) {
|
||||||
|
// const Tensor& is valid, don't bother asserting
|
||||||
|
return v.toTensor();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
template<class T, bool AllowDeprecatedTypes>
|
template<class T, bool AllowDeprecatedTypes>
|
||||||
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
|
struct ivalue_to_arg<ArrayRef<T>, AllowDeprecatedTypes> final {
|
||||||
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
|
// If an argument is ArrayRef<T>, convert the IValue to a std::vector<T> and pass that
|
||||||
// to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
|
// to the operator. std::vector<T> is implicitly convertible to ArrayRef<T>.
|
||||||
static std::vector<T> call(IValue&& v) {
|
static std::vector<T> call(IValue& v) {
|
||||||
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(std::move(v));
|
return ivalue_to_arg<std::vector<T>, AllowDeprecatedTypes>::call(v);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
template<class T, bool AllowDeprecatedTypes>
|
template<class T, bool AllowDeprecatedTypes>
|
||||||
|
|
@ -270,8 +309,8 @@ namespace impl {
|
||||||
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
|
// If an argument is optional<ArrayRef<T>>, convert the IValue to an optional<std::vector<T>> and pass that
|
||||||
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
|
// to the operator. OptionalArray<T> is basically a optional<std::vector<T>> but impliticly convertible
|
||||||
// to optional<ArrayRef<T>>.
|
// to optional<ArrayRef<T>>.
|
||||||
static OptionalArray<T> call(IValue&& v) {
|
static OptionalArray<T> call(IValue& v) {
|
||||||
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(std::move(v));
|
return ivalue_to_arg<OptionalArray<T>, AllowDeprecatedTypes>::call(v);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -296,19 +335,6 @@ namespace impl {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// reference_cast allows casting references, e.g. T&& to T&:
|
|
||||||
// T make_t() {}
|
|
||||||
// T& v = reference_cast<T&>(make_t()); // make_t() returns a T&& which is cast to T&.
|
|
||||||
// If the target is a non-reference value, then it gets moved:
|
|
||||||
// T make_t() {}
|
|
||||||
// T v = reference_cast<T>(make_t()); // no copies involved
|
|
||||||
// The first example actually also shows why reference_cast is usually a very bad idea. v now is a lvalue
|
|
||||||
// reference to a dead temporary. Use with caution!
|
|
||||||
template<class T, class U>
|
|
||||||
T reference_cast(U&& t) {
|
|
||||||
return std::forward<T>(t);
|
|
||||||
}
|
|
||||||
|
|
||||||
// wrap_kernel_functor_unboxed_
|
// wrap_kernel_functor_unboxed_
|
||||||
|
|
||||||
template<class KernelFunctor, class OpSignature>
|
template<class KernelFunctor, class OpSignature>
|
||||||
|
|
@ -363,21 +389,14 @@ namespace impl {
|
||||||
call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
|
call_functor_with_args_from_stack_(OperatorKernel* functor, DispatchKeySet dispatchKeySet, Stack* stack, std::index_sequence<ivalue_arg_indices...>, guts::typelist::typelist<ArgTypes...>*) {
|
||||||
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
|
(void)(stack); // when sizeof...(ivalue_arg_indices) == 0, this argument would be unused and we have to silence the compiler warning.
|
||||||
|
|
||||||
/*
|
|
||||||
* For ops that take "Tensor&" as an argument, ivalue_to_arg would still return a "Tensor" by value
|
|
||||||
* and C++ doesn't allow us to call (*functor) with a temporary "Tensor" when it expects "Tensor&".
|
|
||||||
* We use reference_cast to explicitly cast our temporary to a "Tensor&" and make it pass the compiler.
|
|
||||||
* Even though usually dangerous, this is ok here because temporaries live until the end of the statement.
|
|
||||||
* TODO We should remove reference_cast once kernels don't take "Tensor&" arguments anymore
|
|
||||||
*/
|
|
||||||
// We're explicitly filtering out DispatchKeySet from the argument list.
|
// We're explicitly filtering out DispatchKeySet from the argument list.
|
||||||
// Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
|
// Some kernels take a DispatchKeySet as their first argument in order to plumb keys through the dispatcher.
|
||||||
// We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
|
// We don't want to expose the DispatchKeySet type to jit, so we don't include this argument on the stack.
|
||||||
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
// See Note [Plumbing Keys Through The Dispatcher] for the background.
|
||||||
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet, reference_cast<ArgTypes>(
|
return wrap_kernel_functor_unboxed<Functor>::call(functor, dispatchKeySet,
|
||||||
ivalue_to_arg<std::decay_t<ArgTypes>, AllowDeprecatedTypes>::call(
|
ivalue_to_arg<typename decay_if_not_tensor<ArgTypes>::type, AllowDeprecatedTypes>::call(
|
||||||
std::move(torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices)))
|
torch::jit::peek(*stack, ivalue_arg_indices, sizeof...(ivalue_arg_indices))
|
||||||
))...);
|
)...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class Functor, bool AllowDeprecatedTypes>
|
template<class Functor, bool AllowDeprecatedTypes>
|
||||||
|
|
|
||||||
|
|
@ -639,7 +639,7 @@ TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingC
|
||||||
expectThrows<c10::Error>([] {
|
expectThrows<c10::Error>([] {
|
||||||
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
|
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
|
||||||
.kernel(DispatchKey::CPU, [] (const int64_t&) {})
|
.kernel(DispatchKey::CPU, [] (const int64_t&) {})
|
||||||
.kernel(DispatchKey::CUDA, [] (int64_t&) {}));
|
.kernel(DispatchKey::CUDA, [] (int64_t) {}));
|
||||||
}, "Mismatch in kernel C++ signatures");
|
}, "Mismatch in kernel C++ signatures");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,11 +79,11 @@ call_torchbind_method_from_stack(
|
||||||
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
|
typename c10::guts::infer_function_traits_t<Functor>::parameter_types;
|
||||||
// TODO We shouldn't use c10::impl stuff directly here. We should use the KernelFunction API instead.
|
// TODO We shouldn't use c10::impl stuff directly here. We should use the KernelFunction API instead.
|
||||||
return (functor)(c10::impl::ivalue_to_arg<
|
return (functor)(c10::impl::ivalue_to_arg<
|
||||||
std::remove_cv_t<std::remove_reference_t<
|
typename c10::impl::decay_if_not_tensor<
|
||||||
c10::guts::typelist::
|
c10::guts::typelist::
|
||||||
element_t<ivalue_arg_indices, IValueArgTypes>>>,
|
element_t<ivalue_arg_indices, IValueArgTypes>>::type,
|
||||||
AllowDeprecatedTypes>::call(std::move(
|
AllowDeprecatedTypes>::call(
|
||||||
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args)))...);
|
torch::jit::peek(stack, ivalue_arg_indices, num_ivalue_args))...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Functor, bool AllowDeprecatedTypes>
|
template <class Functor, bool AllowDeprecatedTypes>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user