mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add config alias (#142088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142088 Approved by: https://github.com/c00w
This commit is contained in:
parent
1b6b86fad7
commit
17b71e5d6a
|
|
@ -9,7 +9,10 @@ os.environ["ENV_FALSE"] = "0"
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from torch.testing._internal import fake_config_module as config
|
||||
from torch.testing._internal import (
|
||||
fake_config_module as config,
|
||||
fake_config_module2 as config2,
|
||||
)
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._config_module import _UNSET_SENTINEL, Config
|
||||
|
||||
|
|
@ -98,7 +101,7 @@ class TestConfigModule(TestCase):
|
|||
|
||||
def test_save_config(self):
|
||||
p = config.save_config()
|
||||
self.assertEqual(
|
||||
self.assertDictEqual(
|
||||
pickle.loads(p),
|
||||
{
|
||||
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
||||
|
|
@ -133,7 +136,7 @@ class TestConfigModule(TestCase):
|
|||
|
||||
def test_save_config_portable(self):
|
||||
p = config.save_config_portable()
|
||||
self.assertEqual(
|
||||
self.assertDictEqual(
|
||||
p,
|
||||
{
|
||||
"e_bool": True,
|
||||
|
|
@ -174,22 +177,36 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
)
|
||||
|
||||
def test_get_hash(self):
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
hash_value = b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||
self.assertEqual(
|
||||
config.get_hash(),
|
||||
hash_value,
|
||||
)
|
||||
# Test cached value
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
self.assertEqual(
|
||||
config.get_hash(),
|
||||
hash_value,
|
||||
)
|
||||
self.assertEqual(
|
||||
config.get_hash(),
|
||||
hash_value,
|
||||
)
|
||||
config._hash_digest = "fake"
|
||||
self.assertEqual(config.get_hash(), "fake")
|
||||
|
||||
config.e_bool = False
|
||||
self.assertNotEqual(
|
||||
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||
config.get_hash(),
|
||||
hash_value,
|
||||
)
|
||||
config.e_bool = True
|
||||
|
||||
# Test ignored values
|
||||
config.e_compile_ignored = False
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
self.assertEqual(
|
||||
config.get_hash(),
|
||||
hash_value,
|
||||
)
|
||||
|
||||
def test_dict_copy_semantics(self):
|
||||
p = config.shallow_copy_dict()
|
||||
|
|
@ -319,6 +336,15 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
):
|
||||
Config(default="bad", justknob="fake_knob")
|
||||
|
||||
def test_alias(self):
|
||||
self.assertFalse(config2.e_aliasing_bool)
|
||||
self.assertFalse(config.e_aliased_bool)
|
||||
with config2.patch(e_aliasing_bool=True):
|
||||
self.assertTrue(config2.e_aliasing_bool)
|
||||
self.assertTrue(config.e_aliased_bool)
|
||||
with config.patch(e_aliased_bool=True):
|
||||
self.assertTrue(config2.e_aliasing_bool)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -19,12 +19,15 @@ _e_ignored = True
|
|||
magic_cache_config_ignored = True
|
||||
# [@compile_ignored: debug]
|
||||
e_compile_ignored = True
|
||||
e_config = Config(default=True)
|
||||
e_jk = Config(justknob="does_not_exist", default=True)
|
||||
e_jk_false = Config(justknob="does_not_exist", default=False)
|
||||
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
|
||||
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
|
||||
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
|
||||
e_config: bool = Config(default=True)
|
||||
e_jk: bool = Config(justknob="does_not_exist", default=True)
|
||||
e_jk_false: bool = Config(justknob="does_not_exist", default=False)
|
||||
e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False)
|
||||
e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True)
|
||||
e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False)
|
||||
e_aliased_bool: bool = Config(
|
||||
alias="torch.testing._internal.fake_config_module2.e_aliasing_bool"
|
||||
)
|
||||
|
||||
|
||||
class nested:
|
||||
|
|
|
|||
8
torch/testing/_internal/fake_config_module2.py
Normal file
8
torch/testing/_internal/fake_config_module2.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import sys
|
||||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
e_aliasing_bool = False
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import contextlib
|
||||
import copy
|
||||
import hashlib
|
||||
import importlib
|
||||
import inspect
|
||||
import io
|
||||
import os
|
||||
|
|
@ -20,6 +21,7 @@ from typing import (
|
|||
NoReturn,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
|
|
@ -38,6 +40,9 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
|||
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
|
||||
|
||||
|
||||
_UNSET_SENTINEL = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _Config(Generic[T]):
|
||||
"""Represents a config with richer behaviour than just a default value.
|
||||
|
|
@ -49,7 +54,9 @@ class _Config(Generic[T]):
|
|||
This configs must be installed with install_config_module to be used
|
||||
|
||||
Precedence Order:
|
||||
env_name_force: If set, this environment variable overrides everything
|
||||
alias: If set, the directly use the value of the alias.
|
||||
env_name_force: If set, this environment variable has precedence over
|
||||
everything after this.
|
||||
user_override: If a user sets a value (i.e. foo.bar=True), that
|
||||
has precedence over everything after this.
|
||||
env_name_default: If set, this environment variable will override everything
|
||||
|
|
@ -65,25 +72,28 @@ class _Config(Generic[T]):
|
|||
Arguments:
|
||||
justknob: the name of the feature / JK. In OSS this is unused.
|
||||
default: is the value to default this knob to in OSS.
|
||||
alias: The alias config to read instead.
|
||||
env_name_force: The environment variable to read that is a FORCE
|
||||
environment variable. I.e. it overrides everything
|
||||
environment variable. I.e. it overrides everything except for alias.
|
||||
env_name_default: The environment variable to read that changes the
|
||||
default behaviour. I.e. user overrides take preference.
|
||||
"""
|
||||
|
||||
default: T
|
||||
default: Union[T, object]
|
||||
justknob: Optional[str] = None
|
||||
env_name_default: Optional[str] = None
|
||||
env_name_force: Optional[str] = None
|
||||
value_type: Optional[type] = None
|
||||
alias: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
default: T,
|
||||
default: Union[T, object] = _UNSET_SENTINEL,
|
||||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[str] = None,
|
||||
env_name_force: Optional[str] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
):
|
||||
# python 3.9 does not support kw_only on the dataclass :(.
|
||||
self.default = default
|
||||
|
|
@ -91,10 +101,18 @@ class _Config(Generic[T]):
|
|||
self.env_name_default = env_name_default
|
||||
self.env_name_force = env_name_force
|
||||
self.value_type = value_type
|
||||
self.alias = alias
|
||||
if self.justknob is not None:
|
||||
assert isinstance(
|
||||
self.default, bool
|
||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
||||
if self.alias is not None:
|
||||
assert (
|
||||
default is _UNSET_SENTINEL
|
||||
and justknob is None
|
||||
and env_name_default is None
|
||||
and env_name_force is None
|
||||
), "if alias is set, default, justknob or env var cannot be set"
|
||||
|
||||
|
||||
# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this,
|
||||
|
|
@ -104,24 +122,28 @@ class _Config(Generic[T]):
|
|||
if TYPE_CHECKING:
|
||||
|
||||
def Config(
|
||||
default: T,
|
||||
default: Union[T, object] = _UNSET_SENTINEL,
|
||||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[str] = None,
|
||||
env_name_force: Optional[str] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
else:
|
||||
|
||||
def Config(
|
||||
default: T,
|
||||
default: Union[T, object] = _UNSET_SENTINEL,
|
||||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[str] = None,
|
||||
env_name_force: Optional[str] = None,
|
||||
value_type: Optional[type] = None,
|
||||
alias: Optional[str] = None,
|
||||
) -> _Config[T]:
|
||||
return _Config(default, justknob, env_name_default, env_name_force, value_type)
|
||||
return _Config(
|
||||
default, justknob, env_name_default, env_name_force, value_type, alias
|
||||
)
|
||||
|
||||
|
||||
def _read_env_variable(name: str) -> Optional[bool]:
|
||||
|
|
@ -243,9 +265,6 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str
|
|||
return assignments
|
||||
|
||||
|
||||
_UNSET_SENTINEL = object()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ConfigEntry:
|
||||
# The default value specified in the configuration
|
||||
|
|
@ -272,6 +291,7 @@ class _ConfigEntry:
|
|||
# call so the final state is correct. It's just very unintuitive.
|
||||
# upstream bug - python/cpython#126886
|
||||
hide: bool = False
|
||||
alias: Optional[str] = None
|
||||
|
||||
def __init__(self, config: _Config):
|
||||
self.default = config.default
|
||||
|
|
@ -279,6 +299,7 @@ class _ConfigEntry:
|
|||
config.value_type if config.value_type is not None else type(self.default)
|
||||
)
|
||||
self.justknob = config.justknob
|
||||
self.alias = config.alias
|
||||
if config.env_name_default is not None:
|
||||
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
||||
self.env_value_default = env_value
|
||||
|
|
@ -309,6 +330,8 @@ class ConfigModule(ModuleType):
|
|||
super().__setattr__(name, value)
|
||||
elif name not in self._config:
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||
elif self._config[name].alias is not None:
|
||||
self._set_alias_val(self._config[name], value)
|
||||
else:
|
||||
self._config[name].user_override = value
|
||||
self._is_dirty = True
|
||||
|
|
@ -321,6 +344,10 @@ class ConfigModule(ModuleType):
|
|||
if config.hide:
|
||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||
|
||||
alias_val = self._get_alias_val(config)
|
||||
if alias_val is not _UNSET_SENTINEL:
|
||||
return alias_val
|
||||
|
||||
if config.env_value_force is not _UNSET_SENTINEL:
|
||||
return config.env_value_force
|
||||
|
||||
|
|
@ -353,6 +380,33 @@ class ConfigModule(ModuleType):
|
|||
self._config[name].user_override = _UNSET_SENTINEL
|
||||
self._config[name].hide = True
|
||||
|
||||
def _get_alias_module_and_name(
|
||||
self, entry: _ConfigEntry
|
||||
) -> Optional[Tuple[ModuleType, str]]:
|
||||
alias = entry.alias
|
||||
if alias is None:
|
||||
return None
|
||||
module_name, constant_name = alias.rsplit(".", 1)
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except ImportError as e:
|
||||
raise AttributeError("config alias {alias} does not exist") from e
|
||||
return module, constant_name
|
||||
|
||||
def _get_alias_val(self, entry: _ConfigEntry) -> Any:
|
||||
data = self._get_alias_module_and_name(entry)
|
||||
if data is None:
|
||||
return _UNSET_SENTINEL
|
||||
module, constant_name = data
|
||||
constant_value = getattr(module, constant_name)
|
||||
return constant_value
|
||||
|
||||
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
|
||||
data = self._get_alias_module_and_name(entry)
|
||||
assert data is not None
|
||||
module, constant_name = data
|
||||
setattr(module, constant_name, val)
|
||||
|
||||
def _is_default(self, name: str) -> bool:
|
||||
return self._config[name].user_override is _UNSET_SENTINEL
|
||||
|
||||
|
|
@ -369,6 +423,7 @@ class ConfigModule(ModuleType):
|
|||
This is used by a number of different user facing export methods
|
||||
which all have slightly different semantics re: how and what to
|
||||
skip.
|
||||
If a config is aliased, it skips this config.
|
||||
|
||||
Arguments:
|
||||
ignored_keys are keys that should not be exported.
|
||||
|
|
@ -391,7 +446,10 @@ class ConfigModule(ModuleType):
|
|||
continue
|
||||
if skip_default and self._is_default(key):
|
||||
continue
|
||||
if self._config[key].alias is not None:
|
||||
continue
|
||||
config[key] = copy.deepcopy(getattr(self, key))
|
||||
|
||||
return config
|
||||
|
||||
def get_type(self, config_name: str) -> type:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user