mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[test stats] use published test stats for sharding (#81116)
Use the nightly-published test stats to perform sharding, instead of calculating it in every build job. Pull Request resolved: https://github.com/pytorch/pytorch/pull/81116 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
fb93c3988a
commit
9f58d5d7ce
2
.github/workflows/_linux-build.yml
vendored
2
.github/workflows/_linux-build.yml
vendored
|
|
@ -135,7 +135,7 @@ jobs:
|
||||||
- name: Archive artifacts into zip
|
- name: Archive artifacts into zip
|
||||||
if: inputs.build-generates-artifacts
|
if: inputs.build-generates-artifacts
|
||||||
run: |
|
run: |
|
||||||
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin .pytorch-test-times.json
|
zip -1 -r artifacts.zip dist/ build/custom_test_artifacts build/lib build/bin
|
||||||
|
|
||||||
- name: Store PyTorch Build Artifacts on S3
|
- name: Store PyTorch Build Artifacts on S3
|
||||||
uses: seemethere/upload-artifact-s3@v5
|
uses: seemethere/upload-artifact-s3@v5
|
||||||
|
|
|
||||||
|
|
@ -296,10 +296,4 @@ else
|
||||||
fi
|
fi
|
||||||
fi
|
fi
|
||||||
|
|
||||||
if [[ "$BUILD_ENVIRONMENT" != *libtorch* && "$BUILD_ENVIRONMENT" != *bazel* ]]; then
|
|
||||||
# export test times so that potential sharded tests that'll branch off this build will use consistent data
|
|
||||||
# don't do this for libtorch as libtorch is C++ only and thus won't have python tests run on its build
|
|
||||||
python test/run_test.py --export-past-test-times
|
|
||||||
fi
|
|
||||||
|
|
||||||
print_sccache_stats
|
print_sccache_stats
|
||||||
|
|
|
||||||
|
|
@ -146,9 +146,6 @@ python setup.py install --cmake && sccache --show-stats && (
|
||||||
if errorlevel 1 exit /b
|
if errorlevel 1 exit /b
|
||||||
if not errorlevel 0 exit /b
|
if not errorlevel 0 exit /b
|
||||||
|
|
||||||
:: export test times so that potential sharded tests that'll branch off this build will use consistent data
|
|
||||||
python test/run_test.py --export-past-test-times %PYTORCH_FINAL_PACKAGE_DIR%/.pytorch-test-times.json
|
|
||||||
|
|
||||||
:: Also save build/.ninja_log as an artifact
|
:: Also save build/.ninja_log as an artifact
|
||||||
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
copy /Y "build\.ninja_log" "%PYTORCH_FINAL_PACKAGE_DIR%\"
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
call %SCRIPT_HELPERS_DIR%\setup_pytorch_env.bat
|
||||||
|
|
||||||
echo Copying over test times file
|
echo Copying over test times file
|
||||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
|
||||||
|
|
||||||
pushd test
|
pushd test
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,6 @@ if "%SHARD_NUMBER%" == "1" (
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
echo Copying over test times file
|
|
||||||
copy /Y "%PYTORCH_FINAL_PACKAGE_DIR_WIN%\.pytorch-test-times.json" "%TEST_DIR_WIN%"
|
|
||||||
|
|
||||||
echo Run nn tests
|
echo Run nn tests
|
||||||
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
python run_test.py --exclude-jit-executor --exclude-distributed-tests --shard "%SHARD_NUMBER%" "%NUM_TEST_SHARDS%" --verbose
|
||||||
if ERRORLEVEL 1 goto fail
|
if ERRORLEVEL 1 goto fail
|
||||||
|
|
|
||||||
|
|
@ -32,11 +32,11 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||||
try:
|
try:
|
||||||
# using tools/ to optimize test run.
|
# using tools/ to optimize test run.
|
||||||
sys.path.append(str(REPO_ROOT))
|
sys.path.append(str(REPO_ROOT))
|
||||||
|
from tools.stats.import_test_stats import get_test_times
|
||||||
from tools.testing.test_selections import (
|
from tools.testing.test_selections import (
|
||||||
export_S3_test_times,
|
|
||||||
get_shard_based_on_S3,
|
|
||||||
get_reordered_tests,
|
get_reordered_tests,
|
||||||
get_test_case_configs,
|
get_test_case_configs,
|
||||||
|
calculate_shards,
|
||||||
)
|
)
|
||||||
HAVE_TEST_SELECTION_TOOLS = True
|
HAVE_TEST_SELECTION_TOOLS = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|
@ -677,13 +677,6 @@ def parse_args():
|
||||||
help="additional arguments passed through to unittest, e.g., "
|
help="additional arguments passed through to unittest, e.g., "
|
||||||
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
|
"python run_test.py -i sparse -- TestSparse.test_factory_size_check",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
"--export-past-test-times",
|
|
||||||
nargs="?",
|
|
||||||
type=str,
|
|
||||||
const=TEST_TIMES_FILE,
|
|
||||||
help="dumps test times from previous S3 stats into a file, format JSON",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--shard",
|
"--shard",
|
||||||
nargs=2,
|
nargs=2,
|
||||||
|
|
@ -838,11 +831,21 @@ def get_selected_tests(options):
|
||||||
assert num_shards <= len(
|
assert num_shards <= len(
|
||||||
selected_tests
|
selected_tests
|
||||||
), f"Number of shards must be less than {len(selected_tests)}"
|
), f"Number of shards must be less than {len(selected_tests)}"
|
||||||
# TODO: fix this to use test_times_filename, but currently this is not working
|
|
||||||
# because setting the export arg immeidately halts the test execution.
|
if num_shards == 1:
|
||||||
selected_tests = get_shard_based_on_S3(
|
return selected_tests
|
||||||
which_shard, num_shards, selected_tests, TEST_TIMES_FILE
|
|
||||||
)
|
# Download previous test times to make sharding decisions
|
||||||
|
test_file_times = get_test_times(str(REPO_ROOT), filename=TEST_TIMES_FILE)
|
||||||
|
if len(test_file_times) == 0:
|
||||||
|
print(
|
||||||
|
"::warning:: Gathered no stats from S3. Proceeding with default sharding plan."
|
||||||
|
)
|
||||||
|
selected_tests = selected_tests[which_shard - 1 :: num_shards]
|
||||||
|
else:
|
||||||
|
shards = calculate_shards(num_shards, selected_tests, test_file_times)
|
||||||
|
_, tests_from_shard = shards[which_shard - 1]
|
||||||
|
selected_tests = tests_from_shard
|
||||||
|
|
||||||
# skip all distributed tests if distributed package is not available.
|
# skip all distributed tests if distributed package is not available.
|
||||||
if not dist.is_available():
|
if not dist.is_available():
|
||||||
|
|
@ -882,15 +885,6 @@ def run_test_module(test: str, test_directory: str, options) -> Optional[str]:
|
||||||
def main():
|
def main():
|
||||||
options = parse_args()
|
options = parse_args()
|
||||||
|
|
||||||
# TODO: move this export & download function in tools/ folder
|
|
||||||
test_times_filename = options.export_past_test_times
|
|
||||||
if test_times_filename:
|
|
||||||
print(
|
|
||||||
f"Exporting past test times from S3 to {test_times_filename}, no tests will be run."
|
|
||||||
)
|
|
||||||
export_S3_test_times(test_times_filename)
|
|
||||||
return
|
|
||||||
|
|
||||||
test_directory = str(REPO_ROOT / "test")
|
test_directory = str(REPO_ROOT / "test")
|
||||||
selected_tests = get_selected_tests(options)
|
selected_tests = get_selected_tests(options)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -41,6 +41,7 @@ def fetch_and_cache(
|
||||||
This fetch and cache utils allows sharing between different process.
|
This fetch and cache utils allows sharing between different process.
|
||||||
"""
|
"""
|
||||||
path = os.path.join(dirpath, name)
|
path = os.path.join(dirpath, name)
|
||||||
|
print(f"Downloading {url} to {path}")
|
||||||
|
|
||||||
def is_cached_file_valid() -> bool:
|
def is_cached_file_valid() -> bool:
|
||||||
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
|
# Check if the file is new enough (see: FILE_CACHE_LIFESPAN_SECONDS). A real check
|
||||||
|
|
@ -80,6 +81,21 @@ def get_slow_tests(
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_test_times(dirpath: str, filename: str) -> Dict[str, float]:
|
||||||
|
url = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json"
|
||||||
|
|
||||||
|
def process_response(the_response: Dict[str, Any]) -> Any:
|
||||||
|
build_environment = os.environ["BUILD_ENVIRONMENT"]
|
||||||
|
test_config = os.environ["TEST_CONFIG"]
|
||||||
|
return the_response[build_environment][test_config]
|
||||||
|
|
||||||
|
try:
|
||||||
|
return fetch_and_cache(dirpath, filename, url, process_response)
|
||||||
|
except Exception:
|
||||||
|
print("Couldn't download test times...")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def get_disabled_tests(
|
def get_disabled_tests(
|
||||||
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
dirpath: str, filename: str = DISABLED_TESTS_FILE
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user