Add skip_first_wait to profiler.schedule (V2) (#141512)

Summary:
Another try for D66198138. Original diff had some weird issue with type checking. Setting everything to int this time to get around it.

Addresses https://github.com/pytorch/pytorch/issues/91888
We use wait as the amount you wait in between cycles when profiling and skip_first to delay the start of said profiling. However, once skip_first steps are completed, we immediately go to the wait phase. This is not problematic if wait is smaller than skip_first because we can just lower the values of skip_first, but if it is larger then we end up starting the first profile much later than desired. For example imagine a skip first of 1 and a wait of 100 with repeat of 2. We do want to wait 100 steps in between cycle 1 and 2 but we may not want to start warmup of cycle 1 at step 101 (forced because wait occurs directly after first steps skipped). This diff addresses this by adding a flag to skip the first wait.
Adds new flag but sets to false by default so that existing impl is not affected.

Test Plan:
Got following traces with this schedule:
schedule=torch.profiler.schedule(
          wait=10, warmup=3, active=1, repeat=1, skip_first=1, skip_first_wait=1
      )

Differential Revision: D66465860

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141512
Approved by: https://github.com/aaronenyeshi
This commit is contained in:
Shivam Raikundalia 2024-11-26 18:10:54 +00:00 committed by PyTorch MergeBot
parent 809de05693
commit 29ca44839e
2 changed files with 59 additions and 1 deletions

View File

@ -2081,6 +2081,48 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
else:
os.waitpid(pid, 0)
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
def test_skip_first_wait(self):
# Other tests test when skip_first_wait is false (default) so just test the true case
test_schedule = torch.profiler.schedule(
skip_first=3, wait=5, warmup=1, active=2, repeat=2, skip_first_wait=1
)
test_schedule_expected_outputs = [
# repeat No. 1 begin
# skip first 3
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
# warmup 1
ProfilerAction.WARMUP,
# active 1 begin
ProfilerAction.RECORD,
ProfilerAction.RECORD_AND_SAVE,
# active 1 end
# repeat No. 1 end
# ---
# repeat No. 2 begin
# wait 5
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
# warmup 1
ProfilerAction.WARMUP,
# active 2 begin
ProfilerAction.RECORD,
ProfilerAction.RECORD_AND_SAVE,
# active 2 end
# repeat No. 2 end
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
ProfilerAction.NONE,
]
for step in range(len(test_schedule_expected_outputs)):
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
class SimpleNet(nn.Module):
def __init__(self) -> None:

View File

@ -434,7 +434,13 @@ class ProfilerAction(Enum):
def schedule(
*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0
*,
wait: int,
warmup: int,
active: int,
repeat: int = 0,
skip_first: int = 0,
skip_first_wait: int = 0,
) -> Callable:
"""
Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
@ -442,6 +448,13 @@ def schedule(
then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
the cycles will continue until the profiling is finished.
The ``skip_first_wait`` parameter controls whether the first ``wait`` stage should be skipped.
This can be useful if a user wants to wait longer than ``skip_first`` between cycles, but not
for the first profile. For example, if ``skip_first`` is 10 and ``wait`` is 20, the first cycle will
wait 10 + 20 = 30 steps before warmup if ``skip_first_wait`` is zero, but will wait only 10
steps if ``skip_first_wait`` is non-zero. All subsequent cycles will then wait 20 steps between the
last active and warmup.
"""
def schedule_fn(step: int) -> ProfilerAction:
@ -450,6 +463,9 @@ def schedule(
return ProfilerAction.NONE
else:
step -= skip_first
# If wait >> skip_first and we want to grab profiling early, shift left by wait if skip_first_wait is True
if skip_first_wait != 0:
step += wait
num_steps = wait + warmup + active
if repeat > 0 and step / num_steps >= repeat:
return ProfilerAction.NONE