mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add skip_first_wait to profiler.schedule (V2) (#141512)
Summary: Another try for D66198138. Original diff had some weird issue with type checking. Setting everything to int this time to get around it. Addresses https://github.com/pytorch/pytorch/issues/91888 We use wait as the amount you wait in between cycles when profiling and skip_first to delay the start of said profiling. However, once skip_first steps are completed, we immediately go to the wait phase. This is not problematic if wait is smaller than skip_first because we can just lower the values of skip_first, but if it is larger then we end up starting the first profile much later than desired. For example imagine a skip first of 1 and a wait of 100 with repeat of 2. We do want to wait 100 steps in between cycle 1 and 2 but we may not want to start warmup of cycle 1 at step 101 (forced because wait occurs directly after first steps skipped). This diff addresses this by adding a flag to skip the first wait. Adds new flag but sets to false by default so that existing impl is not affected. Test Plan: Got following traces with this schedule: schedule=torch.profiler.schedule( wait=10, warmup=3, active=1, repeat=1, skip_first=1, skip_first_wait=1 ) Differential Revision: D66465860 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141512 Approved by: https://github.com/aaronenyeshi
This commit is contained in:
parent
809de05693
commit
29ca44839e
|
|
@ -2081,6 +2081,48 @@ assert KinetoStepTracker.current_step() == initial_step + 2 * niters
|
|||
else:
|
||||
os.waitpid(pid, 0)
|
||||
|
||||
@skipIfTorchDynamo("profiler gets ignored if dynamo activated")
|
||||
def test_skip_first_wait(self):
|
||||
# Other tests test when skip_first_wait is false (default) so just test the true case
|
||||
test_schedule = torch.profiler.schedule(
|
||||
skip_first=3, wait=5, warmup=1, active=2, repeat=2, skip_first_wait=1
|
||||
)
|
||||
test_schedule_expected_outputs = [
|
||||
# repeat No. 1 begin
|
||||
# skip first 3
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
# warmup 1
|
||||
ProfilerAction.WARMUP,
|
||||
# active 1 begin
|
||||
ProfilerAction.RECORD,
|
||||
ProfilerAction.RECORD_AND_SAVE,
|
||||
# active 1 end
|
||||
# repeat No. 1 end
|
||||
# ---
|
||||
# repeat No. 2 begin
|
||||
# wait 5
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
# warmup 1
|
||||
ProfilerAction.WARMUP,
|
||||
# active 2 begin
|
||||
ProfilerAction.RECORD,
|
||||
ProfilerAction.RECORD_AND_SAVE,
|
||||
# active 2 end
|
||||
# repeat No. 2 end
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
ProfilerAction.NONE,
|
||||
]
|
||||
for step in range(len(test_schedule_expected_outputs)):
|
||||
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
|
||||
|
||||
|
||||
class SimpleNet(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -434,7 +434,13 @@ class ProfilerAction(Enum):
|
|||
|
||||
|
||||
def schedule(
|
||||
*, wait: int, warmup: int, active: int, repeat: int = 0, skip_first: int = 0
|
||||
*,
|
||||
wait: int,
|
||||
warmup: int,
|
||||
active: int,
|
||||
repeat: int = 0,
|
||||
skip_first: int = 0,
|
||||
skip_first_wait: int = 0,
|
||||
) -> Callable:
|
||||
"""
|
||||
Returns a callable that can be used as profiler ``schedule`` argument. The profiler will skip
|
||||
|
|
@ -442,6 +448,13 @@ def schedule(
|
|||
then do the active recording for the next ``active`` steps and then repeat the cycle starting with ``wait`` steps.
|
||||
The optional number of cycles is specified with the ``repeat`` parameter, the zero value means that
|
||||
the cycles will continue until the profiling is finished.
|
||||
|
||||
The ``skip_first_wait`` parameter controls whether the first ``wait`` stage should be skipped.
|
||||
This can be useful if a user wants to wait longer than ``skip_first`` between cycles, but not
|
||||
for the first profile. For example, if ``skip_first`` is 10 and ``wait`` is 20, the first cycle will
|
||||
wait 10 + 20 = 30 steps before warmup if ``skip_first_wait`` is zero, but will wait only 10
|
||||
steps if ``skip_first_wait`` is non-zero. All subsequent cycles will then wait 20 steps between the
|
||||
last active and warmup.
|
||||
"""
|
||||
|
||||
def schedule_fn(step: int) -> ProfilerAction:
|
||||
|
|
@ -450,6 +463,9 @@ def schedule(
|
|||
return ProfilerAction.NONE
|
||||
else:
|
||||
step -= skip_first
|
||||
# If wait >> skip_first and we want to grab profiling early, shift left by wait if skip_first_wait is True
|
||||
if skip_first_wait != 0:
|
||||
step += wait
|
||||
num_steps = wait + warmup + active
|
||||
if repeat > 0 and step / num_steps >= repeat:
|
||||
return ProfilerAction.NONE
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user