mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This allows Configs to handle setting their defaults (or overriding themselves) via environment variables. The environment variables are resolved at install time (which is usually import time). This is done 1) to avoid any race conditions between threads etc..., but 2) to help encourage people to just go modify the configs directly, vs overriding environment variables to change pytorch behaviour. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138956 Approved by: https://github.com/ezyang ghstack dependencies: #138766
318 lines
11 KiB
Python
318 lines
11 KiB
Python
# Owner(s): ["module: unknown"]
|
|
import os
|
|
import pickle
|
|
|
|
|
|
os.environ["ENV_TRUE"] = "1"
|
|
os.environ["ENV_FALSE"] = "0"
|
|
|
|
from torch.testing._internal import fake_config_module as config
|
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
|
from torch.utils._config_module import _UNSET_SENTINEL
|
|
|
|
|
|
class TestConfigModule(TestCase):
|
|
def test_base_value_loading(self):
|
|
self.assertTrue(config.e_bool)
|
|
self.assertTrue(config.nested.e_bool)
|
|
self.assertEqual(config.e_int, 1)
|
|
self.assertEqual(config.e_float, 1.0)
|
|
self.assertEqual(config.e_string, "string")
|
|
self.assertEqual(config.e_list, [1])
|
|
self.assertEqual(config.e_set, {1})
|
|
self.assertEqual(config.e_tuple, (1,))
|
|
self.assertEqual(config.e_dict, {1: 2})
|
|
self.assertEqual(config.e_none, None)
|
|
with self.assertRaises(
|
|
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
|
):
|
|
config.does_not_exist
|
|
|
|
def test_overrides(self):
|
|
config.e_bool = False
|
|
self.assertFalse(config.e_bool)
|
|
config.nested.e_bool = False
|
|
self.assertFalse(config.nested.e_bool)
|
|
config.e_int = 2
|
|
self.assertEqual(config.e_int, 2)
|
|
config.e_float = 2.0
|
|
self.assertEqual(config.e_float, 2.0)
|
|
config.e_string = "string2"
|
|
self.assertEqual(config.e_string, "string2")
|
|
config.e_list = [2]
|
|
self.assertEqual(config.e_list, [2])
|
|
config.e_set = {2}
|
|
self.assertEqual(config.e_set, {2})
|
|
config.e_tuple = (2,)
|
|
self.assertEqual(config.e_tuple, (2,))
|
|
config.e_dict = {2: 3}
|
|
self.assertEqual(config.e_dict, {2: 3})
|
|
config.e_none = "not none"
|
|
self.assertEqual(config.e_none, "not none")
|
|
config.e_none = None
|
|
self.assertEqual(config.e_none, None)
|
|
with self.assertRaises(
|
|
AttributeError, msg="fake_config_module.does_not_exist does not exist"
|
|
):
|
|
config.does_not_exist = 0
|
|
# Config changes get persisted between test cases
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_none_override_semantics(self):
|
|
config.e_bool = None
|
|
self.assertIsNone(config.e_bool)
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_reference_semantics(self):
|
|
config.e_list.append(2)
|
|
self.assertEqual(config.e_list, [1, 2])
|
|
config.e_set.add(2)
|
|
self.assertEqual(config.e_set, {1, 2})
|
|
config.e_dict[2] = 3
|
|
self.assertEqual(config.e_dict, {1: 2, 2: 3})
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_env_name_semantics(self):
|
|
self.assertTrue(config.e_env_default)
|
|
self.assertFalse(config.e_env_default_FALSE)
|
|
self.assertTrue(config.e_env_force)
|
|
config.e_env_default = False
|
|
self.assertFalse(config.e_env_default)
|
|
config.e_env_force = False
|
|
self.assertTrue(config.e_env_force)
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_save_config(self):
|
|
p = config.save_config()
|
|
self.assertEqual(
|
|
pickle.loads(p),
|
|
{
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"magic_cache_config_ignored": True,
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_force": True,
|
|
},
|
|
)
|
|
config.e_bool = False
|
|
config.e_ignored = False
|
|
config.load_config(p)
|
|
self.assertTrue(config.e_bool)
|
|
self.assertFalse(config.e_ignored)
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_save_config_portable(self):
|
|
p = config.save_config_portable()
|
|
self.assertEqual(
|
|
p,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_force": True,
|
|
},
|
|
)
|
|
config.e_bool = False
|
|
config._e_ignored = False
|
|
config.load_config(p)
|
|
self.assertTrue(config.e_bool)
|
|
self.assertFalse(config._e_ignored)
|
|
# Config changes get persisted between test cases
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_codegen_config(self):
|
|
config.e_bool = False
|
|
config.e_ignored = False
|
|
code = config.codegen_config()
|
|
self.assertEqual(
|
|
code,
|
|
"""torch.testing._internal.fake_config_module.e_bool = False
|
|
torch.testing._internal.fake_config_module.e_list = [1]
|
|
torch.testing._internal.fake_config_module.e_set = {1}
|
|
torch.testing._internal.fake_config_module.e_dict = {1: 2}
|
|
torch.testing._internal.fake_config_module._save_config_ignore = ['e_ignored']""",
|
|
)
|
|
# Config changes get persisted between test cases
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_get_hash(self):
|
|
self.assertEqual(
|
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
)
|
|
# Test cached value
|
|
self.assertEqual(
|
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
)
|
|
self.assertEqual(
|
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
)
|
|
config._hash_digest = "fake"
|
|
self.assertEqual(config.get_hash(), "fake")
|
|
|
|
config.e_bool = False
|
|
self.assertNotEqual(
|
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
)
|
|
config.e_bool = True
|
|
|
|
# Test ignored values
|
|
config.e_compile_ignored = False
|
|
self.assertEqual(
|
|
config.get_hash(), b"U\x8bi\xc2~PY\x98\x18\x9d\xf8<\xe4\xbc%\x0c"
|
|
)
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_dict_copy_semantics(self):
|
|
p = config.shallow_copy_dict()
|
|
self.assertDictEqual(
|
|
p,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_force": True,
|
|
},
|
|
)
|
|
p2 = config.to_dict()
|
|
self.assertEqual(
|
|
p2,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_force": True,
|
|
},
|
|
)
|
|
p3 = config.get_config_copy()
|
|
self.assertEqual(
|
|
p3,
|
|
{
|
|
"e_bool": True,
|
|
"e_dict": {1: 2},
|
|
"e_float": 1.0,
|
|
"e_int": 1,
|
|
"e_list": [1],
|
|
"e_none": None,
|
|
"e_set": {1},
|
|
"e_string": "string",
|
|
"e_tuple": (1,),
|
|
"nested.e_bool": True,
|
|
"e_ignored": True,
|
|
"_e_ignored": True,
|
|
"e_compile_ignored": True,
|
|
"_cache_config_ignore_prefix": ["magic_cache_config"],
|
|
"_save_config_ignore": ["e_ignored"],
|
|
"magic_cache_config_ignored": True,
|
|
"e_config": True,
|
|
"e_jk": True,
|
|
"e_jk_false": False,
|
|
"e_env_default": True,
|
|
"e_env_default_FALSE": False,
|
|
"e_env_force": True,
|
|
},
|
|
)
|
|
|
|
# Shallow + deep copy semantics
|
|
config.e_dict[2] = 3
|
|
self.assertEqual(p["e_dict"], {1: 2})
|
|
self.assertEqual(p2["e_dict"], {1: 2})
|
|
self.assertEqual(p3["e_dict"], {1: 2})
|
|
for k in config._config:
|
|
config._config[k].user_override = _UNSET_SENTINEL
|
|
|
|
def test_patch(self):
|
|
self.assertTrue(config.e_bool)
|
|
with config.patch("e_bool", False):
|
|
self.assertFalse(config.e_bool)
|
|
self.assertTrue(config.e_bool)
|
|
with config.patch(e_bool=False):
|
|
self.assertFalse(config.e_bool)
|
|
self.assertTrue(config.e_bool)
|
|
with self.assertRaises(AssertionError):
|
|
with config.patch("does_not_exist"):
|
|
pass
|
|
|
|
def test_make_closur_patcher(self):
|
|
revert = config._make_closure_patcher(e_bool=False)()
|
|
self.assertFalse(config.e_bool)
|
|
revert()
|
|
self.assertTrue(config.e_bool)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|