mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add config alias (#142088)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142088 Approved by: https://github.com/c00w
This commit is contained in:
parent
1b6b86fad7
commit
17b71e5d6a
|
|
@ -9,7 +9,10 @@ os.environ["ENV_FALSE"] = "0"
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from torch.testing._internal import fake_config_module as config
|
from torch.testing._internal import (
|
||||||
|
fake_config_module as config,
|
||||||
|
fake_config_module2 as config2,
|
||||||
|
)
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
from torch.utils._config_module import _UNSET_SENTINEL, Config
|
from torch.utils._config_module import _UNSET_SENTINEL, Config
|
||||||
|
|
||||||
|
|
@ -98,7 +101,7 @@ class TestConfigModule(TestCase):
|
||||||
|
|
||||||
def test_save_config(self):
|
def test_save_config(self):
|
||||||
p = config.save_config()
|
p = config.save_config()
|
||||||
self.assertEqual(
|
self.assertDictEqual(
|
||||||
pickle.loads(p),
|
pickle.loads(p),
|
||||||
{
|
{
|
||||||
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
||||||
|
|
@ -133,7 +136,7 @@ class TestConfigModule(TestCase):
|
||||||
|
|
||||||
def test_save_config_portable(self):
|
def test_save_config_portable(self):
|
||||||
p = config.save_config_portable()
|
p = config.save_config_portable()
|
||||||
self.assertEqual(
|
self.assertDictEqual(
|
||||||
p,
|
p,
|
||||||
{
|
{
|
||||||
"e_bool": True,
|
"e_bool": True,
|
||||||
|
|
@ -174,22 +177,36 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_hash(self):
|
def test_get_hash(self):
|
||||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
hash_value = b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||||
|
self.assertEqual(
|
||||||
|
config.get_hash(),
|
||||||
|
hash_value,
|
||||||
|
)
|
||||||
# Test cached value
|
# Test cached value
|
||||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
self.assertEqual(
|
||||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
config.get_hash(),
|
||||||
|
hash_value,
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
config.get_hash(),
|
||||||
|
hash_value,
|
||||||
|
)
|
||||||
config._hash_digest = "fake"
|
config._hash_digest = "fake"
|
||||||
self.assertEqual(config.get_hash(), "fake")
|
self.assertEqual(config.get_hash(), "fake")
|
||||||
|
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
config.get_hash(),
|
||||||
|
hash_value,
|
||||||
)
|
)
|
||||||
config.e_bool = True
|
config.e_bool = True
|
||||||
|
|
||||||
# Test ignored values
|
# Test ignored values
|
||||||
config.e_compile_ignored = False
|
config.e_compile_ignored = False
|
||||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
self.assertEqual(
|
||||||
|
config.get_hash(),
|
||||||
|
hash_value,
|
||||||
|
)
|
||||||
|
|
||||||
def test_dict_copy_semantics(self):
|
def test_dict_copy_semantics(self):
|
||||||
p = config.shallow_copy_dict()
|
p = config.shallow_copy_dict()
|
||||||
|
|
@ -319,6 +336,15 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
):
|
):
|
||||||
Config(default="bad", justknob="fake_knob")
|
Config(default="bad", justknob="fake_knob")
|
||||||
|
|
||||||
|
def test_alias(self):
|
||||||
|
self.assertFalse(config2.e_aliasing_bool)
|
||||||
|
self.assertFalse(config.e_aliased_bool)
|
||||||
|
with config2.patch(e_aliasing_bool=True):
|
||||||
|
self.assertTrue(config2.e_aliasing_bool)
|
||||||
|
self.assertTrue(config.e_aliased_bool)
|
||||||
|
with config.patch(e_aliased_bool=True):
|
||||||
|
self.assertTrue(config2.e_aliasing_bool)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
|
|
||||||
|
|
@ -19,12 +19,15 @@ _e_ignored = True
|
||||||
magic_cache_config_ignored = True
|
magic_cache_config_ignored = True
|
||||||
# [@compile_ignored: debug]
|
# [@compile_ignored: debug]
|
||||||
e_compile_ignored = True
|
e_compile_ignored = True
|
||||||
e_config = Config(default=True)
|
e_config: bool = Config(default=True)
|
||||||
e_jk = Config(justknob="does_not_exist", default=True)
|
e_jk: bool = Config(justknob="does_not_exist", default=True)
|
||||||
e_jk_false = Config(justknob="does_not_exist", default=False)
|
e_jk_false: bool = Config(justknob="does_not_exist", default=False)
|
||||||
e_env_default = Config(env_name_default="ENV_TRUE", default=False)
|
e_env_default: bool = Config(env_name_default="ENV_TRUE", default=False)
|
||||||
e_env_default_FALSE = Config(env_name_default="ENV_FALSE", default=True)
|
e_env_default_FALSE: bool = Config(env_name_default="ENV_FALSE", default=True)
|
||||||
e_env_force = Config(env_name_force="ENV_TRUE", default=False)
|
e_env_force: bool = Config(env_name_force="ENV_TRUE", default=False)
|
||||||
|
e_aliased_bool: bool = Config(
|
||||||
|
alias="torch.testing._internal.fake_config_module2.e_aliasing_bool"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class nested:
|
class nested:
|
||||||
|
|
|
||||||
8
torch/testing/_internal/fake_config_module2.py
Normal file
8
torch/testing/_internal/fake_config_module2.py
Normal file
|
|
@ -0,0 +1,8 @@
|
||||||
|
import sys
|
||||||
|
|
||||||
|
from torch.utils._config_module import install_config_module
|
||||||
|
|
||||||
|
|
||||||
|
e_aliasing_bool = False
|
||||||
|
|
||||||
|
install_config_module(sys.modules[__name__])
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
import contextlib
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
|
|
@ -20,6 +21,7 @@ from typing import (
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Set,
|
Set,
|
||||||
|
Tuple,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
|
|
@ -38,6 +40,9 @@ CONFIG_TYPES = (int, float, bool, type(None), str, list, set, tuple, dict)
|
||||||
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
|
T = TypeVar("T", bound=Union[int, float, bool, None, str, list, set, tuple, dict])
|
||||||
|
|
||||||
|
|
||||||
|
_UNSET_SENTINEL = object()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _Config(Generic[T]):
|
class _Config(Generic[T]):
|
||||||
"""Represents a config with richer behaviour than just a default value.
|
"""Represents a config with richer behaviour than just a default value.
|
||||||
|
|
@ -49,7 +54,9 @@ class _Config(Generic[T]):
|
||||||
This configs must be installed with install_config_module to be used
|
This configs must be installed with install_config_module to be used
|
||||||
|
|
||||||
Precedence Order:
|
Precedence Order:
|
||||||
env_name_force: If set, this environment variable overrides everything
|
alias: If set, the directly use the value of the alias.
|
||||||
|
env_name_force: If set, this environment variable has precedence over
|
||||||
|
everything after this.
|
||||||
user_override: If a user sets a value (i.e. foo.bar=True), that
|
user_override: If a user sets a value (i.e. foo.bar=True), that
|
||||||
has precedence over everything after this.
|
has precedence over everything after this.
|
||||||
env_name_default: If set, this environment variable will override everything
|
env_name_default: If set, this environment variable will override everything
|
||||||
|
|
@ -65,25 +72,28 @@ class _Config(Generic[T]):
|
||||||
Arguments:
|
Arguments:
|
||||||
justknob: the name of the feature / JK. In OSS this is unused.
|
justknob: the name of the feature / JK. In OSS this is unused.
|
||||||
default: is the value to default this knob to in OSS.
|
default: is the value to default this knob to in OSS.
|
||||||
|
alias: The alias config to read instead.
|
||||||
env_name_force: The environment variable to read that is a FORCE
|
env_name_force: The environment variable to read that is a FORCE
|
||||||
environment variable. I.e. it overrides everything
|
environment variable. I.e. it overrides everything except for alias.
|
||||||
env_name_default: The environment variable to read that changes the
|
env_name_default: The environment variable to read that changes the
|
||||||
default behaviour. I.e. user overrides take preference.
|
default behaviour. I.e. user overrides take preference.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
default: T
|
default: Union[T, object]
|
||||||
justknob: Optional[str] = None
|
justknob: Optional[str] = None
|
||||||
env_name_default: Optional[str] = None
|
env_name_default: Optional[str] = None
|
||||||
env_name_force: Optional[str] = None
|
env_name_force: Optional[str] = None
|
||||||
value_type: Optional[type] = None
|
value_type: Optional[type] = None
|
||||||
|
alias: Optional[str] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
default: T,
|
default: Union[T, object] = _UNSET_SENTINEL,
|
||||||
justknob: Optional[str] = None,
|
justknob: Optional[str] = None,
|
||||||
env_name_default: Optional[str] = None,
|
env_name_default: Optional[str] = None,
|
||||||
env_name_force: Optional[str] = None,
|
env_name_force: Optional[str] = None,
|
||||||
value_type: Optional[type] = None,
|
value_type: Optional[type] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
):
|
):
|
||||||
# python 3.9 does not support kw_only on the dataclass :(.
|
# python 3.9 does not support kw_only on the dataclass :(.
|
||||||
self.default = default
|
self.default = default
|
||||||
|
|
@ -91,10 +101,18 @@ class _Config(Generic[T]):
|
||||||
self.env_name_default = env_name_default
|
self.env_name_default = env_name_default
|
||||||
self.env_name_force = env_name_force
|
self.env_name_force = env_name_force
|
||||||
self.value_type = value_type
|
self.value_type = value_type
|
||||||
|
self.alias = alias
|
||||||
if self.justknob is not None:
|
if self.justknob is not None:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
self.default, bool
|
self.default, bool
|
||||||
), f"justknobs only support booleans, {self.default} is not a boolean"
|
), f"justknobs only support booleans, {self.default} is not a boolean"
|
||||||
|
if self.alias is not None:
|
||||||
|
assert (
|
||||||
|
default is _UNSET_SENTINEL
|
||||||
|
and justknob is None
|
||||||
|
and env_name_default is None
|
||||||
|
and env_name_force is None
|
||||||
|
), "if alias is set, default, justknob or env var cannot be set"
|
||||||
|
|
||||||
|
|
||||||
# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this,
|
# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this,
|
||||||
|
|
@ -104,24 +122,28 @@ class _Config(Generic[T]):
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
def Config(
|
def Config(
|
||||||
default: T,
|
default: Union[T, object] = _UNSET_SENTINEL,
|
||||||
justknob: Optional[str] = None,
|
justknob: Optional[str] = None,
|
||||||
env_name_default: Optional[str] = None,
|
env_name_default: Optional[str] = None,
|
||||||
env_name_force: Optional[str] = None,
|
env_name_force: Optional[str] = None,
|
||||||
value_type: Optional[type] = None,
|
value_type: Optional[type] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
) -> T:
|
) -> T:
|
||||||
...
|
...
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
def Config(
|
def Config(
|
||||||
default: T,
|
default: Union[T, object] = _UNSET_SENTINEL,
|
||||||
justknob: Optional[str] = None,
|
justknob: Optional[str] = None,
|
||||||
env_name_default: Optional[str] = None,
|
env_name_default: Optional[str] = None,
|
||||||
env_name_force: Optional[str] = None,
|
env_name_force: Optional[str] = None,
|
||||||
value_type: Optional[type] = None,
|
value_type: Optional[type] = None,
|
||||||
|
alias: Optional[str] = None,
|
||||||
) -> _Config[T]:
|
) -> _Config[T]:
|
||||||
return _Config(default, justknob, env_name_default, env_name_force, value_type)
|
return _Config(
|
||||||
|
default, justknob, env_name_default, env_name_force, value_type, alias
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _read_env_variable(name: str) -> Optional[bool]:
|
def _read_env_variable(name: str) -> Optional[bool]:
|
||||||
|
|
@ -243,9 +265,6 @@ def get_assignments_with_compile_ignored_comments(module: ModuleType) -> Set[str
|
||||||
return assignments
|
return assignments
|
||||||
|
|
||||||
|
|
||||||
_UNSET_SENTINEL = object()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _ConfigEntry:
|
class _ConfigEntry:
|
||||||
# The default value specified in the configuration
|
# The default value specified in the configuration
|
||||||
|
|
@ -272,6 +291,7 @@ class _ConfigEntry:
|
||||||
# call so the final state is correct. It's just very unintuitive.
|
# call so the final state is correct. It's just very unintuitive.
|
||||||
# upstream bug - python/cpython#126886
|
# upstream bug - python/cpython#126886
|
||||||
hide: bool = False
|
hide: bool = False
|
||||||
|
alias: Optional[str] = None
|
||||||
|
|
||||||
def __init__(self, config: _Config):
|
def __init__(self, config: _Config):
|
||||||
self.default = config.default
|
self.default = config.default
|
||||||
|
|
@ -279,6 +299,7 @@ class _ConfigEntry:
|
||||||
config.value_type if config.value_type is not None else type(self.default)
|
config.value_type if config.value_type is not None else type(self.default)
|
||||||
)
|
)
|
||||||
self.justknob = config.justknob
|
self.justknob = config.justknob
|
||||||
|
self.alias = config.alias
|
||||||
if config.env_name_default is not None:
|
if config.env_name_default is not None:
|
||||||
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
||||||
self.env_value_default = env_value
|
self.env_value_default = env_value
|
||||||
|
|
@ -309,6 +330,8 @@ class ConfigModule(ModuleType):
|
||||||
super().__setattr__(name, value)
|
super().__setattr__(name, value)
|
||||||
elif name not in self._config:
|
elif name not in self._config:
|
||||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||||
|
elif self._config[name].alias is not None:
|
||||||
|
self._set_alias_val(self._config[name], value)
|
||||||
else:
|
else:
|
||||||
self._config[name].user_override = value
|
self._config[name].user_override = value
|
||||||
self._is_dirty = True
|
self._is_dirty = True
|
||||||
|
|
@ -321,6 +344,10 @@ class ConfigModule(ModuleType):
|
||||||
if config.hide:
|
if config.hide:
|
||||||
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
raise AttributeError(f"{self.__name__}.{name} does not exist")
|
||||||
|
|
||||||
|
alias_val = self._get_alias_val(config)
|
||||||
|
if alias_val is not _UNSET_SENTINEL:
|
||||||
|
return alias_val
|
||||||
|
|
||||||
if config.env_value_force is not _UNSET_SENTINEL:
|
if config.env_value_force is not _UNSET_SENTINEL:
|
||||||
return config.env_value_force
|
return config.env_value_force
|
||||||
|
|
||||||
|
|
@ -353,6 +380,33 @@ class ConfigModule(ModuleType):
|
||||||
self._config[name].user_override = _UNSET_SENTINEL
|
self._config[name].user_override = _UNSET_SENTINEL
|
||||||
self._config[name].hide = True
|
self._config[name].hide = True
|
||||||
|
|
||||||
|
def _get_alias_module_and_name(
|
||||||
|
self, entry: _ConfigEntry
|
||||||
|
) -> Optional[Tuple[ModuleType, str]]:
|
||||||
|
alias = entry.alias
|
||||||
|
if alias is None:
|
||||||
|
return None
|
||||||
|
module_name, constant_name = alias.rsplit(".", 1)
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
except ImportError as e:
|
||||||
|
raise AttributeError("config alias {alias} does not exist") from e
|
||||||
|
return module, constant_name
|
||||||
|
|
||||||
|
def _get_alias_val(self, entry: _ConfigEntry) -> Any:
|
||||||
|
data = self._get_alias_module_and_name(entry)
|
||||||
|
if data is None:
|
||||||
|
return _UNSET_SENTINEL
|
||||||
|
module, constant_name = data
|
||||||
|
constant_value = getattr(module, constant_name)
|
||||||
|
return constant_value
|
||||||
|
|
||||||
|
def _set_alias_val(self, entry: _ConfigEntry, val: Any) -> None:
|
||||||
|
data = self._get_alias_module_and_name(entry)
|
||||||
|
assert data is not None
|
||||||
|
module, constant_name = data
|
||||||
|
setattr(module, constant_name, val)
|
||||||
|
|
||||||
def _is_default(self, name: str) -> bool:
|
def _is_default(self, name: str) -> bool:
|
||||||
return self._config[name].user_override is _UNSET_SENTINEL
|
return self._config[name].user_override is _UNSET_SENTINEL
|
||||||
|
|
||||||
|
|
@ -369,6 +423,7 @@ class ConfigModule(ModuleType):
|
||||||
This is used by a number of different user facing export methods
|
This is used by a number of different user facing export methods
|
||||||
which all have slightly different semantics re: how and what to
|
which all have slightly different semantics re: how and what to
|
||||||
skip.
|
skip.
|
||||||
|
If a config is aliased, it skips this config.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
ignored_keys are keys that should not be exported.
|
ignored_keys are keys that should not be exported.
|
||||||
|
|
@ -391,7 +446,10 @@ class ConfigModule(ModuleType):
|
||||||
continue
|
continue
|
||||||
if skip_default and self._is_default(key):
|
if skip_default and self._is_default(key):
|
||||||
continue
|
continue
|
||||||
|
if self._config[key].alias is not None:
|
||||||
|
continue
|
||||||
config[key] = copy.deepcopy(getattr(self, key))
|
config[key] = copy.deepcopy(getattr(self, key))
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def get_type(self, config_name: str) -> type:
|
def get_type(self, config_name: str) -> type:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user