mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch Edge] Conditionally trim dispatch key set to save heap memory at runtime (#65732)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65732 For certain on-device uses, runtime memory comes at a premium. On-device deployments won't use all the available dispatch keys, so it makes sense to keep only the on-device specific ones around for such uses to reduce runtime heap memory allocated. This change keeps just 10 dispatch keys (the ones that used on-device), guarded under the `C10_MOBILE_TRIM_DISPATCH_KEYS` macro. it tries to keep the other code-paths unaffected and uses `constexpr` for use in the `array` declaration, and simple inline functions to ensure that the compiler is able to optimize these for server builds. Test Plan: Build and check mobile models end to end. ``` buck build -c "pt.enable_milan_dispatch_keys_trimming"=1 //xplat/caffe2/fb/lite_predictor:lite_predictor ``` Reviewed By: ezyang Differential Revision: D31185407 fbshipit-source-id: e954765606373dea6ee9466a851dca7684167b0b
This commit is contained in:
parent
7b5d676fa1
commit
a84feeeade
|
|
@ -538,6 +538,14 @@ if(ANDROID OR IOS OR DEFINED ENV{BUILD_PYTORCH_MOBILE_WITH_HOST_TOOLCHAIN})
|
||||||
# c10/macros/Macros.h, so it needs to be explicitly set here.
|
# c10/macros/Macros.h, so it needs to be explicitly set here.
|
||||||
string(APPEND CMAKE_CXX_FLAGS " -DC10_MOBILE")
|
string(APPEND CMAKE_CXX_FLAGS " -DC10_MOBILE")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(DEFINED ENV{PYTORCH_MOBILE_TRIM_DISPATCH_KEY_SET})
|
||||||
|
# If PYTORCH_MOBILE_TRIM_DISPATCH_KEY_SET is defined (env var),
|
||||||
|
# then define C10_MOBILE_TRIM_DISPATCH_KEYS, which limits the
|
||||||
|
# number of dispatch keys in OperatorEntry::dispatchTable_
|
||||||
|
# to reduce peak memory during library initialization.
|
||||||
|
string(APPEND CMAKE_CXX_FLAGS " -DC10_MOBILE_TRIM_DISPATCH_KEYS")
|
||||||
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
# INTERN_BUILD_ATEN_OPS is used to control whether to build ATen/TH operators.
|
# INTERN_BUILD_ATEN_OPS is used to control whether to build ATen/TH operators.
|
||||||
|
|
|
||||||
|
|
@ -299,7 +299,10 @@ std::pair<const AnnotatedKernel&, const char*> OperatorEntry::computeDispatchTab
|
||||||
// or alias keys and their associated keysets).
|
// or alias keys and their associated keysets).
|
||||||
// This function should be considered a private helper for updateDispatchTable_()
|
// This function should be considered a private helper for updateDispatchTable_()
|
||||||
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
|
void OperatorEntry::updateDispatchTableEntry_(const c10::Dispatcher& dispatcher, DispatchKey dispatch_key) {
|
||||||
auto dispatch_ix = static_cast<uint8_t>(dispatch_key);
|
const auto dispatch_ix = c10::getDispatchTableIndexForDispatchKey(dispatch_key);
|
||||||
|
if (C10_UNLIKELY(dispatch_ix == -1)) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
|
dispatchTable_[dispatch_ix] = computeDispatchTableEntry(dispatcher, dispatch_key);
|
||||||
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
|
dispatchKeyExtractor_.setOperatorHasFallthroughForKey(dispatch_key, dispatchTable_[dispatch_ix].isFallthrough());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,11 @@ public:
|
||||||
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
|
[[noreturn]] void reportError(DispatchKey dispatchKey) const;
|
||||||
|
|
||||||
const KernelFunction& lookup(DispatchKey k) const {
|
const KernelFunction& lookup(DispatchKey k) const {
|
||||||
const auto& kernel = dispatchTable_[static_cast<uint8_t>(k)];
|
const auto idx = getDispatchTableIndexForDispatchKey(k);
|
||||||
|
if (C10_UNLIKELY(idx == -1)) {
|
||||||
|
reportError(k);
|
||||||
|
}
|
||||||
|
const auto& kernel = dispatchTable_[idx];
|
||||||
// A valid kernel *always* has a boxed kernel and *may* have an
|
// A valid kernel *always* has a boxed kernel and *may* have an
|
||||||
// unboxed kernel. However, we typically do unboxed calls in at::
|
// unboxed kernel. However, we typically do unboxed calls in at::
|
||||||
// APIs, where the kernel 1) will very likely be valid and 2)
|
// APIs, where the kernel 1) will very likely be valid and 2)
|
||||||
|
|
@ -203,7 +207,7 @@ private:
|
||||||
OperatorName name_;
|
OperatorName name_;
|
||||||
c10::optional<AnnotatedSchema> schema_;
|
c10::optional<AnnotatedSchema> schema_;
|
||||||
|
|
||||||
std::array<KernelFunction, static_cast<uint8_t>(DispatchKey::NumDispatchKeys)> dispatchTable_;
|
std::array<KernelFunction, c10::getDispatchTableIndexForDispatchKey(DispatchKey::NumDispatchKeys)> dispatchTable_;
|
||||||
DispatchKeyExtractor dispatchKeyExtractor_;
|
DispatchKeyExtractor dispatchKeyExtractor_;
|
||||||
|
|
||||||
// kernels_ stores all registered kernels for the corresponding dispatch key
|
// kernels_ stores all registered kernels for the corresponding dispatch key
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,6 @@ namespace c10 {
|
||||||
//
|
//
|
||||||
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
|
// NOTE: Keep the list in sync with `DispatchKey` in tools/codegen/model.py
|
||||||
enum class DispatchKey : uint8_t {
|
enum class DispatchKey : uint8_t {
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~ UNDEFINED ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ //
|
||||||
// This is not a "real" tensor id, but it exists to give us a "nullopt"
|
// This is not a "real" tensor id, but it exists to give us a "nullopt"
|
||||||
// element we can return for cases when a DispatchKeySet contains no elements.
|
// element we can return for cases when a DispatchKeySet contains no elements.
|
||||||
|
|
@ -358,6 +357,51 @@ static_assert(
|
||||||
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
|
static_cast<uint8_t>(DispatchKey::NumDispatchKeys) < 64,
|
||||||
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
|
"DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries");
|
||||||
|
|
||||||
|
#if defined(C10_MOBILE_TRIM_DISPATCH_KEYS)
|
||||||
|
/**
|
||||||
|
* The method below maps the dispatch key in the enum DispatchKey to an
|
||||||
|
* integer index in the dispatchTable_ array in OperatorEntry. The array
|
||||||
|
* is trimmed for mobile to reduce peak memory usage since it's
|
||||||
|
* unnecessary to reserve additional space for dispatch keys that will
|
||||||
|
* never be used on mobile.
|
||||||
|
*/
|
||||||
|
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
|
||||||
|
switch (dk) {
|
||||||
|
case DispatchKey::Undefined:
|
||||||
|
return 0;
|
||||||
|
case DispatchKey::CPU:
|
||||||
|
return 1;
|
||||||
|
case DispatchKey::Vulkan:
|
||||||
|
return 2;
|
||||||
|
case DispatchKey::Metal:
|
||||||
|
return 3;
|
||||||
|
case DispatchKey::QuantizedCPU:
|
||||||
|
return 4;
|
||||||
|
case DispatchKey::SparseCPU:
|
||||||
|
return 5;
|
||||||
|
case DispatchKey::BackendSelect:
|
||||||
|
return 6;
|
||||||
|
case DispatchKey::ADInplaceOrView:
|
||||||
|
return 7;
|
||||||
|
case DispatchKey::AutogradOther:
|
||||||
|
return 8;
|
||||||
|
case DispatchKey::AutogradCPU:
|
||||||
|
return 9;
|
||||||
|
case DispatchKey::NumDispatchKeys: // Sentinel, end of runtime keys.
|
||||||
|
return 10;
|
||||||
|
default:
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
/**
|
||||||
|
* For the server use-case, make this a simple pass-through.
|
||||||
|
*/
|
||||||
|
C10_API constexpr int getDispatchTableIndexForDispatchKey(DispatchKey dk) {
|
||||||
|
return static_cast<int>(dk);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
C10_API const char* toString(DispatchKey);
|
C10_API const char* toString(DispatchKey);
|
||||||
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
C10_API std::ostream& operator<<(std::ostream&, DispatchKey);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user