[dynamo] "TorchDynamo Cache Lookup" event: use C++ api (#108436)

**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.

**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.

**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.

<details>

<summary>Benchmarking script</summary>

```python
def fn(x, y):
    return (x * y).relu()

a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]

opt_fn = torch.compile(fn)

opt_fn(a, b)
opt_fn(a, b)

with torch.profiler.profile() as prof:
    opt_fn(a, b)
```

</details>

Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108436
Approved by: https://github.com/anijain2305, https://github.com/jansel
This commit is contained in:
David Berard 2023-09-01 17:32:57 -07:00 committed by PyTorch MergeBot
parent 621463a3e6
commit 06b173780d
8 changed files with 79 additions and 138 deletions

View File

@ -829,6 +829,7 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/dynamo/python_compiled_autograd.cpp",
"torch/csrc/dynamo/cpp_shim.cpp",
"torch/csrc/dynamo/cpython_defs.c",
"torch/csrc/dynamo/eval_frame.c",
"torch/csrc/dynamo/guards.cpp",

View File

@ -47,7 +47,6 @@ from torch.ao.quantization import MinMaxObserver
from torch.ao.quantization.fake_quantize import FakeQuantize
from torch.ao.quantization.qconfig import QConfig
from torch.ao.quantization.quantize_fx import prepare_qat_fx
from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
from torch.nn import functional as F
from torch.testing._internal.common_cuda import (
@ -2506,54 +2505,6 @@ def fn():
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
self.assertTrue(result[1] == fn.__code__.co_lnotab)
def test_profiler_cache_lookup(self):
def fn(x):
y = x**2
y = y + 2
z = y**3
return z
for profiler, get_events in (
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
(torch.profiler.profiler.profile, lambda prof: prof.events()),
):
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager")
# warmup
opt_fn(x)
# whenver we enter the profiler context, hooks are automatically registered
with profiler() as prof:
res = opt_fn(x)
events = list(
filter(
lambda event: event.name == "TorchDynamo Cache Lookup",
get_events(prof),
)
)
self.assertTrue(same(ref, res))
self.assertTrue(
len(events) == 1,
"Expected one lookup profiler event for one opt_fn run",
)
with profiler() as prof:
# just make sure the disable functionality works
_enable_dynamo_cache_lookup_profiler(False)
res = opt_fn(x)
events = list(
filter(
lambda event: event.name == "TorchDynamo Cache Lookup",
get_events(prof),
)
)
self.assertTrue(same(ref, res))
self.assertTrue(len(events) == 0, "Expected disabled profiling")
def test_tensor_is_contiguous(self):
def fn(x):
input = torch.randn((1, 16, 1, 1))

View File

@ -7,6 +7,7 @@ import torch._dynamo.test_case
import torch._dynamo.testing
import torch._dynamo.utils
from torch._dynamo.testing import same
from torch._dynamo.utils import dynamo_timed
@ -91,6 +92,39 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
with torch.profiler.profile(record_shapes=True):
opt_fn(*inputs)
def test_profiler_cache_lookup(self):
def fn(x):
y = x**2
y = y + 2
z = y**3
return z
for profiler, get_events in (
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
(torch.profiler.profiler.profile, lambda prof: prof.events()),
):
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
opt_fn = torch.compile(fn, backend="aot_eager")
# warmup
opt_fn(x)
with profiler() as prof:
res = opt_fn(x)
events = list(
filter(
lambda event: "TorchDynamo Cache Lookup" in event.name,
get_events(prof),
)
)
self.assertTrue(same(ref, res))
self.assertTrue(
len(events) == 1,
"Expected one lookup profiler event for one opt_fn run",
)
def test_profiler_cache_lookup_profiler_step(self):
def fn(x, y, z):
return torch.add(torch.sub(x, y), z)

View File

@ -1,11 +1,6 @@
import types
from torch._dynamo.types import (
DynamoCallback,
DynamoGuardHook,
ProfilerEndHook,
ProfilerStartHook,
)
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def reset_code(code: types.CodeType) -> None: ...
@ -13,5 +8,3 @@ def unsupported(obj1: object, obj2: object) -> object: ...
def skip_code(code: types.CodeType) -> None: ...
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
def set_profiler_hooks(start: ProfilerStartHook, end: ProfilerEndHook) -> None: ...
def clear_profiler_hooks() -> None: ...

View File

@ -79,40 +79,12 @@ def _set_is_profiler_enabled(enable: bool):
_is_profiler_enabled = enable
def _enable_dynamo_cache_lookup_profiler(enable: bool):
from torch._dynamo.eval_frame import ( # type: ignore[attr-defined]
clear_profiler_hooks,
set_profiler_hooks,
)
"""
Registers a hook within dynamo eval_frame.c called before and after
the lookup process, which runs guards associated with each cached frame.
Clear deregisters the hooks, saving overhead.
"""
if enable:
def _profiler_start(name):
return torch.ops.profiler._record_function_enter_new(name, None)
def _profiler_end(record):
torch.ops.profiler._record_function_exit._RecordFunction(record)
set_profiler_hooks(_profiler_start, _profiler_end)
else:
clear_profiler_hooks()
def _run_on_profiler_start():
_set_is_profiler_enabled(True)
_enable_dynamo_cache_lookup_profiler(True)
def _run_on_profiler_stop():
_set_is_profiler_enabled(False)
_enable_dynamo_cache_lookup_profiler(False)
class profile:

View File

@ -0,0 +1,22 @@
#include <torch/csrc/dynamo/cpp_shim.h>
#include <ATen/record_function.h>
struct _PytorchRecordFunctionState {
at::RecordFunction guard;
_PytorchRecordFunctionState() : guard(at::RecordScope::FUNCTION) {}
};
_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name) {
_PytorchRecordFunctionState* state = new _PytorchRecordFunctionState();
state->guard.before(name);
return state;
}
void _pytorch_record_function_exit(_PytorchRecordFunctionState* state) {
if (state == nullptr) {
return;
}
delete state;
}

View File

@ -0,0 +1,15 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
struct _PytorchRecordFunctionState;
typedef struct _PytorchRecordFunctionState _PytorchRecordFunctionState;
_PytorchRecordFunctionState* _pytorch_record_function_enter(const char* name);
void _pytorch_record_function_exit(_PytorchRecordFunctionState* state);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -1,4 +1,5 @@
#define PY_SSIZE_T_CLEAN
#include <torch/csrc/dynamo/cpp_shim.h>
#include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/utils/python_compat.h>
#include <opcode.h>
@ -179,9 +180,7 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
bool is_dynamo_compiling = false;
static PyObject* guard_fail_hook = NULL;
static PyObject* guard_error_hook = NULL;
static PyObject* profiler_start_hook = NULL;
static PyObject* profiler_end_hook = NULL;
static PyObject* guard_profiler_name_str = NULL; /* cached py str */
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
// Points to the extra scratch space on the code object
static Py_ssize_t extra_index = -1;
@ -645,22 +644,6 @@ static PyObject* call_guard_fail_hook(
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
}
static PyObject* call_profiler_start_hook(PyObject* name_str) {
if (profiler_start_hook == NULL) return NULL;
return PyObject_CallOneArg(profiler_start_hook, name_str);
}
static void call_profiler_end_hook(PyObject* record) {
// 'record' obj is the return value of calling _start_hook()
if (profiler_end_hook == NULL || record == NULL) return;
PyObject* res = PyObject_CallOneArg(profiler_end_hook, record);
if (res == NULL) {
PyErr_WriteUnraisable(profiler_end_hook);
return;
}
Py_DECREF(res);
}
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
@ -939,10 +922,9 @@ static PyObject* _custom_eval_frame(
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
@ -965,10 +947,9 @@ static PyObject* _custom_eval_frame(
// in the shim.
eval_frame_callback_set(Py_None);
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
return NULL;
@ -1131,27 +1112,6 @@ static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
Py_RETURN_NONE;
}
static PyObject* clear_profiler_hooks(PyObject* module, PyObject* unused) {
Py_CLEAR(profiler_start_hook);
Py_CLEAR(profiler_end_hook);
Py_RETURN_NONE;
}
static PyObject* set_profiler_hooks(PyObject* module, PyObject* args) {
PyObject* start = NULL;
PyObject* end = NULL;
if (!PyArg_ParseTuple(args, "OO:set_profiler_hooks", &start, &end)) {
return NULL;
}
if (start == Py_None || end == Py_None) {
clear_profiler_hooks(module, NULL);
} else {
Py_XSETREF(profiler_start_hook, Py_NewRef(start));
Py_XSETREF(profiler_end_hook, Py_NewRef(end));
}
Py_RETURN_NONE;
}
static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame_py, METH_O, NULL},
{"reset_code", reset_code, METH_O, NULL},
@ -1159,8 +1119,6 @@ static PyMethodDef _methods[] = {
{"skip_code", skip_code, METH_O, NULL},
{"set_guard_fail_hook", set_guard_fail_hook, METH_O, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
{"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
{"clear_profiler_hooks", clear_profiler_hooks, METH_NOARGS, NULL},
{"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
@ -1180,11 +1138,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
return NULL;
}
guard_profiler_name_str = PyUnicode_FromString("TorchDynamo Cache Lookup");
if (guard_profiler_name_str == NULL) {
return NULL;
}
int result = PyThread_tss_create(&eval_frame_callback_key);
CHECK(result == 0);