mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add a RECORD_FUNCTION for Python fallback so it shows in profile (#160573)
Signed-off-by: Edward Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160573 Approved by: https://github.com/bdhirsh, https://github.com/albanD
This commit is contained in:
parent
70d1043bdf
commit
e901866dd7
|
|
@ -2,6 +2,7 @@
|
|||
#include <c10/core/impl/PythonDispatcherTLS.h>
|
||||
#include <ATen/core/PythonFallbackKernel.h>
|
||||
#include <c10/core/SafePyObject.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
namespace {
|
||||
|
||||
|
|
@ -53,20 +54,24 @@ void pythonFallback(const c10::OperatorHandle& op, c10::DispatchKeySet dispatch_
|
|||
TORCH_INTERNAL_ASSERT(tls_on_entry.has_value());
|
||||
// c10::impl::ForceDispatchKeyGuard dispatcher_guard(tls_on_entry.value());
|
||||
// StashTLSOnEntryGuard stash_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(after_Python_keyset);
|
||||
c10::impl::ExcludeDispatchKeyGuard exclude_guard(after_Python_keyset);
|
||||
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
// If Torch Dispatch Mode is active, use its PyInterpreter for dispatch
|
||||
const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len();
|
||||
if (mode_stack_len > 0) {
|
||||
RECORD_FUNCTION("PythonDispatchMode", torch::jit::last(*stack, num_arguments));
|
||||
const auto& cur_torch_dispatch_mode_state = c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1);
|
||||
cur_torch_dispatch_mode_state->pyinterpreter()->dispatch(op, stack);
|
||||
return;
|
||||
}
|
||||
|
||||
RECORD_FUNCTION("PythonSubclass", torch::jit::last(*stack, num_arguments));
|
||||
|
||||
// Otherwise, find a PyInterpreter on a Tensor
|
||||
const auto& schema = op.schema();
|
||||
const auto num_arguments = schema.arguments().size();
|
||||
|
||||
// It is safe to dispatch on the very first Tensor with a pyobj_interpreter
|
||||
// without checking the interpreters of any of the arguments, because when
|
||||
// we actually run dispatch(), we will take out PyObjects in the context
|
||||
|
|
|
|||
|
|
@ -967,7 +967,7 @@ class TestProfiler(TestCase):
|
|||
profiler_output = prof.key_averages(group_by_input_shape=True).table(
|
||||
sort_by="cpu_time_total", row_limit=10
|
||||
)
|
||||
self.assertIn("Total MFLOPs", profiler_output)
|
||||
self.assertRegex(profiler_output, "Total M?FLOPs")
|
||||
if not (kineto_available() and torch.cuda.is_available()):
|
||||
return
|
||||
|
||||
|
|
@ -983,7 +983,7 @@ class TestProfiler(TestCase):
|
|||
profiler_output = kineto_profiler.key_averages().table(
|
||||
sort_by="self_cuda_time_total", row_limit=-1
|
||||
)
|
||||
self.assertIn("Total MFLOPs", profiler_output)
|
||||
self.assertRegex(profiler_output, "Total M?FLOPs")
|
||||
|
||||
def test_override_time_units(self):
|
||||
US_IN_SECOND = 1000.0 * 1000.0
|
||||
|
|
|
|||
|
|
@ -762,21 +762,22 @@ class TestProfilerTree(TestCase):
|
|||
torch/profiler/profiler.py(...): __enter__
|
||||
...
|
||||
aten::add
|
||||
torch/_library/simple_registry.py(...): find_torch_dispatch_rule
|
||||
torch/_library/simple_registry.py(...): find
|
||||
<built-in method get of dict object at 0xXXXXXXXXXXXX>
|
||||
torch/_library/simple_registry.py(...): find
|
||||
<built-in method get of dict object at 0xXXXXXXXXXXXX>
|
||||
test_profiler_tree.py(...): __torch_dispatch__
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/_ops.py(...): __call__
|
||||
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
|
||||
aten::add
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
PythonSubclass
|
||||
torch/_library/simple_registry.py(...): find_torch_dispatch_rule
|
||||
torch/_library/simple_registry.py(...): find
|
||||
<built-in method get of dict object at 0xXXXXXXXXXXXX>
|
||||
torch/_library/simple_registry.py(...): find
|
||||
<built-in method get of dict object at 0xXXXXXXXXXXXX>
|
||||
test_profiler_tree.py(...): __torch_dispatch__
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/_ops.py(...): __call__
|
||||
<built-in method of PyCapsule object at 0xXXXXXXXXXXXX>
|
||||
aten::add
|
||||
torch/utils/_pytree.py(...): tree_map
|
||||
...
|
||||
torch/profiler/profiler.py(...): __exit__
|
||||
torch/profiler/profiler.py(...): stop
|
||||
...""",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import torch
|
|||
import torch.optim
|
||||
import torch.utils.data
|
||||
import torch.utils.data.datapipes as dp
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch.autograd import (
|
||||
_record_function_with_args_enter,
|
||||
_record_function_with_args_exit,
|
||||
|
|
@ -152,6 +153,79 @@ class TestRecordFunction(TestCase):
|
|||
self.assertTrue(has_iter)
|
||||
self.assertTrue(has_child)
|
||||
|
||||
def test_python_dispatch_mode_record_function(self):
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
class TestDispatchMode(TorchDispatchMode):
|
||||
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
return func(*args, **kwargs)
|
||||
|
||||
with _profile() as prof:
|
||||
with enable_python_dispatcher():
|
||||
with TestDispatchMode():
|
||||
x = torch.randn(3, 4)
|
||||
y = torch.sin(x)
|
||||
|
||||
found_python_dispatch_mode = False
|
||||
for e in prof.function_events:
|
||||
if e.name == "PythonDispatchMode":
|
||||
found_python_dispatch_mode = True
|
||||
break
|
||||
self.assertTrue(
|
||||
found_python_dispatch_mode,
|
||||
"PythonDispatchMode record function not found in profiler events",
|
||||
)
|
||||
|
||||
def test_python_subclass_record_function(self):
|
||||
class TestTensorSubclass(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
dtype=elem.dtype,
|
||||
device=elem.device,
|
||||
requires_grad=elem.requires_grad,
|
||||
)
|
||||
r.elem = elem
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
|
||||
def unwrap(x):
|
||||
return x.elem if isinstance(x, TestTensorSubclass) else x
|
||||
|
||||
def wrap(x):
|
||||
return TestTensorSubclass(x) if isinstance(x, torch.Tensor) else x
|
||||
|
||||
unwrapped_args = tuple(unwrap(arg) for arg in args)
|
||||
unwrapped_kwargs = {k: unwrap(v) for k, v in kwargs.items()}
|
||||
result = func(*unwrapped_args, **unwrapped_kwargs)
|
||||
|
||||
if isinstance(result, torch.Tensor):
|
||||
return TestTensorSubclass(result)
|
||||
return result
|
||||
|
||||
with _profile() as prof:
|
||||
with enable_python_dispatcher():
|
||||
x = TestTensorSubclass(torch.randn(3, 4))
|
||||
y = torch.sin(x)
|
||||
|
||||
found_python_subclass = False
|
||||
for e in prof.function_events:
|
||||
if e.name == "PythonSubclass":
|
||||
found_python_subclass = True
|
||||
break
|
||||
self.assertTrue(
|
||||
found_python_subclass,
|
||||
"PythonSubclass record function not found in profiler events",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user