mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo] "TorchDynamo Cache Lookup" event: use C++ api (#108436)
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.
**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.
**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.
<details>
<summary>Benchmarking script</summary>
```python
def fn(x, y):
return (x * y).relu()
a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]
opt_fn = torch.compile(fn)
opt_fn(a, b)
opt_fn(a, b)
with torch.profiler.profile() as prof:
opt_fn(a, b)
```
</details>
Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108436
Approved by: https://github.com/anijain2305, https://github.com/jansel
This commit is contained in:
parent
621463a3e6
commit
06b173780d
|
|
@ -829,6 +829,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/autograd/python_variable.cpp",
|
||||
"torch/csrc/autograd/python_variable_indexing.cpp",
|
||||
"torch/csrc/dynamo/python_compiled_autograd.cpp",
|
||||
"torch/csrc/dynamo/cpp_shim.cpp",
|
||||
"torch/csrc/dynamo/cpython_defs.c",
|
||||
"torch/csrc/dynamo/eval_frame.c",
|
||||
"torch/csrc/dynamo/guards.cpp",
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ from torch.ao.quantization import MinMaxObserver
|
|||
from torch.ao.quantization.fake_quantize import FakeQuantize
|
||||
from torch.ao.quantization.qconfig import QConfig
|
||||
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
||||
from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
|
||||
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
||||
from torch.nn import functional as F
|
||||
from torch.testing._internal.common_cuda import (
|
||||
|
|
@ -2506,54 +2505,6 @@ def fn():
|
|||
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
|
||||
self.assertTrue(result[1] == fn.__code__.co_lnotab)
|
||||
|
||||
def test_profiler_cache_lookup(self):
|
||||
def fn(x):
|
||||
y = x**2
|
||||
y = y + 2
|
||||
z = y**3
|
||||
return z
|
||||
|
||||
for profiler, get_events in (
|
||||
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
|
||||
(torch.profiler.profiler.profile, lambda prof: prof.events()),
|
||||
):
|
||||
x = torch.randn((2, 2), requires_grad=True)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="aot_eager")
|
||||
|
||||
# warmup
|
||||
opt_fn(x)
|
||||
|
||||
# whenver we enter the profiler context, hooks are automatically registered
|
||||
with profiler() as prof:
|
||||
res = opt_fn(x)
|
||||
events = list(
|
||||
filter(
|
||||
lambda event: event.name == "TorchDynamo Cache Lookup",
|
||||
get_events(prof),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertTrue(
|
||||
len(events) == 1,
|
||||
"Expected one lookup profiler event for one opt_fn run",
|
||||
)
|
||||
|
||||
with profiler() as prof:
|
||||
# just make sure the disable functionality works
|
||||
_enable_dynamo_cache_lookup_profiler(False)
|
||||
res = opt_fn(x)
|
||||
events = list(
|
||||
filter(
|
||||
lambda event: event.name == "TorchDynamo Cache Lookup",
|
||||
get_events(prof),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertTrue(len(events) == 0, "Expected disabled profiling")
|
||||
|
||||
def test_tensor_is_contiguous(self):
|
||||
def fn(x):
|
||||
input = torch.randn((1, 16, 1, 1))
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import torch._dynamo.test_case
|
|||
import torch._dynamo.testing
|
||||
import torch._dynamo.utils
|
||||
|
||||
from torch._dynamo.testing import same
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
|
||||
|
||||
|
|
@ -91,6 +92,39 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
|
|||
with torch.profiler.profile(record_shapes=True):
|
||||
opt_fn(*inputs)
|
||||
|
||||
def test_profiler_cache_lookup(self):
|
||||
def fn(x):
|
||||
y = x**2
|
||||
y = y + 2
|
||||
z = y**3
|
||||
return z
|
||||
|
||||
for profiler, get_events in (
|
||||
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
|
||||
(torch.profiler.profiler.profile, lambda prof: prof.events()),
|
||||
):
|
||||
x = torch.randn((2, 2), requires_grad=True)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="aot_eager")
|
||||
|
||||
# warmup
|
||||
opt_fn(x)
|
||||
|
||||
with profiler() as prof:
|
||||
res = opt_fn(x)
|
||||
events = list(
|
||||
filter(
|
||||
lambda event: "TorchDynamo Cache Lookup" in event.name,
|
||||
get_events(prof),
|
||||
)
|
||||
)
|
||||
|
||||
self.assertTrue(same(ref, res))
|
||||
self.assertTrue(
|
||||
len(events) == 1,
|
||||
"Expected one lookup profiler event for one opt_fn run",
|
||||
)
|
||||
|
||||
def test_profiler_cache_lookup_profiler_step(self):
|
||||
def fn(x, y, z):
|
||||
return torch.add(torch.sub(x, y), z)
|
||||
|
|
|
|||
|
|
@ -1,11 +1,6 @@
|
|||
import types
|
||||
|
||||
from torch._dynamo.types import (
|
||||
DynamoCallback,
|
||||
DynamoGuardHook,
|
||||
ProfilerEndHook,
|
||||
ProfilerStartHook,
|
||||
)
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
|
||||
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
||||
def reset_code(code: types.CodeType) -> None: ...
|
||||
|
|
@ -13,5 +8,3 @@ def unsupported(obj1: object, obj2: object) -> object: ...
|
|||
def skip_code(code: types.CodeType) -> None: ...
|
||||
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
|
||||
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
||||
def set_profiler_hooks(start: ProfilerStartHook, end: ProfilerEndHook) -> None: ...
|
||||
def clear_profiler_hooks() -> None: ...
|
||||
|
|
|
|||
|
|
@ -79,40 +79,12 @@ def _set_is_profiler_enabled(enable: bool):
|
|||
_is_profiler_enabled = enable
|
||||
|
||||
|
||||
def _enable_dynamo_cache_lookup_profiler(enable: bool):
|
||||
from torch._dynamo.eval_frame import ( # type: ignore[attr-defined]
|
||||
clear_profiler_hooks,
|
||||
set_profiler_hooks,
|
||||
)
|
||||
|
||||
"""
|
||||
Registers a hook within dynamo eval_frame.c called before and after
|
||||
the lookup process, which runs guards associated with each cached frame.
|
||||
|
||||
Clear deregisters the hooks, saving overhead.
|
||||
"""
|
||||
|
||||
if enable:
|
||||
|
||||
def _profiler_start(name):
|
||||
return torch.ops.profiler._record_function_enter_new(name, None)
|
||||
|
||||
def _profiler_end(record):
|
||||
torch.ops.profiler._record_function_exit._RecordFunction(record)
|
||||
|
||||
set_profiler_hooks(_profiler_start, _profiler_end)
|
||||
else:
|
||||
clear_profiler_hooks()
|
||||
|
||||
|
||||
def _run_on_profiler_start():
|
||||
_set_is_profiler_enabled(True)
|
||||
_enable_dynamo_cache_lookup_profiler(True)
|
||||
|
||||
|
||||
def _run_on_profiler_stop():
|
||||
_set_is_profiler_enabled(False)
|
||||
_enable_dynamo_cache_lookup_profiler(False)
|
||||
|
||||
|
||||
class profile:
|
||||
|
|
|
|||
22
torch/csrc/dynamo/cpp_shim.cpp
Normal file
22
torch/csrc/dynamo/cpp_shim.cpp
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#include <torch/csrc/dynamo/cpp_shim.h>
|
||||
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
struct _PytorchRecordFunctionState {
|
||||
at::RecordFunction guard;
|
||||
|
||||
_PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
|
||||
};
|
||||
|
||||
_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
|
||||
_PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
|
||||
state->guard.before(name);
|
||||
return state;
|
||||
}
|
||||
|
||||
void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
|
||||
if (state == nullptr) {
|
||||
return;
|
||||
}
|
||||
delete state;
|
||||
}
|
||||
15
torch/csrc/dynamo/cpp_shim.h
Normal file
15
torch/csrc/dynamo/cpp_shim.h
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct _PytorchRecordFunctionState;
|
||||
typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState;
|
||||
|
||||
_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name);
|
||||
void _pytorch_record_function_exit(_PytorchRecordFunctionState* state);
|
||||
|
||||
#ifdef __cplusplus
|
||||
} // extern "C"
|
||||
#endif
|
||||
|
|
@ -1,4 +1,5 @@
|
|||
#define PY_SSIZE_T_CLEAN
|
||||
#include <torch/csrc/dynamo/cpp_shim.h>
|
||||
#include <torch/csrc/dynamo/cpython_defs.h>
|
||||
#include <torch/csrc/utils/python_compat.h>
|
||||
#include <opcode.h>
|
||||
|
|
@ -179,9 +180,7 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
|
|||
bool is_dynamo_compiling = false;
|
||||
static PyObject* guard_fail_hook = NULL;
|
||||
static PyObject* guard_error_hook = NULL;
|
||||
static PyObject* profiler_start_hook = NULL;
|
||||
static PyObject* profiler_end_hook = NULL;
|
||||
static PyObject* guard_profiler_name_str = NULL; /* cached py str */
|
||||
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||
|
||||
// Points to the extra scratch space on the code object
|
||||
static Py_ssize_t extra_index = -1;
|
||||
|
|
@ -645,22 +644,6 @@ static PyObject* call_guard_fail_hook(
|
|||
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
|
||||
}
|
||||
|
||||
static PyObject* call_profiler_start_hook(PyObject* name_str) {
|
||||
if (profiler_start_hook == NULL) return NULL;
|
||||
return PyObject_CallOneArg(profiler_start_hook, name_str);
|
||||
}
|
||||
|
||||
static void call_profiler_end_hook(PyObject* record) {
|
||||
// 'record' obj is the return value of calling _start_hook()
|
||||
if (profiler_end_hook == NULL || record == NULL) return;
|
||||
PyObject* res = PyObject_CallOneArg(profiler_end_hook, record);
|
||||
if (res == NULL) {
|
||||
PyErr_WriteUnraisable(profiler_end_hook);
|
||||
return;
|
||||
}
|
||||
Py_DECREF(res);
|
||||
}
|
||||
|
||||
// Return value: borrowed reference
|
||||
// Is either Py_None or a PyCodeObject
|
||||
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
|
||||
|
|
@ -939,10 +922,9 @@ static PyObject* _custom_eval_frame(
|
|||
// we never compile.
|
||||
if (callback == Py_False) {
|
||||
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
||||
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
|
||||
call_profiler_end_hook(hook_record);
|
||||
Py_XDECREF(hook_record);
|
||||
_pytorch_record_function_exit(rf);
|
||||
|
||||
if (maybe_cached_code == NULL) {
|
||||
// guard eval failed, keep propagating
|
||||
|
|
@ -965,10 +947,9 @@ static PyObject* _custom_eval_frame(
|
|||
// in the shim.
|
||||
eval_frame_callback_set(Py_None);
|
||||
|
||||
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
|
||||
call_profiler_end_hook(hook_record);
|
||||
Py_XDECREF(hook_record);
|
||||
_pytorch_record_function_exit(rf);
|
||||
if (maybe_cached_code == NULL) {
|
||||
// Python error
|
||||
return NULL;
|
||||
|
|
@ -1131,27 +1112,6 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
|
|||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* clear_profiler_hooks(PyObject* module, PyObject* unused) {
|
||||
Py_CLEAR(profiler_start_hook);
|
||||
Py_CLEAR(profiler_end_hook);
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyObject* set_profiler_hooks(PyObject* module, PyObject* args) {
|
||||
PyObject* start = NULL;
|
||||
PyObject* end = NULL;
|
||||
if (!PyArg_ParseTuple(args, "OO:set_profiler_hooks", &start, &end)) {
|
||||
return NULL;
|
||||
}
|
||||
if (start == Py_None || end == Py_None) {
|
||||
clear_profiler_hooks(module, NULL);
|
||||
} else {
|
||||
Py_XSETREF(profiler_start_hook, Py_NewRef(start));
|
||||
Py_XSETREF(profiler_end_hook, Py_NewRef(end));
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
|
||||
static PyMethodDef _methods[] = {
|
||||
{"set_eval_frame", set_eval_frame_py, METH_O, NULL},
|
||||
{"reset_code", reset_code, METH_O, NULL},
|
||||
|
|
@ -1159,8 +1119,6 @@ static PyMethodDef _methods[] = {
|
|||
{"skip_code", skip_code, METH_O, NULL},
|
||||
{"set_guard_fail_hook", set_guard_fail_hook, METH_O, NULL},
|
||||
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
||||
{"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
|
||||
{"clear_profiler_hooks", clear_profiler_hooks, METH_NOARGS, NULL},
|
||||
{"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
|
||||
{NULL, NULL, 0, NULL}};
|
||||
|
||||
|
|
@ -1180,11 +1138,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
|
|||
return NULL;
|
||||
}
|
||||
|
||||
guard_profiler_name_str = PyUnicode_FromString("TorchDynamo Cache Lookup");
|
||||
if (guard_profiler_name_str == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
|
||||
int result = PyThread_tss_create(&eval_frame_callback_key);
|
||||
CHECK(result == 0);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user