Add type annotations to Configs (#139833)

Summary:
Adds types to Configs, and fixes a bug in options that was caused by the lack of types.

fixes: https://github.com/pytorch/pytorch/issues/139822

Configs are used by many modules so not sure which label to put.

Types also allow https://github.com/pytorch/pytorch/pull/139736 to fuzz configs

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139833
Approved by: https://github.com/c00w
This commit is contained in:
Gabriel Ferns 2024-11-07 03:49:07 +00:00 committed by PyTorch MergeBot
parent 5203138483
commit 2037ea3e15
4 changed files with 55 additions and 21 deletions

View File

@ -6,6 +6,8 @@ import pickle
os.environ["ENV_TRUE"] = "1" os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0" os.environ["ENV_FALSE"] = "0"
from typing import Optional
from torch.testing._internal import fake_config_module as config from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL from torch.utils._config_module import _UNSET_SENTINEL
@ -15,6 +17,7 @@ class TestConfigModule(TestCase):
def test_base_value_loading(self): def test_base_value_loading(self):
self.assertTrue(config.e_bool) self.assertTrue(config.e_bool)
self.assertTrue(config.nested.e_bool) self.assertTrue(config.nested.e_bool)
self.assertTrue(config.e_optional)
self.assertEqual(config.e_int, 1) self.assertEqual(config.e_int, 1)
self.assertEqual(config.e_float, 1.0) self.assertEqual(config.e_float, 1.0)
self.assertEqual(config.e_string, "string") self.assertEqual(config.e_string, "string")
@ -28,6 +31,10 @@ class TestConfigModule(TestCase):
): ):
config.does_not_exist config.does_not_exist
def test_type_loading(self):
self.assertEqual(config.get_type("e_optional"), Optional[bool])
self.assertEqual(config.get_type("e_none"), Optional[bool])
def test_overrides(self): def test_overrides(self):
config.e_bool = False config.e_bool = False
self.assertFalse(config.e_bool) self.assertFalse(config.e_bool)
@ -51,6 +58,10 @@ class TestConfigModule(TestCase):
self.assertEqual(config.e_none, "not none") self.assertEqual(config.e_none, "not none")
config.e_none = None config.e_none = None
self.assertEqual(config.e_none, None) self.assertEqual(config.e_none, None)
config.e_optional = None
self.assertEqual(config.e_optional, None)
config.e_optional = False
self.assertEqual(config.e_optional, False)
with self.assertRaises( with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist" AttributeError, msg="fake_config_module.does_not_exist does not exist"
): ):
@ -112,6 +123,7 @@ class TestConfigModule(TestCase):
"e_env_default": True, "e_env_default": True,
"e_env_default_FALSE": False, "e_env_default_FALSE": False,
"e_env_force": True, "e_env_force": True,
"e_optional": True,
}, },
) )
config.e_bool = False config.e_bool = False
@ -145,6 +157,7 @@ class TestConfigModule(TestCase):
"e_env_default": True, "e_env_default": True,
"e_env_default_FALSE": False, "e_env_default_FALSE": False,
"e_env_force": True, "e_env_force": True,
"e_optional": True,
}, },
) )
config.e_bool = False config.e_bool = False
@ -173,30 +186,22 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
config._config[k].user_override = _UNSET_SENTINEL config._config[k].user_override = _UNSET_SENTINEL
def test_get_hash(self): def test_get_hash(self):
self.assertEqual( self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
# Test cached value # Test cached value
self.assertEqual( self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
)
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
config._hash_digest = "fake" config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake") self.assertEqual(config.get_hash(), "fake")
config.e_bool = False config.e_bool = False
self.assertNotEqual( self.assertNotEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c" config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ"
) )
config.e_bool = True config.e_bool = True
# Test ignored values # Test ignored values
config.e_compile_ignored = False config.e_compile_ignored = False
self.assertEqual( self.assertEqual(config.get_hash(), b"\xf2C\xdbo\x99qq\x12\x11\xf7\xb4\xeewVpZ")
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
for k in config._config: for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL config._config[k].user_override = _UNSET_SENTINEL
@ -227,6 +232,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True, "e_env_default": True,
"e_env_default_FALSE": False, "e_env_default_FALSE": False,
"e_env_force": True, "e_env_force": True,
"e_optional": True,
}, },
) )
p2 = config.to_dict() p2 = config.to_dict()
@ -255,6 +261,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True, "e_env_default": True,
"e_env_default_FALSE": False, "e_env_default_FALSE": False,
"e_env_force": True, "e_env_force": True,
"e_optional": True,
}, },
) )
p3 = config.get_config_copy() p3 = config.get_config_copy()
@ -283,6 +290,7 @@ torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""
"e_env_default": True, "e_env_default": True,
"e_env_default_FALSE": False, "e_env_default_FALSE": False,
"e_env_force": True, "e_env_force": True,
"e_optional": True,
}, },
) )

View File

@ -25,6 +25,7 @@ from typing import (
Any as _Any, Any as _Any,
Callable as _Callable, Callable as _Callable,
Dict as _Dict, Dict as _Dict,
get_origin as _get_origin,
Optional as _Optional, Optional as _Optional,
overload as _overload, overload as _overload,
Set as _Set, Set as _Set,
@ -2280,13 +2281,18 @@ class _TorchCompileInductorWrapper:
raise RuntimeError( raise RuntimeError(
f"Unexpected optimization option {key}, known options are {list(current_config.keys())}" f"Unexpected optimization option {key}, known options are {list(current_config.keys())}"
) )
if type(val) is not type(current_config[attr_name]): attr_type = config.get_type(attr_name) # type: ignore[attr-defined]
val_type_str = type(val).__name__ # Subscriptable generic types don't support isinstance so skip the type
expected_type_str = type(current_config[attr_name]).__name__ # check. There doesn't seem to be a good way of checking membership without
raise RuntimeError( # 3rd party libraries.
f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}" if _get_origin(attr_type) is None:
) if not isinstance(val, attr_type):
self.config[attr_name] = val val_type_str = type(val).__name__
expected_type_str = type(current_config[attr_name]).__name__
raise RuntimeError(
f"Unexpected type of attr {key}, got {val_type_str} should be {expected_type_str}"
)
self.config[attr_name] = val
def __call__(self, model_, inputs_): def __call__(self, model_, inputs_):
from torch._inductor.compile_fx import compile_fx from torch._inductor.compile_fx import compile_fx

View File

@ -13,6 +13,7 @@ e_set = {1}
e_tuple = (1,) e_tuple = (1,)
e_dict = {1: 2} e_dict = {1: 2}
e_none: Optional[bool] = None e_none: Optional[bool] = None
e_optional: Optional[bool] = True
e_ignored = True e_ignored = True
_e_ignored = True _e_ignored = True
magic_cache_config_ignored = True magic_cache_config_ignored = True

View File

@ -5,6 +5,7 @@ import inspect
import io import io
import os import os
import pickle import pickle
import sys
import tokenize import tokenize
import unittest import unittest
import warnings import warnings
@ -54,6 +55,7 @@ class Config:
justknob: Optional[str] = None justknob: Optional[str] = None
env_name_default: Optional[str] = None env_name_default: Optional[str] = None
env_name_force: Optional[str] = None env_name_force: Optional[str] = None
value_type: Optional[type] = None
def __init__( def __init__(
self, self,
@ -61,12 +63,14 @@ class Config:
justknob: Optional[str] = None, justknob: Optional[str] = None,
env_name_default: Optional[str] = None, env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None, env_name_force: Optional[str] = None,
value_type: Optional[type] = None,
): ):
# python 3.9 does not support kw_only on the dataclass :(. # python 3.9 does not support kw_only on the dataclass :(.
self.default = default self.default = default
self.justknob = justknob self.justknob = justknob
self.env_name_default = env_name_default self.env_name_default = env_name_default
self.env_name_force = env_name_force self.env_name_force = env_name_force
self.value_type = value_type
# Types saved/loaded in configs # Types saved/loaded in configs
@ -99,6 +103,10 @@ def install_config_module(module: ModuleType) -> None:
prefix: str, prefix: str,
) -> None: ) -> None:
"""Walk the module structure and move everything to module._config""" """Walk the module structure and move everything to module._config"""
if sys.version_info[:2] < (3, 10):
type_hints = getattr(source, "__annotations__", {})
else:
type_hints = inspect.get_annotations(source)
for key, value in list(source.__dict__.items()): for key, value in list(source.__dict__.items()):
if ( if (
key.startswith("__") key.startswith("__")
@ -111,7 +119,10 @@ def install_config_module(module: ModuleType) -> None:
name = f"{prefix}{key}" name = f"{prefix}{key}"
if isinstance(value, CONFIG_TYPES): if isinstance(value, CONFIG_TYPES):
config[name] = _ConfigEntry(Config(default=value)) annotated_type = type_hints.get(key, None)
config[name] = _ConfigEntry(
Config(default=value, value_type=annotated_type)
)
if dest is module: if dest is module:
delattr(module, key) delattr(module, key)
elif isinstance(value, Config): elif isinstance(value, Config):
@ -192,6 +203,8 @@ _UNSET_SENTINEL = object()
class _ConfigEntry: class _ConfigEntry:
# The default value specified in the configuration # The default value specified in the configuration
default: Any default: Any
# The type of the configuration value
value_type: type
# The value specified by the user when they overrode the configuration # The value specified by the user when they overrode the configuration
# _UNSET_SENTINEL indicates the value is not set. # _UNSET_SENTINEL indicates the value is not set.
user_override: Any = _UNSET_SENTINEL user_override: Any = _UNSET_SENTINEL
@ -203,6 +216,9 @@ class _ConfigEntry:
def __init__(self, config: Config): def __init__(self, config: Config):
self.default = config.default self.default = config.default
self.value_type = (
config.value_type if config.value_type is not None else type(self.default)
)
self.justknob = config.justknob self.justknob = config.justknob
if config.env_name_default is not None: if config.env_name_default is not None:
if (env_value := _read_env_variable(config.env_name_default)) is not None: if (env_value := _read_env_variable(config.env_name_default)) is not None:
@ -314,6 +330,9 @@ class ConfigModule(ModuleType):
config[key] = copy.deepcopy(getattr(self, key)) config[key] = copy.deepcopy(getattr(self, key))
return config return config
def get_type(self, config_name: str) -> type:
return self._config[config_name].value_type
def save_config(self) -> bytes: def save_config(self) -> bytes:
"""Convert config to a pickled blob""" """Convert config to a pickled blob"""
ignored_keys = getattr(self, "_save_config_ignore", []) ignored_keys = getattr(self, "_save_config_ignore", [])