diff --git a/.ci/pytorch/macos-test.sh b/.ci/pytorch/macos-test.sh index 90d02de26ca..a50d505bffc 100755 --- a/.ci/pytorch/macos-test.sh +++ b/.ci/pytorch/macos-test.sh @@ -42,6 +42,16 @@ test_python_all() { assert_git_not_dirty } +test_python_mps() { + setup_test_python + + time python test/run_test.py --verbose --mps + MTL_CAPTURE_ENABLED=1 ${CONDA_RUN} python3 test/test_mps.py --verbose -k test_metal_capture + + assert_git_not_dirty +} + + test_python_shard() { if [[ -z "$NUM_TEST_SHARDS" ]]; then echo "NUM_TEST_SHARDS must be defined to run a Python test shard" @@ -305,6 +315,8 @@ elif [[ $TEST_CONFIG == *"perf_timm"* ]]; then test_timm_perf elif [[ $TEST_CONFIG == *"perf_smoketest"* ]]; then test_torchbench_smoketest +elif [[ $TEST_CONFIG == *"mps"* ]]; then + test_python_mps elif [[ $NUM_TEST_SHARDS -gt 1 ]]; then test_python_shard "${SHARD_NUMBER}" if [[ "${SHARD_NUMBER}" == 1 ]]; then