[PyTorch Edge] Conditionally trim dispatch key set to save heap memory at runtime (#65732)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65732

For certain on-device uses, runtime memory comes at a premium. On-device deployments won't use all the available dispatch keys, so it makes sense to keep only the on-device specific ones around for such uses to reduce runtime heap memory allocated.

This change keeps just 10 dispatch keys (the ones that used on-device), guarded under the `C10_MOBILE_TRIM_DISPATCH_KEYS` macro. it tries to keep the other code-paths unaffected and uses `constexpr` for use in the `array` declaration, and simple inline functions to ensure that the compiler is able to optimize these for server builds.

Test Plan:
Build and check mobile models end to end.

```
buck build -c "pt.enable_milan_dispatch_keys_trimming"=1 //xplat/caffe2/fb/lite_predictor:lite_predictor
```

Reviewed By: ezyang

Differential Revision: D31185407

fbshipit-source-id: e954765606373dea6ee9466a851dca7684167b0b
This commit is contained in:
Dhruv Matani 2021-09-29 12:18:56 -07:00 committed by Facebook GitHub Bot
parent 7b5d676fa1
commit a84feeeade
4 changed files with 63 additions and 4 deletions

View File

@ -538,6 +538,14 @@ if(ANDROID OR IOS OR DEFINED ENV{BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN})
# c10/macros/Macros.h, so it needs to be explicitly set here.
string(APPEND CMAKE_CXX_FLAGS " -DC10_MOBILE")
endif()
if(DEFINED ENV{PYTORCH_MOBILE_TRIM_DISPATCH_KEY_SET})
# If PYTORCH_MOBILE_TRIM_DISPATCH_KEY_SET is defined (env var),
# then define C10_MOBILE_TRIM_DISPATCH_KEYS, which limits the
# number of dispatch keys in OperatorEntry::dispatchTable_
# to reduce peak memory during library initialization.
string(APPEND CMAKE_CXX_FLAGS " -DC10_MOBILE_TRIM_DISPATCH_KEYS")
endif()
endif()
# INTERN_BUILD_ATEN_OPS is used to control whether to build ATen/TH operators.

View File

@ -299,7 +299,10 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
// or alias keys and their associated keysets).
// This function should be considered a private helper for updateDispatchTable_()
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key);
if (C10_UNLIKELY(dispatch_ix == -1)) {
return;
}
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
}

View File

@ -170,7 +170,11 @@ public:
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
const KernelFunction& lookup(DispatchKey k) const {
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
const auto idx = getDispatchTableIndexForDispatchKey(k);
if (C10_UNLIKELY(idx == -1)) {
reportError(k);
}
const auto& kernel = dispatchTable_[idx];
// A valid kernel *always* has a boxed kernel and *may* have an
// unboxed kernel. However, we typically do unboxed calls in at::
// APIs, where the kernel 1) will very likely be valid and 2)
@ -203,7 +207,7 @@ private:
OperatorName name_;
c10::optional<AnnotatedSchema> schema_;
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_;
DispatchKeyExtractor dispatchKeyExtractor_;
// kernels_ stores all registered kernels for the corresponding dispatch key

View File

@ -21,7 +21,6 @@ namespace c10 {
//
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
enum class DispatchKey : uint8_t {
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
// This is not a "real" tensor id, but it exists to give us a "nullopt"
// element we can return for cases when a DispatchKeySet contains no elements.
@ -358,6 +357,51 @@ static_assert(
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
/**
* The method below maps the dispatch key in the enum DispatchKey to an
* integer index in the dispatchTable_ array in OperatorEntry. The array
* is trimmed for mobile to reduce peak memory usage since it's
* unnecessary to reserve additional space for dispatch keys that will
* never be used on mobile.
*/
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
switch (dk) {
case DispatchKey::Undefined:
return 0;
case DispatchKey::CPU:
return 1;
case DispatchKey::Vulkan:
return 2;
case DispatchKey::Metal:
return 3;
case DispatchKey::QuantizedCPU:
return 4;
case DispatchKey::SparseCPU:
return 5;
case DispatchKey::BackendSelect:
return 6;
case DispatchKey::ADInplaceOrView:
return 7;
case DispatchKey::AutogradOther:
return 8;
case DispatchKey::AutogradCPU:
return 9;
case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys.
return 10;
default:
return -1;
}
}
#else
/**
* For the server use-case, make this a simple pass-through.
*/
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
return static_cast<int>(dk);
}
#endif
C10_API const char* toString(DispatchKey);
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);