Enable AcceleratorAllocatorConfig key check (#157908)

# Motivation
Add a mechanism to ensure raise the key if the key is unrecognized in allocator config.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157908
Approved by: https://github.com/albanD
ghstack dependencies: #149601
This commit is contained in:
Yu, Guangye 2025-07-10 11:20:25 +00:00 committed by PyTorch MergeBot
parent 905b084690
commit 65fcca4f8c
3 changed files with 64 additions and 16 deletions

View File

@ -221,6 +221,15 @@ void AcceleratorAllocatorConfig::parseArgs(const std::string& env) {
} else if (key == "pinned_use_background_threads") {
i = parsePinnedUseBackgroundThreads(tokenizer, i);
} else {
// If a device-specific configuration parser hook is registered, it will
// check if the key is unrecognized.
if (device_config_parser_hook_) {
TORCH_CHECK(
keys_.find(key) != keys_.end(),
"Unrecognized key '",
key,
"' in Accelerator allocator config.");
}
i = tokenizer.skipKey(i);
}

View File

@ -7,6 +7,7 @@
#include <atomic>
#include <mutex>
#include <string>
#include <unordered_set>
#include <vector>
namespace c10::CachingAllocator {
@ -180,7 +181,7 @@ class C10_API AcceleratorAllocatorConfig {
// Returns the vector of division factors used for rounding up allocation
// sizes. These divisions apply to size intervals between 1MB and 64GB.
static std::vector<size_t> roundup_power2_divisions() {
static const std::vector<size_t>& roundup_power2_divisions() {
return instance().roundup_power2_divisions_;
}
@ -219,6 +220,13 @@ class C10_API AcceleratorAllocatorConfig {
return instance().last_allocator_settings_;
}
// Returns the set of valid keys for the allocator configuration.
// This set is used to validate the presence and correctness of keys in
// device-specific configuration parsers.
static const std::unordered_set<std::string>& getKeys() {
return instance().keys_;
}
// Parses the environment variable `env` to update the allocator settings.
// If the environment variable is not set, it does nothing.
// The configuration string should be a comma-separated list of key-value
@ -227,16 +235,24 @@ class C10_API AcceleratorAllocatorConfig {
// "max_split_size_mb:100,max_non_split_rounding_mb:20,garbage_collection_threshold:0.5,roundup_power2_divisions:[64:8,256:4,1024:4,>:1],expandable_segments:true,pinned_use_background_threads:true"
void parseArgs(const std::string& env);
// Registers a device-specific configuration parser hook. This allows
// backends to parse additional device-specific configuration options from the
// environment variable. The hook should be a function that takes a string
// (the environment variable value) and parses it to set device-specific
// configuration options.
// The hook will be called when the environment variable is parsed.
// If a hook is already registered, it will be replaced with the new one.
// Registers a device-specific configuration parser hook and its key. This
// allows backends to parse additional device-specific configuration options
// from the environment variable. The hook should be a function that takes a
// string (the environment variable value) and parses it to set
// device-specific configuration options. The hook will be called when the
// environment variable is parsed. If a hook is already registered, it will be
// replaced with the new one.
void registerDeviceConfigParserHook(
std::function<void(const std::string&)> hook) {
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
device_config_parser_hook_ = std::move(hook);
for (auto& key : keys) {
TORCH_CHECK(
keys_.insert(key).second,
"Duplicated key '",
key,
"' found in device-specific configuration parser hook registration");
}
}
// Calls the registered device-specific configuration parser hook with the
@ -309,6 +325,17 @@ class C10_API AcceleratorAllocatorConfig {
// This allows backends (e.g., CUDA, XPU) to register a custom parser for
// their own environment configuration extensions.
std::function<void(const std::string&)> device_config_parser_hook_{nullptr};
// A set of valid configuration keys, including both common and
// device-specific options. This set is used to validate the presence and
// correctness of keys during parsing.
std::unordered_set<std::string> keys_{
"max_split_size_mb",
"max_non_split_rounding_mb",
"garbage_collection_threshold",
"roundup_power2_divisions",
"expandable_segments",
"pinned_use_background_threads"};
};
C10_API inline void setAllocatorSettings(const std::string& env) {
@ -322,16 +349,22 @@ C10_API inline std::string getAllocatorSettings() {
struct DeviceConfigParserHookRegistry {
explicit DeviceConfigParserHookRegistry(
std::function<void(const std::string&)> hook) {
std::function<void(const std::string&)>&& hook,
const std::unordered_set<std::string>& keys) {
AcceleratorAllocatorConfig::instance().registerDeviceConfigParserHook(
std::move(hook));
std::move(hook), keys);
}
};
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(hook) \
// Assume each config parser has `parseArgs` and `getKeys` methods
#define REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(parser_cls) \
namespace { \
static at::CachingAllocator::DeviceConfigParserHookRegistry \
g_device_config_parse_hook_registry_instance(hook); \
g_device_config_parse_hook_registry_instance( \
[](const std::string& env) { \
parser_cls::instance().parseArgs(env); \
}, \
parser_cls::getKeys()); \
}
} // namespace c10::CachingAllocator

View File

@ -16,6 +16,10 @@ struct ExtendedAllocatorConfig {
return instance().device_specific_option_;
}
static const std::unordered_set<std::string>& getKeys() {
return instance().keys_;
}
void parseArgs(const std::string& env) {
// Parse device-specific options from the environment variable
ConfigTokenizer tokenizer(env);
@ -37,11 +41,10 @@ struct ExtendedAllocatorConfig {
private:
// Device-specific option, e.g., memory limit for a specific device.
std::atomic<size_t> device_specific_option_{0};
std::unordered_set<std::string> keys_{"device_specific_option_mb"};
};
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK([](const std::string& env) {
ExtendedAllocatorConfig::instance().parseArgs(env);
})
REGISTER_ALLOCATOR_CONFIG_PARSE_HOOK(ExtendedAllocatorConfig)
TEST(AllocatorConfigTest, allocator_config_test) {
std::string env =
@ -120,4 +123,7 @@ TEST(AllocatorConfigTest, allocator_config_test) {
c10::CachingAllocator::setAllocatorSettings(env);
EXPECT_EQ(c10::CachingAllocator::getAllocatorSettings(), env);
EXPECT_EQ(AcceleratorAllocatorConfig::pinned_use_background_threads(), false);
env = "foo:123,bar:456";
ASSERT_THROW(c10::CachingAllocator::setAllocatorSettings(env), c10::Error);
}