mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Optimize perf for calling ops with custom classes (#38257)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/38257 It seems we're doing a runtime type check for custom classes on each operator call if the operator has custom class arguments. This does not have an effect on operators without custom class arguments, but this is a problem for operators with custom class arguments, for example operators taking a at::native::xnnpack::Conv2dOpContext argument. The long term solution would be to move those checks to op registration time instead of doing them at call time, but as an intermediate fix, we can at least make the check fast by - Using ska::flat_hash_map instead of std::unordered_map - Using std::type_index instead of std::string (i.e. avoid calling std::hash on a std::string) ghstack-source-id: 106805209 Test Plan: waitforsandcastle Reviewed By: ezyang Differential Revision: D21507226 fbshipit-source-id: bd120d5574734be843c197673ea4222599fee7cb
This commit is contained in:
parent
2f47e953f7
commit
d7c9f96e43
|
|
@ -53,7 +53,8 @@ using supported_primitive_arg_types = guts::typelist::typelist<
|
|||
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
|
||||
/* everything is ok, this is a primitive type */
|
||||
}, /* else */ [] {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
// TODO This is called for each operator call and potentially expensive.
|
||||
// This check should be moved to operator registration time instead.
|
||||
TORCH_CHECK(
|
||||
c10::isCustomClassRegistered<T>(),
|
||||
"Tried to use undefined class ",
|
||||
|
|
@ -146,7 +147,8 @@ using supported_primitive_arg_types = guts::typelist::typelist<
|
|||
guts::if_constexpr<guts::typelist::contains<supported_primitive_arg_types, T>::value>([] {
|
||||
/* everything is ok, this is a primitive type */
|
||||
}, /* else */ [] {
|
||||
auto tmap = getCustomClassTypeMap();
|
||||
// TODO This is called for each operator call and potentially expensive.
|
||||
// This check should be moved to operator registration time instead.
|
||||
TORCH_CHECK(c10::isCustomClassRegistered<T>(), "Tried to use undefined class ", c10::util::get_fully_qualified_type_name<T>(), " as output");
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -618,8 +618,8 @@ StrongTypePtr::StrongTypePtr(
|
|||
TORCH_INTERNAL_ASSERT(type_);
|
||||
}
|
||||
|
||||
std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap() {
|
||||
static std::unordered_map<std::string, c10::ClassTypePtr> tmap;
|
||||
ska::flat_hash_map<std::type_index, c10::ClassTypePtr>& getCustomClassTypeMap() {
|
||||
static ska::flat_hash_map<std::type_index, c10::ClassTypePtr> tmap;
|
||||
return tmap;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <c10/util/C++17.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||
#include <typeindex>
|
||||
|
||||
namespace torch {
|
||||
class TORCH_API CustomClassHolder : public c10::intrusive_ptr_target {};
|
||||
|
|
@ -795,12 +796,12 @@ struct TORCH_API StrongTypePtr {
|
|||
std::shared_ptr<Type> type_;
|
||||
};
|
||||
|
||||
TORCH_API std::unordered_map<std::string, c10::ClassTypePtr>& getCustomClassTypeMap();
|
||||
TORCH_API ska::flat_hash_map<std::type_index, c10::ClassTypePtr>& getCustomClassTypeMap();
|
||||
|
||||
template<typename T>
|
||||
c10::ClassTypePtr getCustomClassType() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
auto res = tmap.find(typeid(T).name());
|
||||
auto res = tmap.find(std::type_index(typeid(T)));
|
||||
if (res == tmap.end()) {
|
||||
throw c10::Error("Can't find class id in custom class type map", "");
|
||||
}
|
||||
|
|
@ -810,7 +811,7 @@ c10::ClassTypePtr getCustomClassType() {
|
|||
template<typename T>
|
||||
inline bool isCustomClassRegistered() {
|
||||
auto tmap = c10::getCustomClassTypeMap();
|
||||
return tmap.find(typeid(T).name()) != tmap.end();
|
||||
return tmap.find(std::type_index(typeid(T))) != tmap.end();
|
||||
}
|
||||
|
||||
TORCH_API std::unordered_map<std::string, std::function<PyObject*(void*)>>&
|
||||
|
|
|
|||
|
|
@ -69,9 +69,9 @@ class class_ {
|
|||
classTypePtr->addAttribute("capsule", at::CapsuleType::get());
|
||||
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
{typeid(c10::intrusive_ptr<CurClass>).name(), classTypePtr});
|
||||
{std::type_index(typeid(c10::intrusive_ptr<CurClass>)), classTypePtr});
|
||||
c10::getCustomClassTypeMap().insert(
|
||||
{typeid(c10::tagged_capsule<CurClass>).name(), classTypePtr});
|
||||
{std::type_index(typeid(c10::tagged_capsule<CurClass>)), classTypePtr});
|
||||
|
||||
registerCustomClass(classTypePtr);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user