mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type annotations to Configs (#139833)
Summary: Adds types to Configs, and fixes a bug in options that was caused by the lack of types. fixes: https://github.com/pytorch/pytorch/issues/139822 Configs are used by many modules so not sure which label to put. Types also allow https://github.com/pytorch/pytorch/pull/139736 to fuzz configs Pull Request resolved: https://github.com/pytorch/pytorch/pull/139833 Approved by: https://github.com/c00w
This commit is contained in:
parent
5203138483
commit
2037ea3e15
|
|
@ -6,6 +6,8 @@ import pickle
|
|||
os.environ["ENV_TRUE"] = "1"
|
||||
os.environ["ENV_FALSE"] = "0"
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from torch.testing._internal import fake_config_module as config
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.utils._config_module import _UNSET_SENTINEL
|
||||
|
|
@ -15,6 +17,7 @@ class TestConfigModule(TestCase):
|
|||
def test_base_value_loading(self):
|
||||
self.assertTrue(config.e_bool)
|
||||
self.assertTrue(config.nested.e_bool)
|
||||
self.assertTrue(config.e_optional)
|
||||
self.assertEqual(config.e_int, 1)
|
||||
self.assertEqual(config.e_float, 1.0)
|
||||
self.assertEqual(config.e_string, "string")
|
||||
|
|
@ -28,6 +31,10 @@ class TestConfigModule(TestCase):
|
|||
):
|
||||
config.does_not_exist
|
||||
|
||||
def test_type_loading(self):
|
||||
self.assertEqual(config.get_type("e_optional"), Optional[bool])
|
||||
self.assertEqual(config.get_type("e_none"), Optional[bool])
|
||||
|
||||
def test_overrides(self):
|
||||
config.e_bool = False
|
||||
self.assertFalse(config.e_bool)
|
||||
|
|
@ -51,6 +58,10 @@ class TestConfigModule(TestCase):
|
|||
self.assertEqual(config.e_none, "not none")
|
||||
config.e_none = None
|
||||
self.assertEqual(config.e_none, None)
|
||||
config.e_optional = None
|
||||
self.assertEqual(config.e_optional, None)
|
||||
config.e_optional = False
|
||||
self.assertEqual(config.e_optional, False)
|
||||
with self.assertRaises(
|
||||
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
||||
):
|
||||
|
|
@ -112,6 +123,7 @@ class TestConfigModule(TestCase):
|
|||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
config.e_bool = False
|
||||
|
|
@ -145,6 +157,7 @@ class TestConfigModule(TestCase):
|
|||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
config.e_bool = False
|
||||
|
|
@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
def test_get_hash(self):
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
# Test cached value
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
config._hash_digest = "fake"
|
||||
self.assertEqual(config.get_hash(), "fake")
|
||||
|
||||
config.e_bool = False
|
||||
self.assertNotEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||
)
|
||||
config.e_bool = True
|
||||
|
||||
# Test ignored values
|
||||
config.e_compile_ignored = False
|
||||
self.assertEqual(
|
||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
||||
)
|
||||
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||
for k in config._config:
|
||||
config._config[k].user_override = _UNSET_SENTINEL
|
||||
|
||||
|
|
@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
p2 = config.to_dict()
|
||||
|
|
@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
p3 = config.get_config_copy()
|
||||
|
|
@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
|||
"e_env_default": True,
|
||||
"e_env_default_FALSE": False,
|
||||
"e_env_force": True,
|
||||
"e_optional": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ from typing import (
|
|||
Any as _Any,
|
||||
Callable as _Callable,
|
||||
Dict as _Dict,
|
||||
get_origin as _get_origin,
|
||||
Optional as _Optional,
|
||||
overload as _overload,
|
||||
Set as _Set,
|
||||
|
|
@ -2280,7 +2281,12 @@ class _TorchCompileInductorWrapper:
|
|||
raise RuntimeError(
|
||||
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
|
||||
)
|
||||
if type(val) is not type(current_config[attr_name]):
|
||||
attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
|
||||
# Subscriptable generic types don't support isinstance so skip the type
|
||||
# check. There doesn't seem to be a good way of checking membership without
|
||||
# 3rd party libraries.
|
||||
if _get_origin(attr_type) is None:
|
||||
if not isinstance(val, attr_type):
|
||||
val_type_str = type(val).__name__
|
||||
expected_type_str = type(current_config[attr_name]).__name__
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ e_set = {1}
|
|||
e_tuple = (1,)
|
||||
e_dict = {1: 2}
|
||||
e_none: Optional[bool] = None
|
||||
e_optional: Optional[bool] = True
|
||||
e_ignored = True
|
||||
_e_ignored = True
|
||||
magic_cache_config_ignored = True
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import inspect
|
|||
import io
|
||||
import os
|
||||
import pickle
|
||||
import sys
|
||||
import tokenize
|
||||
import unittest
|
||||
import warnings
|
||||
|
|
@ -54,6 +55,7 @@ class Config:
|
|||
justknob: Optional[str] = None
|
||||
env_name_default: Optional[str] = None
|
||||
env_name_force: Optional[str] = None
|
||||
value_type: Optional[type] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -61,12 +63,14 @@ class Config:
|
|||
justknob: Optional[str] = None,
|
||||
env_name_default: Optional[str] = None,
|
||||
env_name_force: Optional[str] = None,
|
||||
value_type: Optional[type] = None,
|
||||
):
|
||||
# python 3.9 does not support kw_only on the dataclass :(.
|
||||
self.default = default
|
||||
self.justknob = justknob
|
||||
self.env_name_default = env_name_default
|
||||
self.env_name_force = env_name_force
|
||||
self.value_type = value_type
|
||||
|
||||
|
||||
# Types saved/loaded in configs
|
||||
|
|
@ -99,6 +103,10 @@ def install_config_module(module: ModuleType) -> None:
|
|||
prefix: str,
|
||||
) -> None:
|
||||
"""Walk the module structure and move everything to module._config"""
|
||||
if sys.version_info[:2] < (3, 10):
|
||||
type_hints = getattr(source, "__annotations__", {})
|
||||
else:
|
||||
type_hints = inspect.get_annotations(source)
|
||||
for key, value in list(source.__dict__.items()):
|
||||
if (
|
||||
key.startswith("__")
|
||||
|
|
@ -111,7 +119,10 @@ def install_config_module(module: ModuleType) -> None:
|
|||
|
||||
name = f"{prefix}{key}"
|
||||
if isinstance(value, CONFIG_TYPES):
|
||||
config[name] = _ConfigEntry(Config(default=value))
|
||||
annotated_type = type_hints.get(key, None)
|
||||
config[name] = _ConfigEntry(
|
||||
Config(default=value, value_type=annotated_type)
|
||||
)
|
||||
if dest is module:
|
||||
delattr(module, key)
|
||||
elif isinstance(value, Config):
|
||||
|
|
@ -192,6 +203,8 @@ _UNSET_SENTINEL = object()
|
|||
class _ConfigEntry:
|
||||
# The default value specified in the configuration
|
||||
default: Any
|
||||
# The type of the configuration value
|
||||
value_type: type
|
||||
# The value specified by the user when they overrode the configuration
|
||||
# _UNSET_SENTINEL indicates the value is not set.
|
||||
user_override: Any = _UNSET_SENTINEL
|
||||
|
|
@ -203,6 +216,9 @@ class _ConfigEntry:
|
|||
|
||||
def __init__(self, config: Config):
|
||||
self.default = config.default
|
||||
self.value_type = (
|
||||
config.value_type if config.value_type is not None else type(self.default)
|
||||
)
|
||||
self.justknob = config.justknob
|
||||
if config.env_name_default is not None:
|
||||
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
||||
|
|
@ -314,6 +330,9 @@ class ConfigModule(ModuleType):
|
|||
config[key] = copy.deepcopy(getattr(self, key))
|
||||
return config
|
||||
|
||||
def get_type(self, config_name: str) -> type:
|
||||
return self._config[config_name].value_type
|
||||
|
||||
def save_config(self) -> bytes:
|
||||
"""Convert config to a pickled blob"""
|
||||
ignored_keys = getattr(self, "_save_config_ignore", [])
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user