mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes https://github.com/pytorch/pytorch/issues/158164 This was fixed by applying `skip_code_recursive` to any function registered to `sys.monitoring` (via `PyThreadState_GET()->interp->monitoring_callables`). This check is done whenever we attempt to set the eval frame callback from Python. Microbenchmark: `benchmarks/dynamo/microbenchmarks/overheads.py`: BEFORE: ``` requires_grad=False eager 7.1us (warmup=0.0s) compiled 24.6us (warmup=10.0s) requires_grad=True eager 8.9us (warmup=0.0s) compiled 57.8us (warmup=0.1s) inference_mode() eager 6.5us (warmup=0.0s) compiled 23.4us (warmup=0.1s) ``` AFTER: ``` requires_grad=False eager 7.0us (warmup=0.0s) compiled 23.2us (warmup=15.2s) requires_grad=True eager 9.0us (warmup=0.0s) compiled 55.1us (warmup=0.1s) inference_mode() eager 6.4us (warmup=0.0s) compiled 22.2us (warmup=0.1s) ``` Followup thought: how do we let users know that a frame is skipped because the code object is a callable registered to sys.monitoring? (or any other reason?) Differential Revision: [D78530528](https://our.internmc.facebook.com/intern/diff/D78530528) Pull Request resolved: https://github.com/pytorch/pytorch/pull/158171 Approved by: https://github.com/jansel
27 lines
517 B
C
27 lines
517 B
C
#pragma once
|
|
#include <Python.h>
|
|
|
|
#include <torch/csrc/dynamo/eval_frame.h>
|
|
#include <torch/csrc/dynamo/extra_state.h>
|
|
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
|
#ifdef __cplusplus
|
|
|
|
extern "C" {
|
|
|
|
#endif
|
|
|
|
PyObject* dynamo__custom_eval_frame(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag,
|
|
PyObject* callback);
|
|
|
|
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* obj);
|
|
void skip_code_recursive(PyCodeObject* code);
|
|
|
|
#ifdef __cplusplus
|
|
|
|
} // extern "C"
|
|
|
|
#endif
|