Create and use DynamicTypes for check in DispatchKeyExtractor::makeBitsetForDispatchArgs (#151802)

On mobile, many but not all things in the JIT type subsystem start using DynamicType. Not using DynamicType  was imposing a startup time cost here, as explained in the comment.

Differential Revision: [D73129442](https://our.internmc.facebook.com/intern/diff/D73129442/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151802
Approved by: https://github.com/malfet
ghstack dependencies: #151801
This commit is contained in:
Scott Wolchok 2025-04-23 17:22:21 -07:00 committed by PyTorch MergeBot
parent 5de92e676a
commit fabbcddab1

View File

@ -200,6 +200,31 @@ struct TORCH_API DispatchKeyExtractor final {
void checkInvariants(const FunctionSchema& schema) const;
private:
static bool isDispatchType(const Type& type) {
// Checking isSubtypeOf on a DynamicType heap-allocates a
// DynamicType version of the argument if it's not a DynamicType
// already, and this has measurable overhead during startup.
#ifdef C10_MOBILE
struct CachedTypes {
DynamicTypePtr listOfTensors;
DynamicTypePtr listOfOptionalTensors;
DynamicTypePtr optionalOfTensor;
};
static const CachedTypes ct = {
DynamicType::create(*ListType::ofTensors()),
DynamicType::create(*ListType::ofOptionalTensors()),
DynamicType::create(*OptionalType::ofTensor())};
return type.isSubtypeOf(c10::TypeFactory::get<TensorType>()) ||
type.isSubtypeOf(ct.listOfTensors) ||
type.isSubtypeOf(ct.listOfOptionalTensors) ||
type.isSubtypeOf(ct.optionalOfTensor);
#else // C10_MOBILE
return type.isSubtypeOf(*TensorType::get()) ||
type.isSubtypeOf(*ListType::ofTensors()) ||
type.isSubtypeOf(*ListType::ofOptionalTensors()) ||
type.isSubtypeOf(*OptionalType::ofTensor());
#endif // C10_MOBILE
}
static c10::utils::bitset makeBitsetForDispatchArgs(
const FunctionSchema& schema) {
TORCH_CHECK(
@ -210,13 +235,7 @@ struct TORCH_API DispatchKeyExtractor final {
c10::utils::bitset::NUM_BITS());
c10::utils::bitset dispatch_arg_indices_reverse;
for (const auto index : c10::irange(schema.arguments().size())) {
if (schema.arguments()[index].type()->isSubtypeOf(*TensorType::get()) ||
schema.arguments()[index].type()->isSubtypeOf(
*ListType::ofTensors()) ||
schema.arguments()[index].type()->isSubtypeOf(
*ListType::ofOptionalTensors()) ||
schema.arguments()[index].type()->isSubtypeOf(
*OptionalType::ofTensor())) {
if (isDispatchType(*schema.arguments()[index].type())) {
dispatch_arg_indices_reverse.set(schema.arguments().size() - 1 - index);
}
}