Add a RECORD_FUNCTION for Python fallback so it shows in profile (#160573)

Signed-off-by: Edward Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160573
Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
Edward Yang 2025-09-29 10:04:18 -07:00 committed by PyTorch MergeBot
parent 70d1043bdf
commit e901866dd7
5 changed files with 100 additions and 20 deletions

View File

@ -2,6 +2,7 @@
#include <c10/core/impl/PythonDispatcherTLS.h>
#include <ATen/core/PythonFallbackKernel.h>
#include <c10/core/SafePyObject.h>
#include <ATen/record_function.h>
namespace {
@ -53,20 +54,24 @@ void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
// StashTLSOnEntryGuard stash_guard;
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
c10::impl::ExcludeDispatchKeyGuard exclude_guard(after_Python_keyset);
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
if (mode_stack_len > 0) {
RECORD_FUNCTION("PythonDispatchMode", torch::jit::last(*stack, num_arguments));
const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
return;
}
RECORD_FUNCTION("PythonSubclass", torch::jit::last(*stack, num_arguments));
// Otherwise, find a PyInterpreter on a Tensor
const auto& schema = op.schema();
const auto num_arguments = schema.arguments().size();
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
// without checking the interpreters of any of the arguments, because when
// we actually run dispatch(), we will take out PyObjects in the context

View File

@ -967,7 +967,7 @@ class TestProfiler(TestCase):
profiler_output = prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
self.assertIn("Total MFLOPs", profiler_output)
self.assertRegex(profiler_output, "Total M?FLOPs")
if not (kineto_available() and torch.cuda.is_available()):
return
@ -983,7 +983,7 @@ class TestProfiler(TestCase):
profiler_output = kineto_profiler.key_averages().table(
sort_by="self_cuda_time_total", row_limit=-1
)
self.assertIn("Total MFLOPs", profiler_output)
self.assertRegex(profiler_output, "Total M?FLOPs")
def test_override_time_units(self):
US_IN_SECOND = 1000.0 * 1000.0

View File

@ -762,21 +762,22 @@ class TestProfilerTree(TestCase):
torch/profiler/profiler.py(...): __enter__
...
aten::add
torch/_library/simple_registry.py(...): find_torch_dispatch_rule
torch/_library/simple_registry.py(...): find
<built-in method get of dict object at 0xXXXXXXXXXXXX>
torch/_library/simple_registry.py(...): find
<built-in method get of dict object at 0xXXXXXXXXXXXX>
test_profiler_tree.py(...): __torch_dispatch__
torch/utils/_pytree.py(...): tree_map
...
torch/utils/_pytree.py(...): tree_map
...
torch/_ops.py(...): __call__
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
aten::add
torch/utils/_pytree.py(...): tree_map
...
PythonSubclass
torch/_library/simple_registry.py(...): find_torch_dispatch_rule
torch/_library/simple_registry.py(...): find
<built-in method get of dict object at 0xXXXXXXXXXXXX>
torch/_library/simple_registry.py(...): find
<built-in method get of dict object at 0xXXXXXXXXXXXX>
test_profiler_tree.py(...): __torch_dispatch__
torch/utils/_pytree.py(...): tree_map
...
torch/utils/_pytree.py(...): tree_map
...
torch/_ops.py(...): __call__
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
aten::add
torch/utils/_pytree.py(...): tree_map
...
torch/profiler/profiler.py(...): __exit__
torch/profiler/profiler.py(...): stop
...""",

View File

@ -7,6 +7,7 @@ import torch
import torch.optim
import torch.utils.data
import torch.utils.data.datapipes as dp
from torch._dispatch.python import enable_python_dispatcher
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
@ -152,6 +153,79 @@ class TestRecordFunction(TestCase):
self.assertTrue(has_iter)
self.assertTrue(has_child)
def test_python_dispatch_mode_record_function(self):
from torch.utils._python_dispatch import TorchDispatchMode
class TestDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
return func(*args, **kwargs)
with _profile() as prof:
with enable_python_dispatcher():
with TestDispatchMode():
x = torch.randn(3, 4)
y = torch.sin(x)
found_python_dispatch_mode = False
for e in prof.function_events:
if e.name == "PythonDispatchMode":
found_python_dispatch_mode = True
break
self.assertTrue(
found_python_dispatch_mode,
"PythonDispatchMode record function not found in profiler events",
)
def test_python_subclass_record_function(self):
class TestTensorSubclass(torch.Tensor):
@staticmethod
def __new__(cls, elem):
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
dtype=elem.dtype,
device=elem.device,
requires_grad=elem.requires_grad,
)
r.elem = elem
return r
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
def unwrap(x):
return x.elem if isinstance(x, TestTensorSubclass) else x
def wrap(x):
return TestTensorSubclass(x) if isinstance(x, torch.Tensor) else x
unwrapped_args = tuple(unwrap(arg) for arg in args)
unwrapped_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
result = func(*unwrapped_args, **unwrapped_kwargs)
if isinstance(result, torch.Tensor):
return TestTensorSubclass(result)
return result
with _profile() as prof:
with enable_python_dispatcher():
x = TestTensorSubclass(torch.randn(3, 4))
y = torch.sin(x)
found_python_subclass = False
for e in prof.function_events:
if e.name == "PythonSubclass":
found_python_subclass = True
break
self.assertTrue(
found_python_subclass,
"PythonSubclass record function not found in profiler events",
)
if __name__ == "__main__":
run_tests()