pytorch/test/test_utils_config_module.py
Colin L. Rice 2a857e940d config: Add env_name_default and env_name_force to Config (#138956)
This allows Configs to handle setting their defaults (or overriding
themselves) via environment variables.

The environment variables are resolved at install time (which is usually
import time). This is done 1) to avoid any race conditions between
threads etc..., but 2) to help encourage people to just go modify the
configs directly, vs overriding environment variables to change
pytorch behaviour.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956
Approved by: https://github.com/ezyang
ghstack dependencies: #138766
2024-11-06 21:20:42 +00:00

318 lines
11 KiB
Python

# Owner(s): ["module: unknown"]
import os
import pickle
os.environ["ENV_TRUE"] = "1"
os.environ["ENV_FALSE"] = "0"
from torch.testing._internal import fake_config_module as config
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._config_module import _UNSET_SENTINEL
class TestConfigModule(TestCase):
def test_base_value_loading(self):
self.assertTrue(config.e_bool)
self.assertTrue(config.nested.e_bool)
self.assertEqual(config.e_int, 1)
self.assertEqual(config.e_float, 1.0)
self.assertEqual(config.e_string, "string")
self.assertEqual(config.e_list, [1])
self.assertEqual(config.e_set, {1})
self.assertEqual(config.e_tuple, (1,))
self.assertEqual(config.e_dict, {1: 2})
self.assertEqual(config.e_none, None)
with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist"
):
config.does_not_exist
def test_overrides(self):
config.e_bool = False
self.assertFalse(config.e_bool)
config.nested.e_bool = False
self.assertFalse(config.nested.e_bool)
config.e_int = 2
self.assertEqual(config.e_int, 2)
config.e_float = 2.0
self.assertEqual(config.e_float, 2.0)
config.e_string = "string2"
self.assertEqual(config.e_string, "string2")
config.e_list = [2]
self.assertEqual(config.e_list, [2])
config.e_set = {2}
self.assertEqual(config.e_set, {2})
config.e_tuple = (2,)
self.assertEqual(config.e_tuple, (2,))
config.e_dict = {2: 3}
self.assertEqual(config.e_dict, {2: 3})
config.e_none = "not none"
self.assertEqual(config.e_none, "not none")
config.e_none = None
self.assertEqual(config.e_none, None)
with self.assertRaises(
AttributeError, msg="fake_config_module.does_not_exist does not exist"
):
config.does_not_exist = 0
# Config changes get persisted between test cases
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_none_override_semantics(self):
config.e_bool = None
self.assertIsNone(config.e_bool)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_reference_semantics(self):
config.e_list.append(2)
self.assertEqual(config.e_list, [1, 2])
config.e_set.add(2)
self.assertEqual(config.e_set, {1, 2})
config.e_dict[2] = 3
self.assertEqual(config.e_dict, {1: 2, 2: 3})
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_env_name_semantics(self):
self.assertTrue(config.e_env_default)
self.assertFalse(config.e_env_default_FALSE)
self.assertTrue(config.e_env_force)
config.e_env_default = False
self.assertFalse(config.e_env_default)
config.e_env_force = False
self.assertTrue(config.e_env_force)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config(self):
p = config.save_config()
self.assertEqual(
pickle.loads(p),
{
"_cache_config_ignore_prefix": ["magic_cache_config"],
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"_e_ignored": True,
"e_compile_ignored": True,
"magic_cache_config_ignored": True,
"_save_config_ignore": ["e_ignored"],
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
config.e_bool = False
config.e_ignored = False
config.load_config(p)
self.assertTrue(config.e_bool)
self.assertFalse(config.e_ignored)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_save_config_portable(self):
p = config.save_config_portable()
self.assertEqual(
p,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"e_compile_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
config.e_bool = False
config._e_ignored = False
config.load_config(p)
self.assertTrue(config.e_bool)
self.assertFalse(config._e_ignored)
# Config changes get persisted between test cases
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_codegen_config(self):
config.e_bool = False
config.e_ignored = False
code = config.codegen_config()
self.assertEqual(
code,
"""torch.testing._internal.fake_config_module.e_bool = False
torch.testing._internal.fake_config_module.e_list = [1]
torch.testing._internal.fake_config_module.e_set = {1}
torch.testing._internal.fake_config_module.e_dict = {1: 2}
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
)
# Config changes get persisted between test cases
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_get_hash(self):
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
# Test cached value
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
config._hash_digest = "fake"
self.assertEqual(config.get_hash(), "fake")
config.e_bool = False
self.assertNotEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
config.e_bool = True
# Test ignored values
config.e_compile_ignored = False
self.assertEqual(
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
)
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_dict_copy_semantics(self):
p = config.shallow_copy_dict()
self.assertDictEqual(
p,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
p2 = config.to_dict()
self.assertEqual(
p2,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
p3 = config.get_config_copy()
self.assertEqual(
p3,
{
"e_bool": True,
"e_dict": {1: 2},
"e_float": 1.0,
"e_int": 1,
"e_list": [1],
"e_none": None,
"e_set": {1},
"e_string": "string",
"e_tuple": (1,),
"nested.e_bool": True,
"e_ignored": True,
"_e_ignored": True,
"e_compile_ignored": True,
"_cache_config_ignore_prefix": ["magic_cache_config"],
"_save_config_ignore": ["e_ignored"],
"magic_cache_config_ignored": True,
"e_config": True,
"e_jk": True,
"e_jk_false": False,
"e_env_default": True,
"e_env_default_FALSE": False,
"e_env_force": True,
},
)
# Shallow + deep copy semantics
config.e_dict[2] = 3
self.assertEqual(p["e_dict"], {1: 2})
self.assertEqual(p2["e_dict"], {1: 2})
self.assertEqual(p3["e_dict"], {1: 2})
for k in config._config:
config._config[k].user_override = _UNSET_SENTINEL
def test_patch(self):
self.assertTrue(config.e_bool)
with config.patch("e_bool", False):
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)
with config.patch(e_bool=False):
self.assertFalse(config.e_bool)
self.assertTrue(config.e_bool)
with self.assertRaises(AssertionError):
with config.patch("does_not_exist"):
pass
def test_make_closur_patcher(self):
revert = config._make_closure_patcher(e_bool=False)()
self.assertFalse(config.e_bool)
revert()
self.assertTrue(config.e_bool)
if __name__ == "__main__":
run_tests()