Add a shortcut to test all torchbench models. (#57311)

Summary:
This PR adds a shortcut of specifying all models in TorchBench CI

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57311

Test Plan:
CI

RUN_TORCHBENCH: ALL

Reviewed By: bitfort

Differential Revision: D28160198

Pulled By: xuzhao9

fbshipit-source-id: 67c292bc98868979d868d4cf1e599c38e0da94b5
This commit is contained in:
Xu Zhao 2021-05-03 13:49:24 -07:00 committed by Facebook GitHub Bot
parent 33eea146ee
commit d68ad3cb1e
2 changed files with 10 additions and 2 deletions

View File

@ -28,7 +28,7 @@ start: {control}
end: {treatment}
threshold: 100
direction: decrease
timeout: 60
timeout: 720
tests:"""
def gen_abtest_config(control: str, treatment: str, models: List[str]):
@ -36,6 +36,8 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]):
d["control"] = control
d["treatment"] = treatment
config = ABTEST_CONFIG_TEMPLATE.format(**d)
if models == ["ALL"]:
return config + "\n"
for model in models:
config = f"{config}\n - {model}"
config = config + "\n"
@ -57,8 +59,12 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]:
if magic_lines:
# Only the first magic line will be respected.
model_list = list(map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX):].split(",")))
# Shortcut: if model_list is ["ALL"], run all the tests
if model_list == ["ALL"]:
return model_list
# Sanity check: make sure all the user specified models exist in torchbench repository
full_model_list = os.listdir(os.path.join(torchbench_path, "torchbenchmark", "models"))
benchmark_path = os.path.join(torchbench_path, "torchbenchmark", "models")
full_model_list = [model for model in os.listdir(benchmark_path) if os.path.isdir(os.path.join(benchmark_path, model))]
for m in model_list:
if m not in full_model_list:
print(f"The model {m} you specified does not exist in TorchBench suite. Please double check.")

View File

@ -14,6 +14,8 @@ jobs:
# Only run the job when the body contains magic word "RUN_TORCHBENCH:"
if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.body, 'RUN_TORCHBENCH:') }}
runs-on: [self-hosted, bm-runner]
# Set to 12 hours
timeout-minutes: 720
steps:
- name: Checkout PyTorch
uses: actions/checkout@v2