mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change run_test.py arg parsing to handle additional args better (#126709)
Do not inherit parser from common_utils * I don't think we use any variables in run_test that depend on those, and I think all tests except doctests run in a subprocess so they will parse the args in common_utils and set the variables. I don't think doctests wants any of those variables? Parse known args, add the extra args as extra, pass the extra ones along to the subprocess Removes the first instance of `--` I think I will miss run_test telling me if an arg is valid or not Pull Request resolved: https://github.com/pytorch/pytorch/pull/126709 Approved by: https://github.com/ZainRizvi, https://github.com/huydhn, https://github.com/Flamefire
This commit is contained in:
parent
09a73da190
commit
a31a60d85b
|
|
@ -29,7 +29,6 @@ from torch.testing._internal.common_utils import (
|
|||
IS_CI,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
parser as common_parser,
|
||||
retry_shell,
|
||||
set_cwd,
|
||||
shell,
|
||||
|
|
@ -384,7 +383,7 @@ def run_test(
|
|||
) -> int:
|
||||
env = env or os.environ.copy()
|
||||
maybe_set_hip_visible_devies()
|
||||
unittest_args = options.additional_unittest_args.copy()
|
||||
unittest_args = options.additional_args.copy()
|
||||
test_file = test_module.name
|
||||
stepcurrent_key = test_file
|
||||
|
||||
|
|
@ -1057,7 +1056,6 @@ def parse_args():
|
|||
description="Run the PyTorch unit test suite",
|
||||
epilog="where TESTS is any of: {}".format(", ".join(TESTS)),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
parents=[common_parser],
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v",
|
||||
|
|
@ -1206,12 +1204,6 @@ def parse_args():
|
|||
and "debug" not in BUILD_ENVIRONMENT
|
||||
and "parallelnative" not in BUILD_ENVIRONMENT,
|
||||
)
|
||||
parser.add_argument(
|
||||
"additional_unittest_args",
|
||||
nargs="*",
|
||||
help="additional arguments passed through to unittest, e.g., "
|
||||
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--shard",
|
||||
nargs=2,
|
||||
|
|
@ -1273,7 +1265,11 @@ def parse_args():
|
|||
help="Run tests with TorchInductor turned on",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
args, extra = parser.parse_known_args()
|
||||
if "--" in extra:
|
||||
extra.remove("--")
|
||||
args.additional_args = extra
|
||||
return args
|
||||
|
||||
|
||||
def exclude_tests(
|
||||
|
|
@ -1626,7 +1622,7 @@ def run_tests(
|
|||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
options_clone.additional_unittest_args.extend(["-m", "serial"])
|
||||
options_clone.additional_args.extend(["-m", "serial"])
|
||||
failure = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(failure)
|
||||
if (
|
||||
|
|
@ -1641,7 +1637,7 @@ def run_tests(
|
|||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
options_clone.additional_unittest_args.extend(["-m", "not serial"])
|
||||
options_clone.additional_args.extend(["-m", "not serial"])
|
||||
pool.apply_async(
|
||||
run_test_module,
|
||||
args=(test, test_directory, options_clone),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user