Change run_test.py arg parsing to handle additional args better (#126709)

Do not inherit parser from common_utils
* I don't think we use any variables in run_test that depend on those, and I think all tests except doctests run in a subprocess so they will parse the args in common_utils and set the variables.  I don't think doctests wants any of those variables?

Parse known args, add the extra args as extra, pass the extra ones along to the subprocess
Removes the first instance of `--`

I think I will miss run_test telling me if an arg is valid or not

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126709
Approved by: https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/Flamefire
This commit is contained in:
Catherine Lee 2024-05-23 21:08:12 +00:00 committed by PyTorch MergeBot
parent 09a73da190
commit a31a60d85b

View File

@ -29,7 +29,6 @@ from torch.testing._internal.common_utils import (
IS_CI,
IS_MACOS,
IS_WINDOWS,
parser as common_parser,
retry_shell,
set_cwd,
shell,
@ -384,7 +383,7 @@ def run_test(
) -> int:
env = env or os.environ.copy()
maybe_set_hip_visible_devies()
unittest_args = options.additional_unittest_args.copy()
unittest_args = options.additional_args.copy()
test_file = test_module.name
stepcurrent_key = test_file
@ -1057,7 +1056,6 @@ def parse_args():
description="Run the PyTorch unit test suite",
epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
formatter_class=argparse.RawTextHelpFormatter,
parents=[common_parser],
)
parser.add_argument(
"-v",
@ -1206,12 +1204,6 @@ def parse_args():
and "debug" not in BUILD_ENVIRONMENT
and "parallelnative" not in BUILD_ENVIRONMENT,
)
parser.add_argument(
"additional_unittest_args",
nargs="*",
help="additional arguments passed through to unittest, e.g., "
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
)
parser.add_argument(
"--shard",
nargs=2,
@ -1273,7 +1265,11 @@ def parse_args():
help="Run tests with TorchInductor turned on",
)
return parser.parse_args()
args, extra = parser.parse_known_args()
if "--" in extra:
extra.remove("--")
args.additional_args = extra
return args
def exclude_tests(
@ -1626,7 +1622,7 @@ def run_tests(
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
options_clone.additional_unittest_args.extend(["-m", "serial"])
options_clone.additional_args.extend(["-m", "serial"])
failure = run_test_module(test, test_directory, options_clone)
test_failed = handle_error_messages(failure)
if (
@ -1641,7 +1637,7 @@ def run_tests(
options_clone = copy.deepcopy(options)
if can_run_in_pytest(test):
options_clone.pytest = True
options_clone.additional_unittest_args.extend(["-m", "not serial"])
options_clone.additional_args.extend(["-m", "not serial"])
pool.apply_async(
run_test_module,
args=(test, test_directory, options_clone),