[Profiler] Hide Kineto Step Tracker Behind Env Var (#144494)

Summary:
To support iteration-based on-demand we have step tracker hooks for both the scheduler and for the optimizer to control Kineto's backend FSM. We already hide the optimizer step tracker behind and ENV_VAR to prevent any extra overhead from the frontend profiler down to the kineto backend, but we don't do any such thing for the profiler step tracker. It also seems to cause errors occasionally in the FSM having both auto-trace and on-demand occurring at the same time.

To remedy this issue, lets put in a patch to guard the step incrementer for the frontend step function. This will bypass all of the on-demand logic which shouldn't occur in auto-trace

Test Plan:
Ran
`buck run mode/dev-nosan kineto/libkineto/fb/integration_tests:pytorch_resnet_integration_test -- --enable_profiling --trace_handler=auto_trace --with_stack` and added prints in on-demand functions (performLoopStep and collectTrace) and saw that neither were called even though they were called on main.

Also got following healthy traces:

Auto-Trace (schedule-based):
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/devvm2185.cco0.facebook.com/rank-0.Jan_09_12_43_37.1122140.pt.trace.json.gz&bucket=gpu_traces

Timing Based On-demand:
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/0/1736456722/localhost/libkineto_activities_1286261.json.gz&bucket=gpu_traces

Iteration Based On-demand:
https://www.internalfb.com/intern/perfdoctor/trace_view?filepath=tree/traces/dynocli/0/1736456889/localhost/libkineto_activities_1304781.json.gz&bucket=gpu_traces

Differential Revision: D67990080

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144494
Approved by: https://github.com/ngimel
This commit is contained in:
Shivam Raikundalia 2025-01-10 07:00:55 +00:00 committed by PyTorch MergeBot
parent 8cc8989b26
commit f295eff512
2 changed files with 9 additions and 1 deletions

View File

@ -950,6 +950,8 @@ class TestProfiler(TestCase):
) )
self.assertIn("Total MFLOPs", profiler_output) self.assertIn("Total MFLOPs", profiler_output)
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
@patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
def test_kineto_profiler_api(self): def test_kineto_profiler_api(self):
called_num = [0] called_num = [0]
@ -1034,6 +1036,8 @@ class TestProfiler(TestCase):
for step in range(len(test_schedule_expected_outputs)): for step in range(len(test_schedule_expected_outputs)):
self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step]) self.assertEqual(test_schedule(step), test_schedule_expected_outputs[step])
@patch.dict(os.environ, {"KINETO_USE_DAEMON": "1"})
@patch.dict(os.environ, {"KINETO_DAEMON_INIT_DELAY_S": "1"})
def test_kineto_profiler_multiple_steppers(self): def test_kineto_profiler_multiple_steppers(self):
niters = 8 niters = 8
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities() use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()

View File

@ -21,6 +21,7 @@ from torch._C._profiler import (
_ExperimentalConfig, _ExperimentalConfig,
_remove_execution_trace_observer, _remove_execution_trace_observer,
) )
from torch._environment import is_fbcode
from torch.autograd import kineto_available, ProfilerActivity from torch.autograd import kineto_available, ProfilerActivity
from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline from torch.profiler._memory_profiler import MemoryProfile, MemoryProfileTimeline
@ -821,6 +822,9 @@ class profile(_KinetoProfile):
self.current_action = self.schedule(self.step_num) self.current_action = self.schedule(self.step_num)
self._transit_action(prev_action, self.current_action) self._transit_action(prev_action, self.current_action)
if os.environ.get("KINETO_USE_DAEMON", "") or (
is_fbcode() and os.environ.get("KINETO_FORCE_STEP_HOOK", "")
):
prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME) prof.KinetoStepTracker.increment_step(PROFILER_STEP_NAME)
if self.record_steps: if self.record_steps: