Add config alias (#142088)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142088
Approved by: https://github.com/c00w
This commit is contained in:
Oguz Ulgen 2024-12-14 09:24:12 -08:00 committed by PyTorch MergeBot
parent 1b6b86fad7
commit 17b71e5d6a
4 changed files with 119 additions and 24 deletions

View File

@ -9,7 +9,10 @@ os.environ["ENV_FALSE"] = "0"
from typing import Optional from typing import Optional
from torch.testing._internal import fake_config_module as config from torch.testing._internal import (
fake_config_module as config,
fake_config_module2 as config2,
)
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL, Config from torch.utils._config_module import _UNSET_SENTINEL, Config
@ -98,7 +101,7 @@ class TestConfigModule(TestCase):
def test_save_config(self): def test_save_config(self):
p = config.save_config() p = config.save_config()
self.assertEqual( self.assertDictEqual(
pickle.loads(p), pickle.loads(p),
{ {
"_cache_config_ignore_prefix": ["magic_cache_config"], "_cache_config_ignore_prefix": ["magic_cache_config"],
@ -133,7 +136,7 @@ class TestConfigModule(TestCase):
def test_save_config_portable(self): def test_save_config_portable(self):
p = config.save_config_portable() p = config.save_config_portable()
self.assertEqual( self.assertDictEqual(
p, p,
{ {
"e_bool": True, "e_bool": True,
@ -174,22 +177,36 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
) )
def test_get_hash(self): def test_get_hash(self):
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") hash_value = b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
self.assertEqual(
config.get_hash(),
hash_value,
)
# Test cached value # Test cached value
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") self.assertEqual(
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") config.get_hash(),
hash_value,
)
self.assertEqual(
config.get_hash(),
hash_value,
)
config._hash_digest = "fake" config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake") self.assertEqual(config.get_hash(), "fake")
config.e_bool = False config.e_bool = False
self.assertNotEqual( self.assertNotEqual(
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ" config.get_hash(),
hash_value,
) )
config.e_bool = True config.e_bool = True
# Test ignored values # Test ignored values
config.e_compile_ignored = False config.e_compile_ignored = False
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ") self.assertEqual(
config.get_hash(),
hash_value,
)
def test_dict_copy_semantics(self): def test_dict_copy_semantics(self):
p = config.shallow_copy_dict() p = config.shallow_copy_dict()
@ -319,6 +336,15 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
): ):
Config(default="bad", justknob="fake_knob") Config(default="bad", justknob="fake_knob")
def test_alias(self):
self.assertFalse(config2.e_aliasing_bool)
self.assertFalse(config.e_aliased_bool)
with config2.patch(e_aliasing_bool=True):
self.assertTrue(config2.e_aliasing_bool)
self.assertTrue(config.e_aliased_bool)
with config.patch(e_aliased_bool=True):
self.assertTrue(config2.e_aliasing_bool)
if __name__ == "__main__": if __name__ == "__main__":
run_tests() run_tests()

View File

@ -19,12 +19,15 @@ _e_ignored = True
magic_cache_config_ignored = True magic_cache_config_ignored = True
# [@compile_ignored: debug] # [@compile_ignored: debug]
e_compile_ignored = True e_compile_ignored = True
e_config = Config(default=True) e_config: bool = Config(default=True)
e_jk = Config(justknob="does_not_exist", default=True) e_jk: bool = Config(justknob="does_not_exist", default=True)
e_jk_false = Config(justknob="does_not_exist", default=False) e_jk_false: bool = Config(justknob="does_not_exist", default=False)
e_env_default = Config(env_name_default="ENV_TRUE", default=False) e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False)
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True) e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True)
e_env_force = Config(env_name_force="ENV_TRUE", default=False) e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False)
e_aliased_bool: bool = Config(
alias="torch.testing._internal.fake_config_module2.e_aliasing_bool"
)
class nested: class nested:

View File

@ -0,0 +1,8 @@
import sys
from torch.utils._config_module import install_config_module
e_aliasing_bool = False
install_config_module(sys.modules[__name__])

View File

