mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enable AcceleratorAllocatorConfig key check (#157908)
# Motivation Add a mechanism to ensure raise the key if the key is unrecognized in allocator config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157908 Approved by: https://github.com/albanD ghstack dependencies: #149601
This commit is contained in:
parent
905b084690
commit
65fcca4f8c
|
|
@ -221,6 +221,15 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
|
|||
} else if (key == "pinned_use_background_threads") {
|
||||
i = parsePinnedUseBackgroundThreads(tokenizer, i);
|
||||
} else {
|
||||
// If a device-specific configuration parser hook is registered, it will
|
||||
// check if the key is unrecognized.
|
||||
if (device_config_parser_hook_) {
|
||||
TORCH_CHECK(
|
||||
keys_.find(key) != keys_.end(),
|
||||
"Unrecognized key '",
|
||||
key,
|
||||
"' in Accelerator allocator config.");
|
||||
}
|
||||
i = tokenizer.skipKey(i);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include <atomic>
|
||||
#include <mutex>
|
||||
#include <string>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace c10::CachingAllocator {
|
||||
|
|
@ -180,7 +181,7 @@ class C10_API AcceleratorAllocatorConfig {
|
|||
|
||||
// Returns the vector of division factors used for rounding up allocation
|
||||
// sizes. These divisions apply to size intervals between 1MB and 64GB.
|
||||
static std::vector<size_t> roundup_power2_divisions() {
|
||||
static const std::vector<size_t>& roundup_power2_divisions() {
|
||||
return instance().roundup_power2_divisions_;
|
||||
}
|
||||
|
||||
|
|
@ -219,6 +220,13 @@ class C10_API AcceleratorAllocatorConfig {
|
|||
return instance().last_allocator_settings_;
|
||||
}
|
||||
|
||||
// Returns the set of valid keys for the allocator configuration.
|
||||
// This set is used to validate the presence and correctness of keys in
|
||||
// device-specific configuration parsers.
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return instance().keys_;
|
||||
}
|
||||
|
||||
// Parses the environment variable `env` to update the allocator settings.
|
||||
// If the environment variable is not set, it does nothing.
|
||||
// The configuration string should be a comma-separated list of key-value
|
||||
|
|
@ -227,16 +235,24 @@ class C10_API AcceleratorAllocatorConfig {
|
|||
// "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true"
|
||||
void parseArgs(const std::string& env);
|
||||
|
||||
// Registers a device-specific configuration parser hook. This allows
|
||||
// backends to parse additional device-specific configuration options from the
|
||||
// environment variable. The hook should be a function that takes a string
|
||||
// (the environment variable value) and parses it to set device-specific
|
||||
// configuration options.
|
||||
// The hook will be called when the environment variable is parsed.
|
||||
// If a hook is already registered, it will be replaced with the new one.
|
||||
// Registers a device-specific configuration parser hook and its key. This
|
||||
// allows backends to parse additional device-specific configuration options
|
||||
// from the environment variable. The hook should be a function that takes a
|
||||
// string (the environment variable value) and parses it to set
|
||||
// device-specific configuration options. The hook will be called when the
|
||||
// environment variable is parsed. If a hook is already registered, it will be
|
||||
// replaced with the new one.
|
||||
void registerDeviceConfigParserHook(
|
||||
std::function<void(const std::string&)> hook) {
|
||||
std::function<void(const std::string&)>&& hook,
|
||||
const std::unordered_set<std::string>& keys) {
|
||||
device_config_parser_hook_ = std::move(hook);
|
||||
for (auto& key : keys) {
|
||||
TORCH_CHECK(
|
||||
keys_.insert(key).second,
|
||||
"Duplicated key '",
|
||||
key,
|
||||
"' found in device-specific configuration parser hook registration");
|
||||
}
|
||||
}
|
||||
|
||||
// Calls the registered device-specific configuration parser hook with the
|
||||
|
|
@ -309,6 +325,17 @@ class C10_API AcceleratorAllocatorConfig {
|
|||
// This allows backends (e.g., CUDA, XPU) to register a custom parser for
|
||||
// their own environment configuration extensions.
|
||||
std::function<void(const std::string&)> device_config_parser_hook_{nullptr};
|
||||
|
||||
// A set of valid configuration keys, including both common and
|
||||
// device-specific options. This set is used to validate the presence and
|
||||
// correctness of keys during parsing.
|
||||
std::unordered_set<std::string> keys_{
|
||||
"max_split_size_mb",
|
||||
"max_non_split_rounding_mb",
|
||||
"garbage_collection_threshold",
|
||||
"roundup_power2_divisions",
|
||||
"expandable_segments",
|
||||
"pinned_use_background_threads"};
|
||||
};
|
||||
|
||||
C10_API inline void setAllocatorSettings(const std::string& env) {
|
||||
|
|
@ -322,16 +349,22 @@ C10_API inline std::string getAllocatorSettings() {
|
|||
|
||||
struct DeviceConfigParserHookRegistry {
|
||||
explicit DeviceConfigParserHookRegistry(
|
||||
std::function<void(const std::string&)> hook) {
|
||||
std::function<void(const std::string&)>&& hook,
|
||||
const std::unordered_set<std::string>& keys) {
|
||||
AcceleratorAllocatorConfig::instance().registerDeviceConfigParserHook(
|
||||
std::move(hook));
|
||||
std::move(hook), keys);
|
||||
}
|
||||
};
|
||||
|
||||
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(hook) \
|
||||
// Assume each config parser has `parseArgs` and `getKeys` methods
|
||||
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \
|
||||
namespace { \
|
||||
static at::CachingAllocator::DeviceConfigParserHookRegistry \
|
||||
g_device_config_parse_hook_registry_instance(hook); \
|
||||
g_device_config_parse_hook_registry_instance( \
|
||||
[](const std::string& env) { \
|
||||
parser_cls::instance().parseArgs(env); \
|
||||
}, \
|
||||
parser_cls::getKeys()); \
|
||||
}
|
||||
|
||||
} // namespace c10::CachingAllocator
|
||||
|
|
|
|||
|
|
@ -16,6 +16,10 @@ struct ExtendedAllocatorConfig {
|
|||
return instance().device_specific_option_;
|
||||
}
|
||||
|
||||
static const std::unordered_set<std::string>& getKeys() {
|
||||
return instance().keys_;
|
||||
}
|
||||
|
||||
void parseArgs(const std::string& env) {
|
||||
// Parse device-specific options from the environment variable
|
||||
ConfigTokenizer tokenizer(env);
|
||||
|
|
@ -37,11 +41,10 @@ struct ExtendedAllocatorConfig {
|
|||
private:
|
||||
// Device-specific option, e.g., memory limit for a specific device.
|
||||
std::atomic<size_t> device_specific_option_{0};
|
||||
std::unordered_set<std::string> keys_{"device_specific_option_mb"};
|
||||
};
|
||||
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK([](const std::string& env) {
|
||||
ExtendedAllocatorConfig::instance().parseArgs(env);
|
||||
})
|
||||
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig)
|
||||
|
||||
TEST(AllocatorConfigTest, allocator_config_test) {
|
||||
std::string env =
|
||||
|
|
@ -120,4 +123,7 @@ TEST(AllocatorConfigTest, allocator_config_test) {
|
|||
c10::CachingAllocator::setAllocatorSettings(env);
|
||||
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
|
||||
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false);
|
||||
|
||||
env = "foo:123,bar:456";
|
||||
ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user