Optimize perf for calling ops with custom classes (#38257)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/38257

It seems we're doing a runtime type check for custom classes on each operator call if the operator has custom class arguments.
This does not have an effect on operators without custom class arguments, but this is a problem for operators with custom class arguments,
for example operators taking a at::native::xnnpack::Conv2dOpContext argument.

The long term solution would be to move those checks to op registration time instead of doing them at call time,
but as an intermediate fix, we can at least make the check fast by

- Using ska::flat_hash_map instead of std::unordered_map
- Using std::type_index instead of std::string (i.e. avoid calling std::hash on a std::string)
ghstack-source-id: 106805209

Test Plan: waitforsandcastle

Reviewed By: ezyang

Differential Revision: D21507226

fbshipit-source-id: bd120d5574734be843c197673ea4222599fee7cb
This commit is contained in:
Sebastian Messmer 2020-07-01 19:24:23 -07:00 committed by Facebook GitHub Bot
parent 2f47e953f7
commit d7c9f96e43
4 changed files with 12 additions and 9 deletions

View File

@ -53,7 +53,8 @@ using supported_primitive_arg_types = guts::typelist::typelist<
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
/* everything is ok, this is a primitive type */
}, /* else */ [] {
auto tmap = c10::getCustomClassTypeMap();
// TODO This is called for each operator call and potentially expensive.
// This check should be moved to operator registration time instead.
TORCH_CHECK(
c10::isCustomClassRegistered<T>(),
"Tried to use undefined class ",
@ -146,7 +147,8 @@ using supported_primitive_arg_types = guts::typelist::typelist<
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
/* everything is ok, this is a primitive type */
}, /* else */ [] {
auto tmap = getCustomClassTypeMap();
// TODO This is called for each operator call and potentially expensive.
// This check should be moved to operator registration time instead.
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class ", c10::util::get_fully_qualified_type_name<T>(), " as output");
});
}

View File

@ -618,8 +618,8 @@ StrongTypePtr::StrongTypePtr(
TORCH_INTERNAL_ASSERT(type_);
}
std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap() {
static std::unordered_map<std::string, c10::ClassTypePtr> tmap;
ska::flat_hash_map<std::type_index, c10::ClassTypePtr>& getCustomClassTypeMap() {
static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
return tmap;
}

View File

@ -5,6 +5,7 @@
#include <c10/util/C++17.h>
#include <c10/util/intrusive_ptr.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <typeindex>
namespace torch {
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
@ -795,12 +796,12 @@ struct TORCH_API StrongTypePtr {
std::shared_ptr<Type> type_;
};
TORCH_API std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap();
TORCH_API ska::flat_hash_map<std::type_index, c10::ClassTypePtr>& getCustomClassTypeMap();
template<typename T>
c10::ClassTypePtr getCustomClassType() {
auto tmap = c10::getCustomClassTypeMap();
auto res = tmap.find(typeid(T).name());
auto res = tmap.find(std::type_index(typeid(T)));
if (res == tmap.end()) {
throw c10::Error("Can't find class id in custom class type map", "");
}
@ -810,7 +811,7 @@ c10::ClassTypePtr getCustomClassType() {
template<typename T>
inline bool isCustomClassRegistered() {
auto tmap = c10::getCustomClassTypeMap();
return tmap.find(typeid(T).name()) != tmap.end();
return tmap.find(std::type_index(typeid(T))) != tmap.end();
}
TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&

View File

@ -69,9 +69,9 @@ class class_ {
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
c10::getCustomClassTypeMap().insert(
{typeid(c10::intrusive_ptr<CurClass>).name(), classTypePtr});
{std::type_index(typeid(c10::intrusive_ptr<CurClass>)), classTypePtr});
c10::getCustomClassTypeMap().insert(
{typeid(c10::tagged_capsule<CurClass>).name(), classTypePtr});
{std::type_index(typeid(c10::tagged_capsule<CurClass>)), classTypePtr});
registerCustomClass(classTypePtr);
}