@ -1,6 +1,7 @@
import contextlib import contextlib
import copy import copy
import hashlib import hashlib
import importlib
import inspect import inspect
import io import io
import os import os
@ -20,6 +21,7 @@ from typing import (
NoReturn, NoReturn,
Optional, Optional,
Set, Set,
Tuple,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -38,6 +40,9 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict]) T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
_UNSET_SENTINEL = object()
@dataclass @dataclass
class _Config(Generic[T]): class _Config(Generic[T]):
"""Represents a config with richer behaviour than just a default value. """Represents a config with richer behaviour than just a default value.
@ -49,7 +54,9 @@ class _Config(Generic[T]):
This configs must be installed with install_config_module to be used This configs must be installed with install_config_module to be used
Precedence Order: Precedence Order:
env_name_force: If set, this environment variable overrides everything alias: If set, the directly use the value of the alias.
env_name_force: If set, this environment variable has precedence over
everything after this.
user_override: If a user sets a value (i.e. foo.bar=True), that user_override: If a user sets a value (i.e. foo.bar=True), that
has precedence over everything after this. has precedence over everything after this.
env_name_default: If set, this environment variable will override everything env_name_default: If set, this environment variable will override everything
@ -65,25 +72,28 @@ class _Config(Generic[T]):
Arguments: Arguments:
justknob: the name of the feature / JK. In OSS this is unused. justknob: the name of the feature / JK. In OSS this is unused.
default: is the value to default this knob to in OSS. default: is the value to default this knob to in OSS.
alias: The alias config to read instead.
env_name_force: The environment variable to read that is a FORCE env_name_force: The environment variable to read that is a FORCE
environment variable. I.e. it overrides everything environment variable. I.e. it overrides everything except for alias.
env_name_default: The environment variable to read that changes the env_name_default: The environment variable to read that changes the
default behaviour. I.e. user overrides take preference. default behaviour. I.e. user overrides take preference.
""" """
default: T default: Union[T, object]
justknob: Optional[str] = None justknob: Optional[str] = None
env_name_default: Optional[str] = None env_name_default: Optional[str] = None
env_name_force: Optional[str] = None env_name_force: Optional[str] = None
value_type: Optional[type] = None value_type: Optional[type] = None
alias: Optional[str] = None
def __init__( def __init__(
self, self,
default: T, default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None, justknob: Optional[str] = None,
env_name_default: Optional[str] = None, env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None, env_name_force: Optional[str] = None,
value_type: Optional[type] = None, value_type: Optional[type] = None,
alias: Optional[str] = None,
): ):
# python 3.9 does not support kw_only on the dataclass :(. # python 3.9 does not support kw_only on the dataclass :(.
self.default = default self.default = default
@ -91,10 +101,18 @@ class _Config(Generic[T]):
self.env_name_default = env_name_default self.env_name_default = env_name_default
self.env_name_force = env_name_force self.env_name_force = env_name_force
self.value_type = value_type self.value_type = value_type
self.alias = alias
if self.justknob is not None: if self.justknob is not None:
assert isinstance( assert isinstance(
self.default, bool self.default, bool
), f"justknobs only support booleans, {self.default} is not a boolean" ), f"justknobs only support booleans, {self.default} is not a boolean"
if self.alias is not None:
assert (
default is _UNSET_SENTINEL
and justknob is None
and env_name_default is None
and env_name_force is None
), "if alias is set, default, justknob or env var cannot be set"
# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this, # In runtime, we unbox the Config[T] to a T, but typechecker cannot see this,
@ -104,24 +122,28 @@ class _Config(Generic[T]):
if TYPE_CHECKING: if TYPE_CHECKING:
def Config( def Config(
default: T, default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None, justknob: Optional[str] = None,
env_name_default: Optional[str] = None, env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None, env_name_force: Optional[str] = None,
value_type: Optional[type] = None, value_type: Optional[type] = None,
alias: Optional[str] = None,
) -> T: ) -> T:
... ...
else: else:
def Config( def Config(
default: T, default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None, justknob: Optional[str] = None,
env_name_default: Optional[str] = None, env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None, env_name_force: Optional[str] = None,
value_type: Optional[type] = None, value_type: Optional[type] = None,
alias: Optional[str] = None,
) -> _Config[T]: ) -> _Config[T]:
return _Config(default, justknob, env_name_default, env_name_force, value_type) return _Config(
default, justknob, env_name_default, env_name_force, value_type, alias
)
def _read_env_variable(name: str) -> Optional[bool]: def _read_env_variable(name: str) -> Optional[bool]:
@ -243,9 +265,6 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str
return assignments return assignments
_UNSET_SENTINEL = object()
@dataclass @dataclass
class _ConfigEntry: class _ConfigEntry:
# The default value specified in the configuration # The default value specified in the configuration
@ -272,6 +291,7 @@ class _ConfigEntry:
# call so the final state is correct. It's just very unintuitive. # call so the final state is correct. It's just very unintuitive.
# upstream bug - python/cpython#126886 # upstream bug - python/cpython#126886
hide: bool = False hide: bool = False
alias: Optional[str] = None
def __init__(self, config: _Config): def __init__(self, config: _Config):
self.default = config.default self.default = config.default
@ -279,6 +299,7 @@ class _ConfigEntry:
config.value_type if config.value_type is not None else type(self.default) config.value_type if config.value_type is not None else type(self.default)
) )
self.justknob = config.justknob self.justknob = config.justknob
self.alias = config.alias
if config.env_name_default is not None: if config.env_name_default is not None:
if (env_value := _read_env_variable(config.env_name_default)) is not None: if (env_value := _read_env_variable(config.env_name_default)) is not None:
self.env_value_default = env_value self.env_value_default = env_value
@ -309,6 +330,8 @@ class ConfigModule(ModuleType):
super().__setattr__(name, value) super().__setattr__(name, value)
elif name not in self._config: elif name not in self._config:
raise AttributeError(f"{self.__name__}.{name} does not exist") raise AttributeError(f"{self.__name__}.{name} does not exist")
elif self._config[name].alias is not None:
self._set_alias_val(self._config[name], value)
else: else:
self._config[name].user_override = value self._config[name].user_override = value
self._is_dirty = True self._is_dirty = True
@ -321,6 +344,10 @@ class ConfigModule(ModuleType):
if config.hide: if config.hide:
raise AttributeError(f"{self.__name__}.{name} does not exist") raise AttributeError(f"{self.__name__}.{name} does not exist")
alias_val = self._get_alias_val(config)
if alias_val is not _UNSET_SENTINEL:
return alias_val
if config.env_value_force is not _UNSET_SENTINEL: if config.env_value_force is not _UNSET_SENTINEL:
return config.env_value_force return config.env_value_force
@ -353,6 +380,33 @@ class ConfigModule(ModuleType):
self._config[name].user_override = _UNSET_SENTINEL self._config[name].user_override = _UNSET_SENTINEL
self._config[name].hide = True self._config[name].hide = True
def _get_alias_module_and_name(
self, entry: _ConfigEntry
) -> Optional[Tuple[ModuleType, str]]:
alias = entry.alias
if alias is None:
return None
module_name, constant_name = alias.rsplit(".", 1)
try:
module = importlib.import_module(module_name)
except ImportError as e:
raise AttributeError("config alias {alias} does not exist") from e
return module, constant_name
def _get_alias_val(self, entry: _ConfigEntry) -> Any:
data = self._get_alias_module_and_name(entry)
if data is None:
return _UNSET_SENTINEL
module, constant_name = data
constant_value = getattr(module, constant_name)
return constant_value
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
data = self._get_alias_module_and_name(entry)
assert data is not None
module, constant_name = data
setattr(module, constant_name, val)
def _is_default(self, name: str) -> bool: def _is_default(self, name: str) -> bool:
return self._config[name].user_override is _UNSET_SENTINEL return self._config[name].user_override is _UNSET_SENTINEL
@ -369,6 +423,7 @@ class ConfigModule(ModuleType):
This is used by a number of different user facing export methods This is used by a number of different user facing export methods
which all have slightly different semantics re: how and what to which all have slightly different semantics re: how and what to
skip. skip.
If a config is aliased, it skips this config.
Arguments: Arguments:
ignored_keys are keys that should not be exported. ignored_keys are keys that should not be exported.
@ -391,7 +446,10 @@ class ConfigModule(ModuleType):
continue continue
if skip_default and self._is_default(key): if skip_default and self._is_default(key):
continue continue
if self._config[key].alias is not None:
continue
config[key] = copy.deepcopy(getattr(self, key)) config[key] = copy.deepcopy(getattr(self, key))
return config return config
def get_type(self, config_name: str) -> type: def get_type(self, config_name: str) -> type: