mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type annotations to Configs (#139833)
Summary: Adds types to Configs, and fixes a bug in options that was caused by the lack of types. fixes: https://github.com/pytorch/pytorch/issues/139822 Configs are used by many modules so not sure which label to put. Types also allow https://github.com/pytorch/pytorch/pull/139736 to fuzz configs Pull Request resolved: https://github.com/pytorch/pytorch/pull/139833 Approved by: https://github.com/c00w
This commit is contained in:
parent
5203138483
commit
2037ea3e15
|
|
@ -6,6 +6,8 @@ import pickle
|
||||||
os.environ["ENV_TRUE"] = "1"
|
os.environ["ENV_TRUE"] = "1"
|
||||||
os.environ["ENV_FALSE"] = "0"
|
os.environ["ENV_FALSE"] = "0"
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from torch.testing._internal import fake_config_module as config
|
from torch.testing._internal import fake_config_module as config
|
||||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
from torch.utils._config_module import _UNSET_SENTINEL
|
from torch.utils._config_module import _UNSET_SENTINEL
|
||||||
|
|
@ -15,6 +17,7 @@ class TestConfigModule(TestCase):
|
||||||
def test_base_value_loading(self):
|
def test_base_value_loading(self):
|
||||||
self.assertTrue(config.e_bool)
|
self.assertTrue(config.e_bool)
|
||||||
self.assertTrue(config.nested.e_bool)
|
self.assertTrue(config.nested.e_bool)
|
||||||
|
self.assertTrue(config.e_optional)
|
||||||
self.assertEqual(config.e_int, 1)
|
self.assertEqual(config.e_int, 1)
|
||||||
self.assertEqual(config.e_float, 1.0)
|
self.assertEqual(config.e_float, 1.0)
|
||||||
self.assertEqual(config.e_string, "string")
|
self.assertEqual(config.e_string, "string")
|
||||||
|
|
@ -28,6 +31,10 @@ class TestConfigModule(TestCase):
|
||||||
):
|
):
|
||||||
config.does_not_exist
|
config.does_not_exist
|
||||||
|
|
||||||
|
def test_type_loading(self):
|
||||||
|
self.assertEqual(config.get_type("e_optional"), Optional[bool])
|
||||||
|
self.assertEqual(config.get_type("e_none"), Optional[bool])
|
||||||
|
|
||||||
def test_overrides(self):
|
def test_overrides(self):
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
self.assertFalse(config.e_bool)
|
self.assertFalse(config.e_bool)
|
||||||
|
|
@ -51,6 +58,10 @@ class TestConfigModule(TestCase):
|
||||||
self.assertEqual(config.e_none, "not none")
|
self.assertEqual(config.e_none, "not none")
|
||||||
config.e_none = None
|
config.e_none = None
|
||||||
self.assertEqual(config.e_none, None)
|
self.assertEqual(config.e_none, None)
|
||||||
|
config.e_optional = None
|
||||||
|
self.assertEqual(config.e_optional, None)
|
||||||
|
config.e_optional = False
|
||||||
|
self.assertEqual(config.e_optional, False)
|
||||||
with self.assertRaises(
|
with self.assertRaises(
|
||||||
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
||||||
):
|
):
|
||||||
|
|
@ -112,6 +123,7 @@ class TestConfigModule(TestCase):
|
||||||
"e_env_default": True,
|
"e_env_default": True,
|
||||||
"e_env_default_FALSE": False,
|
"e_env_default_FALSE": False,
|
||||||
"e_env_force": True,
|
"e_env_force": True,
|
||||||
|
"e_optional": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
|
|
@ -145,6 +157,7 @@ class TestConfigModule(TestCase):
|
||||||
"e_env_default": True,
|
"e_env_default": True,
|
||||||
"e_env_default_FALSE": False,
|
"e_env_default_FALSE": False,
|
||||||
"e_env_force": True,
|
"e_env_force": True,
|
||||||
|
"e_optional": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
|
|
@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
config._config[k].user_override = _UNSET_SENTINEL
|
config._config[k].user_override = _UNSET_SENTINEL
|
||||||
|
|
||||||
def test_get_hash(self):
|
def test_get_hash(self):
|
||||||
self.assertEqual(
|
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
||||||
)
|
|
||||||
# Test cached value
|
# Test cached value
|
||||||
self.assertEqual(
|
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||||
)
|
|
||||||
self.assertEqual(
|
|
||||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
||||||
)
|
|
||||||
config._hash_digest = "fake"
|
config._hash_digest = "fake"
|
||||||
self.assertEqual(config.get_hash(), "fake")
|
self.assertEqual(config.get_hash(), "fake")
|
||||||
|
|
||||||
config.e_bool = False
|
config.e_bool = False
|
||||||
self.assertNotEqual(
|
self.assertNotEqual(
|
||||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
|
||||||
)
|
)
|
||||||
config.e_bool = True
|
config.e_bool = True
|
||||||
|
|
||||||
# Test ignored values
|
# Test ignored values
|
||||||
config.e_compile_ignored = False
|
config.e_compile_ignored = False
|
||||||
self.assertEqual(
|
self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
|
||||||
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
||||||
)
|
|
||||||
for k in config._config:
|
for k in config._config:
|
||||||
config._config[k].user_override = _UNSET_SENTINEL
|
config._config[k].user_override = _UNSET_SENTINEL
|
||||||
|
|
||||||
|
|
@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
"e_env_default": True,
|
"e_env_default": True,
|
||||||
"e_env_default_FALSE": False,
|
"e_env_default_FALSE": False,
|
||||||
"e_env_force": True,
|
"e_env_force": True,
|
||||||
|
"e_optional": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
p2 = config.to_dict()
|
p2 = config.to_dict()
|
||||||
|
|
@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
"e_env_default": True,
|
"e_env_default": True,
|
||||||
"e_env_default_FALSE": False,
|
"e_env_default_FALSE": False,
|
||||||
"e_env_force": True,
|
"e_env_force": True,
|
||||||
|
"e_optional": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
p3 = config.get_config_copy()
|
p3 = config.get_config_copy()
|
||||||
|
|
@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
|
||||||
"e_env_default": True,
|
"e_env_default": True,
|
||||||
"e_env_default_FALSE": False,
|
"e_env_default_FALSE": False,
|
||||||
"e_env_force": True,
|
"e_env_force": True,
|
||||||
|
"e_optional": True,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ from typing import (
|
||||||
Any as _Any,
|
Any as _Any,
|
||||||
Callable as _Callable,
|
Callable as _Callable,
|
||||||
Dict as _Dict,
|
Dict as _Dict,
|
||||||
|
get_origin as _get_origin,
|
||||||
Optional as _Optional,
|
Optional as _Optional,
|
||||||
overload as _overload,
|
overload as _overload,
|
||||||
Set as _Set,
|
Set as _Set,
|
||||||
|
|
@ -2280,13 +2281,18 @@ class _TorchCompileInductorWrapper:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
|
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
|
||||||
)
|
)
|
||||||
if type(val) is not type(current_config[attr_name]):
|
attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
|
||||||
val_type_str = type(val).__name__
|
# Subscriptable generic types don't support isinstance so skip the type
|
||||||
expected_type_str = type(current_config[attr_name]).__name__
|
# check. There doesn't seem to be a good way of checking membership without
|
||||||
raise RuntimeError(
|
# 3rd party libraries.
|
||||||
f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
|
if _get_origin(attr_type) is None:
|
||||||
)
|
if not isinstance(val, attr_type):
|
||||||
self.config[attr_name] = val
|
val_type_str = type(val).__name__
|
||||||
|
expected_type_str = type(current_config[attr_name]).__name__
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
|
||||||
|
)
|
||||||
|
self.config[attr_name] = val
|
||||||
|
|
||||||
def __call__(self, model_, inputs_):
|
def __call__(self, model_, inputs_):
|
||||||
from torch._inductor.compile_fx import compile_fx
|
from torch._inductor.compile_fx import compile_fx
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ e_set = {1}
|
||||||
e_tuple = (1,)
|
e_tuple = (1,)
|
||||||
e_dict = {1: 2}
|
e_dict = {1: 2}
|
||||||
e_none: Optional[bool] = None
|
e_none: Optional[bool] = None
|
||||||
|
e_optional: Optional[bool] = True
|
||||||
e_ignored = True
|
e_ignored = True
|
||||||
_e_ignored = True
|
_e_ignored = True
|
||||||
magic_cache_config_ignored = True
|
magic_cache_config_ignored = True
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ import inspect
|
||||||
import io
|
import io
|
||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
import sys
|
||||||
import tokenize
|
import tokenize
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
|
@ -54,6 +55,7 @@ class Config:
|
||||||
justknob: Optional[str] = None
|
justknob: Optional[str] = None
|
||||||
env_name_default: Optional[str] = None
|
env_name_default: Optional[str] = None
|
||||||
env_name_force: Optional[str] = None
|
env_name_force: Optional[str] = None
|
||||||
|
value_type: Optional[type] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -61,12 +63,14 @@ class Config:
|
||||||
justknob: Optional[str] = None,
|
justknob: Optional[str] = None,
|
||||||
env_name_default: Optional[str] = None,
|
env_name_default: Optional[str] = None,
|
||||||
env_name_force: Optional[str] = None,
|
env_name_force: Optional[str] = None,
|
||||||
|
value_type: Optional[type] = None,
|
||||||
):
|
):
|
||||||
# python 3.9 does not support kw_only on the dataclass :(.
|
# python 3.9 does not support kw_only on the dataclass :(.
|
||||||
self.default = default
|
self.default = default
|
||||||
self.justknob = justknob
|
self.justknob = justknob
|
||||||
self.env_name_default = env_name_default
|
self.env_name_default = env_name_default
|
||||||
self.env_name_force = env_name_force
|
self.env_name_force = env_name_force
|
||||||
|
self.value_type = value_type
|
||||||
|
|
||||||
|
|
||||||
# Types saved/loaded in configs
|
# Types saved/loaded in configs
|
||||||
|
|
@ -99,6 +103,10 @@ def install_config_module(module: ModuleType) -> None:
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Walk the module structure and move everything to module._config"""
|
"""Walk the module structure and move everything to module._config"""
|
||||||
|
if sys.version_info[:2] < (3, 10):
|
||||||
|
type_hints = getattr(source, "__annotations__", {})
|
||||||
|
else:
|
||||||
|
type_hints = inspect.get_annotations(source)
|
||||||
for key, value in list(source.__dict__.items()):
|
for key, value in list(source.__dict__.items()):
|
||||||
if (
|
if (
|
||||||
key.startswith("__")
|
key.startswith("__")
|
||||||
|
|
@ -111,7 +119,10 @@ def install_config_module(module: ModuleType) -> None:
|
||||||
|
|
||||||
name = f"{prefix}{key}"
|
name = f"{prefix}{key}"
|
||||||
if isinstance(value, CONFIG_TYPES):
|
if isinstance(value, CONFIG_TYPES):
|
||||||
config[name] = _ConfigEntry(Config(default=value))
|
annotated_type = type_hints.get(key, None)
|
||||||
|
config[name] = _ConfigEntry(
|
||||||
|
Config(default=value, value_type=annotated_type)
|
||||||
|
)
|
||||||
if dest is module:
|
if dest is module:
|
||||||
delattr(module, key)
|
delattr(module, key)
|
||||||
elif isinstance(value, Config):
|
elif isinstance(value, Config):
|
||||||
|
|
@ -192,6 +203,8 @@ _UNSET_SENTINEL = object()
|
||||||
class _ConfigEntry:
|
class _ConfigEntry:
|
||||||
# The default value specified in the configuration
|
# The default value specified in the configuration
|
||||||
default: Any
|
default: Any
|
||||||
|
# The type of the configuration value
|
||||||
|
value_type: type
|
||||||
# The value specified by the user when they overrode the configuration
|
# The value specified by the user when they overrode the configuration
|
||||||
# _UNSET_SENTINEL indicates the value is not set.
|
# _UNSET_SENTINEL indicates the value is not set.
|
||||||
user_override: Any = _UNSET_SENTINEL
|
user_override: Any = _UNSET_SENTINEL
|
||||||
|
|
@ -203,6 +216,9 @@ class _ConfigEntry:
|
||||||
|
|
||||||
def __init__(self, config: Config):
|
def __init__(self, config: Config):
|
||||||
self.default = config.default
|
self.default = config.default
|
||||||
|
self.value_type = (
|
||||||
|
config.value_type if config.value_type is not None else type(self.default)
|
||||||
|
)
|
||||||
self.justknob = config.justknob
|
self.justknob = config.justknob
|
||||||
if config.env_name_default is not None:
|
if config.env_name_default is not None:
|
||||||
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
if (env_value := _read_env_variable(config.env_name_default)) is not None:
|
||||||
|
|
@ -314,6 +330,9 @@ class ConfigModule(ModuleType):
|
||||||
config[key] = copy.deepcopy(getattr(self, key))
|
config[key] = copy.deepcopy(getattr(self, key))
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
def get_type(self, config_name: str) -> type:
|
||||||
|
return self._config[config_name].value_type
|
||||||
|
|
||||||
def save_config(self) -> bytes:
|
def save_config(self) -> bytes:
|
||||||
"""Convert config to a pickled blob"""
|
"""Convert config to a pickled blob"""
|
||||||
ignored_keys = getattr(self, "_save_config_ignore", [])
|
ignored_keys = getattr(self, "_save_config_ignore", [])
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user