mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add a shortcut to test all torchbench models. (#57311)
Summary: This PR adds a shortcut of specifying all models in TorchBench CI Pull Request resolved: https://github.com/pytorch/pytorch/pull/57311 Test Plan: CI RUN_TORCHBENCH: ALL Reviewed By: bitfort Differential Revision: D28160198 Pulled By: xuzhao9 fbshipit-source-id: 67c292bc98868979d868d4cf1e599c38e0da94b5
This commit is contained in:
parent
33eea146ee
commit
d68ad3cb1e
10
.github/scripts/run_torchbench.py
vendored
10
.github/scripts/run_torchbench.py
vendored
|
|
@ -28,7 +28,7 @@ start: {control}
|
|||
end: {treatment}
|
||||
threshold: 100
|
||||
direction: decrease
|
||||
timeout: 60
|
||||
timeout: 720
|
||||
tests:"""
|
||||
|
||||
def gen_abtest_config(control: str, treatment: str, models: List[str]):
|
||||
|
|
@ -36,6 +36,8 @@ def gen_abtest_config(control: str, treatment: str, models: List[str]):
|
|||
d["control"] = control
|
||||
d["treatment"] = treatment
|
||||
config = ABTEST_CONFIG_TEMPLATE.format(**d)
|
||||
if models == ["ALL"]:
|
||||
return config + "\n"
|
||||
for model in models:
|
||||
config = f"{config}\n - {model}"
|
||||
config = config + "\n"
|
||||
|
|
@ -57,8 +59,12 @@ def extract_models_from_pr(torchbench_path: str, prbody_file: str) -> List[str]:
|
|||
if magic_lines:
|
||||
# Only the first magic line will be respected.
|
||||
model_list = list(map(lambda x: x.strip(), magic_lines[0][len(MAGIC_PREFIX):].split(",")))
|
||||
# Shortcut: if model_list is ["ALL"], run all the tests
|
||||
if model_list == ["ALL"]:
|
||||
return model_list
|
||||
# Sanity check: make sure all the user specified models exist in torchbench repository
|
||||
full_model_list = os.listdir(os.path.join(torchbench_path, "torchbenchmark", "models"))
|
||||
benchmark_path = os.path.join(torchbench_path, "torchbenchmark", "models")
|
||||
full_model_list = [model for model in os.listdir(benchmark_path) if os.path.isdir(os.path.join(benchmark_path, model))]
|
||||
for m in model_list:
|
||||
if m not in full_model_list:
|
||||
print(f"The model {m} you specified does not exist in TorchBench suite. Please double check.")
|
||||
|
|
|
|||
2
.github/workflows/run_torchbench.yml
vendored
2
.github/workflows/run_torchbench.yml
vendored
|
|
@ -14,6 +14,8 @@ jobs:
|
|||
# Only run the job when the body contains magic word "RUN_TORCHBENCH:"
|
||||
if: ${{ github.repository_owner == 'pytorch' && contains(github.event.pull_request.body, 'RUN_TORCHBENCH:') }}
|
||||
runs-on: [self-hosted, bm-runner]
|
||||
# Set to 12 hours
|
||||
timeout-minutes: 720
|
||||
steps:
|
||||
- name: Checkout PyTorch
|
||||
uses: actions/checkout@v2
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user