mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Alternate sharding (#119078)"
This reverts commit861acda205. Reverted https://github.com/pytorch/pytorch/pull/119078 on behalf of https://github.com/clee2000 due to failing861acda205([comment](https://github.com/pytorch/pytorch/pull/119078#issuecomment-1946583857))
This commit is contained in:
parent
a83a1bc43b
commit
9b38ee2343
|
|
@ -1563,25 +1563,6 @@ def run_tests(
|
|||
pool.terminate()
|
||||
|
||||
try:
|
||||
for test in selected_tests_serial:
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
failure = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(failure)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
):
|
||||
raise RuntimeError(
|
||||
failure.message
|
||||
+ "\n\nTip: You can keep running tests even on failure by "
|
||||
"passing --keep-going to run_test.py.\n"
|
||||
"If running on CI, add the 'keep-going' label to "
|
||||
"your PR and rerun your jobs."
|
||||
)
|
||||
|
||||
os.environ["NUM_PARALLEL_PROCS"] = str(NUM_PROCS)
|
||||
for test in selected_tests_parallel:
|
||||
options_clone = copy.deepcopy(options)
|
||||
|
|
@ -1596,6 +1577,32 @@ def run_tests(
|
|||
pool.join()
|
||||
del os.environ["NUM_PARALLEL_PROCS"]
|
||||
|
||||
if (
|
||||
not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
and len(failures) != 0
|
||||
):
|
||||
raise RuntimeError(
|
||||
"\n".join(x.message for x in failures)
|
||||
+ "\n\nTip: You can keep running tests even on failure by "
|
||||
"passing --keep-going to run_test.py.\n"
|
||||
"If running on CI, add the 'keep-going' label to "
|
||||
"your PR and rerun your jobs."
|
||||
)
|
||||
|
||||
for test in selected_tests_serial:
|
||||
options_clone = copy.deepcopy(options)
|
||||
if can_run_in_pytest(test):
|
||||
options_clone.pytest = True
|
||||
failure = run_test_module(test, test_directory, options_clone)
|
||||
test_failed = handle_error_messages(failure)
|
||||
if (
|
||||
test_failed
|
||||
and not options.continue_through_error
|
||||
and not RERUN_DISABLED_TESTS
|
||||
):
|
||||
raise RuntimeError(failure.message)
|
||||
|
||||
finally:
|
||||
pool.terminate()
|
||||
pool.join()
|
||||
|
|
@ -1677,14 +1684,14 @@ def main():
|
|||
|
||||
def __str__(self):
|
||||
s = f"Name: {self.name}\n"
|
||||
s += " Serial tests:\n"
|
||||
s += "".join(
|
||||
f" {test}\n" for test in self.sharded_tests if must_serial(test)
|
||||
)
|
||||
s += " Parallel tests:\n"
|
||||
s += "".join(
|
||||
f" {test}\n" for test in self.sharded_tests if not must_serial(test)
|
||||
)
|
||||
s += " Serial tests:\n"
|
||||
s += "".join(
|
||||
f" {test}\n" for test in self.sharded_tests if must_serial(test)
|
||||
)
|
||||
return s.strip()
|
||||
|
||||
test_batches: List[TestBatch] = []
|
||||
|
|
|
|||
|
|
@ -322,9 +322,6 @@ class TestCalculateShards(unittest.TestCase):
|
|||
),
|
||||
)
|
||||
|
||||
def test_zero_tests(self) -> None:
|
||||
self.assertListEqual([(0.0, []), (0.0, [])], calculate_shards(2, [], {}, None))
|
||||
|
||||
def test_split_shards_random(self) -> None:
|
||||
random.seed(120)
|
||||
for _ in range(100):
|
||||
|
|
@ -333,32 +330,27 @@ class TestCalculateShards(unittest.TestCase):
|
|||
random_times: Dict[str, float] = {
|
||||
str(i): random.randint(0, THRESHOLD * 10) for i in range(num_tests)
|
||||
}
|
||||
serial = [str(i) for i in range(num_tests) if random.randint(0, 1) == 0]
|
||||
|
||||
shards = calculate_shards(
|
||||
num_shards,
|
||||
[TestRun(t) for t in random_times.keys()],
|
||||
random_times,
|
||||
None,
|
||||
must_serial=lambda x: x in serial,
|
||||
sort_by_time=random.randint(0, 1) == 0,
|
||||
gen_class_times(random_times),
|
||||
)
|
||||
|
||||
times = [x[0] for x in shards]
|
||||
max_diff = max(times) - min(times)
|
||||
self.assertTrue(max_diff <= THRESHOLD)
|
||||
|
||||
all_sharded_tests: Dict[str, List[ShardedTest]] = defaultdict(list)
|
||||
for _, sharded_tests in shards:
|
||||
all_sharded_tests = defaultdict(list)
|
||||
for time, sharded_tests in shards:
|
||||
self.assertEqual(time, sum(x.time for x in sharded_tests))
|
||||
for sharded_test in sharded_tests:
|
||||
all_sharded_tests[sharded_test.name].append(sharded_test)
|
||||
|
||||
# Check that all test files are represented in the shards
|
||||
self.assertListEqual(
|
||||
sorted(random_times.keys()), sorted(all_sharded_tests.keys())
|
||||
)
|
||||
# Check that for each test file, the pytest shards' times adds up to
|
||||
# original and all shards are present
|
||||
for test, sharded_tests in all_sharded_tests.items():
|
||||
self.assertAlmostEqual(
|
||||
random_times[test], sum(x.time or 0 for x in sharded_tests)
|
||||
|
|
|
|||
|
|
@ -67,8 +67,44 @@ def get_with_pytest_shard(
|
|||
) -> List[ShardedTest]:
|
||||
sharded_tests: List[ShardedTest] = []
|
||||
|
||||
def get_duration_for_classes(
|
||||
test_file: str, test_classes: Set[str]
|
||||
) -> Optional[float]:
|
||||
duration: float = 0
|
||||
if not test_class_times:
|
||||
return None
|
||||
|
||||
for test_class in test_classes:
|
||||
class_duration = test_class_times.get(test_file, {}).get(test_class, None)
|
||||
if class_duration is None:
|
||||
return None
|
||||
if class_duration:
|
||||
duration += class_duration
|
||||
return duration
|
||||
|
||||
for test in tests:
|
||||
duration = get_duration(test, test_file_times, test_class_times or {})
|
||||
file_duration = test_file_times.get(test.test_file, None)
|
||||
included = test.included()
|
||||
excluded = test.excluded()
|
||||
included_classes_duration = get_duration_for_classes(test.test_file, included)
|
||||
excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
|
||||
|
||||
if included:
|
||||
# If we don't have the time for all included classes, our upper bound is the file duration
|
||||
duration = (
|
||||
included_classes_duration
|
||||
if included_classes_duration is not None
|
||||
else file_duration
|
||||
)
|
||||
elif excluded:
|
||||
# If we don't have the time for all excluded classes, our upper bound is file duration
|
||||
duration = (
|
||||
file_duration - excluded_classes_duration
|
||||
if excluded_classes_duration is not None and file_duration is not None
|
||||
else file_duration
|
||||
)
|
||||
else:
|
||||
duration = file_duration
|
||||
|
||||
if duration and duration > THRESHOLD:
|
||||
num_shards = math.ceil(duration / THRESHOLD)
|
||||
|
|
@ -81,111 +117,6 @@ def get_with_pytest_shard(
|
|||
return sharded_tests
|
||||
|
||||
|
||||
def get_duration(
|
||||
test: TestRun,
|
||||
test_file_times: Dict[str, float],
|
||||
test_class_times: Dict[str, Dict[str, float]],
|
||||
) -> Optional[float]:
|
||||
file_duration = test_file_times.get(test.test_file, None)
|
||||
if test.is_full_file():
|
||||
return file_duration
|
||||
|
||||
def get_duration_for_classes(
|
||||
test_file: str, test_classes: Set[str]
|
||||
) -> Optional[float]:
|
||||
duration: float = 0
|
||||
|
||||
for test_class in test_classes:
|
||||
class_duration = test_class_times.get(test_file, {}).get(test_class, None)
|
||||
if class_duration is None:
|
||||
return None
|
||||
duration += class_duration
|
||||
return duration
|
||||
|
||||
included = test.included()
|
||||
excluded = test.excluded()
|
||||
included_classes_duration = get_duration_for_classes(test.test_file, included)
|
||||
excluded_classes_duration = get_duration_for_classes(test.test_file, excluded)
|
||||
|
||||
if included_classes_duration is None or excluded_classes_duration is None:
|
||||
# Didn't get the time for all classes, so time is unknown
|
||||
return None
|
||||
|
||||
if included:
|
||||
return included_classes_duration
|
||||
assert (
|
||||
excluded
|
||||
), f"TestRun {test} is not full file but doesn't have included or excluded classes"
|
||||
if file_duration is None:
|
||||
return None
|
||||
return file_duration - excluded_classes_duration
|
||||
|
||||
|
||||
def shard(
|
||||
sharded_jobs: List[ShardJob],
|
||||
tests: Sequence[TestRun],
|
||||
test_file_times: Dict[str, float],
|
||||
test_class_times: Dict[str, Dict[str, float]],
|
||||
estimated_time_limit: Optional[float] = None,
|
||||
sort_by_time: bool = True,
|
||||
serial: bool = False,
|
||||
) -> None:
|
||||
if len(sharded_jobs) == 0:
|
||||
assert len(tests) == 0, "No shards provided but there are tests to shard"
|
||||
return
|
||||
# Modifies sharded_jobs in place
|
||||
known_tests = tests
|
||||
unknown_tests = []
|
||||
if sort_by_time:
|
||||
known_tests = [
|
||||
x
|
||||
for x in tests
|
||||
if get_duration(x, test_file_times, test_class_times) is not None
|
||||
]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
assert (
|
||||
unknown_tests == [] or serial
|
||||
), f"Attmempting to parallelize unknown tests {unknown_tests}"
|
||||
del tests
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
|
||||
def _shard_serial(tests: List[ShardedTest], sharded_jobs: List[ShardJob]) -> None:
|
||||
assert estimated_time_limit is not None, "Estimated time limit must be provided"
|
||||
new_sharded_jobs = sharded_jobs
|
||||
for test in tests:
|
||||
if (
|
||||
len(sharded_jobs) > 1
|
||||
and sharded_jobs[-1].get_total_time() > estimated_time_limit
|
||||
):
|
||||
new_sharded_jobs = sharded_jobs[:-1]
|
||||
min_sharded_job = min(new_sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
|
||||
def _shard_parallel(tests: List[ShardedTest], sharded_jobs: List[ShardJob]) -> None:
|
||||
for test in tests:
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.parallel.append(test)
|
||||
|
||||
if serial:
|
||||
_shard_serial(known_tests, sharded_jobs)
|
||||
else:
|
||||
_shard_parallel(known_tests, sharded_jobs)
|
||||
|
||||
# Round robin the unknown jobs starting with the smallest shard
|
||||
num_shards = len(sharded_jobs)
|
||||
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
||||
for unknown_test in unknown_tests:
|
||||
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
||||
index = (index + 1) % num_shards
|
||||
|
||||
return
|
||||
|
||||
|
||||
def calculate_shards(
|
||||
num_shards: int,
|
||||
tests: Sequence[TestRun],
|
||||
|
|
@ -195,65 +126,38 @@ def calculate_shards(
|
|||
sort_by_time: bool = True,
|
||||
) -> List[Tuple[float, List[ShardedTest]]]:
|
||||
must_serial = must_serial or (lambda x: True)
|
||||
test_class_times = test_class_times or {}
|
||||
serial_tests = [
|
||||
test
|
||||
for test in tests
|
||||
if get_duration(test, test_file_times, test_class_times) is None
|
||||
or must_serial(test.test_file)
|
||||
]
|
||||
parallel_tests = [test for test in tests if test not in serial_tests]
|
||||
|
||||
serial_time = sum(
|
||||
get_duration(test, test_file_times, test_class_times) or 0
|
||||
for test in serial_tests
|
||||
)
|
||||
parallel_time = sum(
|
||||
get_duration(test, test_file_times, test_class_times) or 0
|
||||
for test in parallel_tests
|
||||
)
|
||||
total_time = serial_time + parallel_time / NUM_PROCS_FOR_SHARDING_CALC
|
||||
estimated_time_per_shard = total_time / num_shards
|
||||
# Separate serial tests from parallel tests as much as possible to maximize
|
||||
# parallelism by putting all the serial tests on the first num_serial_shards
|
||||
# shards. The estimated_time_limit is the estimated time it should take for
|
||||
# the least filled serial shard. Ex if we have 8 min of serial tests, 20 min
|
||||
# of parallel tests, 6 shards, and 2 procs per machine, we would expect each
|
||||
# machine to take 3 min and should aim for 3 serial shards, with shards 1
|
||||
# and 2 taking 3 min and shard 3 taking 2 min. The estimated time limit
|
||||
# would be 2 min. This ensures that the first few shard contains as many
|
||||
# serial tests as possible and as few parallel tests as possible. The least
|
||||
# filled/last (in the example, the 3rd) shard may contain a lot of both
|
||||
# serial and parallel tests.
|
||||
estimated_time_limit = 0.0
|
||||
if estimated_time_per_shard != 0:
|
||||
estimated_time_limit = serial_time % estimated_time_per_shard
|
||||
if estimated_time_limit <= 0.01:
|
||||
estimated_time_limit = estimated_time_per_shard
|
||||
if total_time == 0:
|
||||
num_serial_shards = num_shards
|
||||
else:
|
||||
num_serial_shards = math.ceil(serial_time / total_time * num_shards)
|
||||
known_tests: Sequence[TestRun] = tests
|
||||
unknown_tests: Sequence[TestRun] = []
|
||||
|
||||
sharded_jobs = [ShardJob() for _ in range(num_shards)]
|
||||
shard(
|
||||
sharded_jobs[:num_serial_shards],
|
||||
serial_tests,
|
||||
test_file_times,
|
||||
test_class_times,
|
||||
estimated_time_limit=estimated_time_limit,
|
||||
sort_by_time=sort_by_time,
|
||||
serial=True,
|
||||
)
|
||||
shard(
|
||||
sharded_jobs,
|
||||
parallel_tests,
|
||||
test_file_times,
|
||||
test_class_times,
|
||||
sort_by_time=sort_by_time,
|
||||
serial=False,
|
||||
)
|
||||
if sort_by_time:
|
||||
known_tests = [
|
||||
x
|
||||
for x in tests
|
||||
if x.test_file in test_file_times
|
||||
or (test_class_times and x.test_file in test_class_times)
|
||||
]
|
||||
unknown_tests = [x for x in tests if x not in known_tests]
|
||||
|
||||
known_tests = get_with_pytest_shard(known_tests, test_file_times, test_class_times)
|
||||
|
||||
if sort_by_time:
|
||||
known_tests = sorted(known_tests, key=lambda j: j.get_time(), reverse=True)
|
||||
|
||||
sharded_jobs: List[ShardJob] = [ShardJob() for _ in range(num_shards)]
|
||||
for test in known_tests:
|
||||
if must_serial(test.name):
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.serial.append(test)
|
||||
else:
|
||||
min_sharded_job = min(sharded_jobs, key=lambda j: j.get_total_time())
|
||||
min_sharded_job.parallel.append(test)
|
||||
|
||||
# Round robin the unknown jobs starting with the smallest shard
|
||||
index = min(range(num_shards), key=lambda i: sharded_jobs[i].get_total_time())
|
||||
for unknown_test in unknown_tests:
|
||||
sharded_jobs[index].serial.append(ShardedTest(unknown_test, 1, 1, None))
|
||||
index = (index + 1) % num_shards
|
||||
return [job.convert_to_tuple() for job in sharded_jobs]
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user