Add multi env variable support to configs (#145288)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145288
Approved by: https://github.com/c00w
This commit is contained in:
Oguz Ulgen 2025-01-23 17:28:11 -08:00 committed by PyTorch MergeBot
parent 10bdd0a1cc
commit d3989ca636
3 changed files with 47 additions and 17 deletions

View File

@ -100,6 +100,10 @@ class TestConfigModule(TestCase):
config.e_env_force = False
self.assertTrue(config.e_env_force)
def test_multi_env(self):
self.assertTrue(config2.e_env_default_multi)
self.assertTrue(config2.e_env_force_multi)
def test_save_config(self):
p = config.save_config()
self.assertDictEqual(

View File

@ -1,8 +1,13 @@
import sys
from torch.utils._config_module import install_config_module
from torch.utils._config_module import Config, install_config_module
e_aliasing_bool = False
e_env_default_multi: bool = Config(
env_name_default=["ENV_TRUE", "ENV_FALSE"], default=False
)
e_env_force_multi: bool = Config(env_name_force=["ENV_FAKE", "ENV_TRUE"], default=False)
install_config_module(sys.modules[__name__])

View File

@ -53,10 +53,14 @@ class _Config(Generic[T]):
alias: If set, the directly use the value of the alias.
env_name_force: If set, this environment variable has precedence over
everything after this.
If multiple env variables are given, the precendence order is from
left to right.
user_override: If a user sets a value (i.e. foo.bar=True), that
has precedence over everything after this.
env_name_default: If set, this environment variable will override everything
after this.
If multiple env variables are given, the precendence order is from
left to right.
justknob: If this pytorch installation supports justknobs, that will
override defaults, but will not override the user_override precendence.
default: This value is the lowest precendance, and will be used if nothing is
@ -69,16 +73,16 @@ class _Config(Generic[T]):
justknob: the name of the feature / JK. In OSS this is unused.
default: is the value to default this knob to in OSS.
alias: The alias config to read instead.
env_name_force: The environment variable to read that is a FORCE
env_name_force: The environment variable, or list of, to read that is a FORCE
environment variable. I.e. it overrides everything except for alias.
env_name_default: The environment variable to read that changes the
env_name_default: The environment variable, or list of, to read that changes the
default behaviour. I.e. user overrides take preference.
"""
default: Union[T, object]
justknob: Optional[str] = None
env_name_default: Optional[str] = None
env_name_force: Optional[str] = None
env_name_default: Optional[list[str]] = None
env_name_force: Optional[list[str]] = None
value_type: Optional[type] = None
alias: Optional[str] = None
@ -86,16 +90,18 @@ class _Config(Generic[T]):
self,
default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None,
env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None,
env_name_default: Optional[Union[str, list[str]]] = None,
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
):
# python 3.9 does not support kw_only on the dataclass :(.
self.default = default
self.justknob = justknob
self.env_name_default = env_name_default
self.env_name_force = env_name_force
self.env_name_default = _Config.string_or_list_of_string_to_list(
env_name_default
)
self.env_name_force = _Config.string_or_list_of_string_to_list(env_name_force)
self.value_type = value_type
self.alias = alias
if self.justknob is not None:
@ -110,6 +116,17 @@ class _Config(Generic[T]):
and env_name_force is None
), "if alias is set, default, justknob or env var cannot be set"
@staticmethod
def string_or_list_of_string_to_list(
val: Optional[Union[str, list[str]]]
) -> Optional[list[str]]:
if val is None:
return None
if isinstance(val, str):
return [val]
assert isinstance(val, list)
return val
# In runtime, we unbox the Config[T] to a T, but typechecker cannot see this,
# so in order to allow for this dynamic behavior to work correctly with
@ -120,8 +137,8 @@ if TYPE_CHECKING:
def Config(
default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None,
env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None,
env_name_default: Optional[Union[str, list[str]]] = None,
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
) -> T:
@ -132,8 +149,8 @@ else:
def Config(
default: Union[T, object] = _UNSET_SENTINEL,
justknob: Optional[str] = None,
env_name_default: Optional[str] = None,
env_name_force: Optional[str] = None,
env_name_default: Optional[Union[str, list[str]]] = None,
env_name_force: Optional[Union[str, list[str]]] = None,
value_type: Optional[type] = None,
alias: Optional[str] = None,
) -> _Config[T]:
@ -300,11 +317,15 @@ class _ConfigEntry:
self.justknob = config.justknob
self.alias = config.alias
if config.env_name_default is not None:
if (env_value := _read_env_variable(config.env_name_default)) is not None:
self.env_value_default = env_value
for val in config.env_name_default:
if (env_value := _read_env_variable(val)) is not None:
self.env_value_default = env_value
break
if config.env_name_force is not None:
if (env_value := _read_env_variable(config.env_name_force)) is not None:
self.env_value_force = env_value
for val in config.env_name_force:
if (env_value := _read_env_variable(val)) is not None:
self.env_value_force = env_value
break
class ConfigModule(ModuleType):