This PR and the previous:
- Moves parts of `eval_frame.c` to C++.
- Reduces code duplication in `dynamo__custom_eval_frame` and makes the control flow more clear.
- Enables `convert_frame` to signal to `eval_frame.cpp` in a general manner how to evaluate this frame, recursive frames, and future frames with the same code object (default/compile, skip, run-only). e.g. this will allow us to change skipping/cache limit hit eval_frame behavior directly from convert_frame without requiring changes to C/C++.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146355
Approved by: https://github.com/jansel
ghstack dependencies: #145603
Fixes#108942
this PR converts eval_frame.c's static extension types to heap types, making it thread and sub-interpreter safe.
the current modification only showcases one state variable being lifted, but there are opportunities for other variables that can be addressed in this PR
todo / suggestions:
1. uplift `eval_frame_callback_key` to module state
2. define `.m_slots` to module definition so initialization is within python's module lifecycle rather than an explicit `torch_c_dynamo_eval_frame_init`
3. define configurations for module allowing sub-interpreters or not
```c
static int module_exec(PyObject *m) {}
static PyModuleDef_Slot module_slots[] = {
{Py_mod_exec, module_exec},
{0, NULL}
};
static struct PyModuleDef module = {
PyModuleDef_HEAD_INIT,
....
.m_slots = module_slots
};
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141357
Approved by: https://github.com/jansel
Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Implements https://github.com/pytorch/pytorch/issues/93753 - move frame local guard accessors to C++.
Before, we used dict accessors on a Python dict representing the frame's fastlocals that we manually build. We move this accessor to C++ and additionally use the fastlocal index whenever possible.
Some implementation notes:
- `FrameLocalsMapping` is now initialized as a C++ vector of `PyObject`s. We do not just use the frame's localsplus/fastlocals buffer because we also unbox cells.
- `FrameLocalsMapping` can still be converted into a Python dict representing the frame's fastlocals, but it is done lazily.
- We update `LeafGuard`, `GuardAccessor`, and `GuardManager`'s `check_nopybind` methods to accept `FrameLocalsMapping`. By default, we convert the `FrameLocalsMapping` to a Python dict and run the original `check_nopybind` on it, but in some cases, conversion is not needed.
- We add a new guard accessor `FrameLocalsGuardAccessor`, which is similar to `DictGetItemGuardAccessor` but has special handling for `FrameLocalsMapping`. We create a separate class to emphasize different use cases, but we could probably combine these two (can do in a follow up)
dynamo_guard_eval.py microbenchmark update:
- 713.2us -> 630.0us (3.10)
- 598.8us -> 530.7us (3.12)
Other followups:
- Add `FrameLocalsMapping` version for `check_verbose_nopybind` in order to match behavior between `check_nopybind` and `check_verbose_nopybind`. This can prevent difficult debugging situations where guards fail (`check_nopybind` returns false) but no guard error message is generated (`check_verbose_nopybind` succeeds).
- Rewrite the `SHAPE_ENV` guard into C++ - it is a fairly common guard that results in `FrameLocalsMapping` needing to convert to a dict
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140063
Approved by: https://github.com/jansel
ghstack dependencies: #142117, #142430
In `match_nested_cell`, Dynamo tried to identify pre-existing captured
cells by `(cell_name, id(cell_contents))`. This works in most cases, but
as the test added in this patch shows, it's not a complete solution.
This patch
1. changes `match_nested_cell` to `lookup_variable_for_captured_cell`,
and does the lookup based on id of cell objects, not their contents.
This requires plumbing a tuple of captured cell objects from
different CPython versions all the way to
`InstructionTranslator.__init__`, where we store a mapping from the
ids of these cell objects, and use it later in
`UserFunctionVariable.bind_args` to look for these unboxed cells.
2. builds off (1) -- rather than using a `VariableTracker` that
represents the content of the unboxed cells, use `ClosureVariable`,
which enables codegen in case these cells escape as closure of a
`NestedUserFunctionVariable`.
The patch adds a regression test for each of the scenarios above:
1. `test_write_to_cells_with_name_shadowing` where Dynamo mistakenly
thought the program is writing to a cell captured by root frame (which
it doesn't support atm), which resulted in
```
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/symbolic_convert.py", line 3340, in STORE_DEREF
unimplemented("write to __closure__ while inlining")
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/exc.py", line 313, in unimplemented
raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: write to __closure__ while inlining
```
2. `test_existing_func_that_creates_capturing_nested_func` where Dynamo
ended up trying to codegen a `NestedUserFunctionVariable` that
captures a cell which was also captured by the root frame, so it was
unboxed and ends up emitting `LOAD_DEREF` rather than
`LOAD_FAST/LOAD_CLOSURE` during codegen, resulting in
```
File "/Users/ryanguo99/Documents/work/pytorch/torch/_dynamo/variables/functions.py", line 105, in _create_nested_fn
func = FunctionType(code, f_globals, name, defaults, closure)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: arg 5 (closure) expected cell, found int
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140436
Approved by: https://github.com/jansel, https://github.com/williamwen42
ghstack dependencies: #140330, #140152
This patch introduces a `DynamoFrameType` to serve as a layer between
Dynamo and different versions of Python frame object. In
`DynamoFrameType`, we only register attributes Dynamo cares about (e.g.,
`f_code`, `f_locals`, etc.
This will be helpful when it comes to adding new attributes to this
`DynamoFrameType`, or dealing with Python version changes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140330
Approved by: https://github.com/jansel, https://github.com/williamwen42
This PR enables you to inspect PyObjects in C using `INSPECT(...)` without requiring https://docs.python.org/3/howto/gdb_helpers.html. `torch._dynamo.eval_frame.raise_sigtrap` can also be used to set gdb breakpoints while running Python code, e.g.
```python
x = x + 1
torch._dynamo.eval_frame.raise_sigtrap();
# can breakpoint on ceval.c:CALL to breakpoint the `sin` call in C.
x = torch.sin(x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138030
Approved by: https://github.com/jansel
Summary: In D60803317, we added CompileContext (trace_id) information to Kineto traces using caching when a CompileContext exits. As pointed out by some users, this gives innaccurate IDs because we are not getting the context that we is being looked up within the eval_frame. For this reason, we decided to revert that change, and go with an approach that involves getting the trace_id associated with a given CacheEntry. To do this, we add a trace_id to the GuardedCode so that it can be passed onto a CacheEntry. Then, we change the lookup function to return said trace_id alongside the code so that we can pass both into our eval function. Once we get to a Torch-Compiled Region, we can just append the context information to the name of the annotation thus bypassing any need for kwargs.
Test Plan: Added more comprehensive unit test. Saw that all the trace_ids appeared within the graph.
Differential Revision: D63138786
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136460
Approved by: https://github.com/ezyang
Summary:
We want to add compile IDs and frames to each Torch-Compiled Region in order to help users cross reference the section they are checking alongside data obtained from tools, such as tlparse.
This diff operates on the assumption that each graph section will enter and exit a CompileContext before it is ran to either compile the graph or look it up in the cache. Based on this assuption, we can save the value of the graph section from the exited CompileContext in eval_frame.c using a Python C API. After this, we can create a new interface in cpp shim to wrap around the record_function in order to pass in the new keyword argument for "context".
Test Plan:
Enhance test_profiler_dynamo_compiled_region to look for kwinputs as well as a name to see that the context is now labeled. Also changed test to run graph with more contexts so that we test a wider range of profiling.
Differential Revision: D60803317
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132765
Approved by: https://github.com/anijain2305
Construct frame localsplus in 3.12+ using our own simplified way rather than copypasting from CPython.
This is necessary for 3.13 since we can no longer generate frame `f_locals` before executing the interpreter frame.
We also enable this for 3.12 since the `f_locals` construction between 3.12 and 3.13 is the same, so we can test for correctness with 3.12.
This is also one of the first steps to completing https://github.com/pytorch/pytorch/issues/93753 - we will implement simplified f_locals generation of previous Python versions in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129185
Approved by: https://github.com/jansel
Construct frame localsplus in 3.12+ using our own simplified way rather than copypasting from CPython.
This is necessary for 3.13 since we can no longer generate frame `f_locals` before executing the interpreter frame.
We also enable this for 3.12 since the `f_locals` construction between 3.12 and 3.13 is the same, so we can test for correctness with 3.12.
This is also one of the first steps to completing https://github.com/pytorch/pytorch/issues/93753 - we will implement simplified f_locals generation of previous Python versions in the future.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129185
Approved by: https://github.com/jansel
This is mainly:
- Fix refcount access macro
- Hide all the Dynamo code that needs update as usual
- Add _PyWeakref_ClearRef as an extern provided by CPython. Including the pycore header that defines it would require raw c include shenanigans that I don't think are worth it.
This allows to build both with regular and nogil version of cpython. Both
Note that this requires the 3.13 branch at least past [d3094744d40de2deefbda9b1996d5029c9ebf0b0](d3094744d4) which we need for mimalloc include and weakref function being exposed.
debug-only issues in pybind11 with PyMem_MALLOC vs PyObject_MALLOC being should be synced either by updating pybind or cpython. @colesbury I can send a PR to ifdef the proper use in pybind if you think that this is the best solution here?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126033
Approved by: https://github.com/colesbury
Fixes https://github.com/pytorch/pytorch/issues/119607 for 3.11+.
In 3.11+, `_PyFrame_FastToLocalsWithError` could implicity run `COPY_FREE_VARS` on the original frame, leading to double incref's since the dynamo shadow frame can rerun `COPY_FREE_VARS`. So the solution is to skip the first `COPY_FREE_VARS` instruction in the shadow frame if it was already executed in the original frame.
Also move the location for clearing the original frame in 3.12 to handle error cases more thoroughly.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124238
Approved by: https://github.com/jansel
Python 3.12 changed a few things with how `_PyInterpreterFrame`s are allocated and freed:
- Frames are now required to be placed on the Python frame stack. In 3.11, we could allocate frames anywhere in memory. In 3.12, we now need to use `THP_PyThreadState_BumpFramePointerSlow`/`push_chunk`/`allocate_chunk`. This method of allocating/freeing frames is also compatible with 3.11.
- The eval frame function is now responsible for clearing the frame (see https://docs.python.org/3/whatsnew/changelog.html#id128, the point about "...which now clear the frame.")
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122146
Approved by: https://github.com/jansel
Attempt number 2 at https://github.com/pytorch/pytorch/issues/108950.
Improves debugging for guard failures/recompilations by:
- only running guard fail reason generation during recompilation, instead of when a guard fails during dynamo cache lookup (so generating guard failure reasons is not on the critical path)
- ~~always reporting all guard failures~~ Reports the first-failing guard failure for each cache entry.
We don't expect a performance hit since the guard fail reasons are only generated at recompile time rather than runtime. Perf benchmark to check this (https://hud.pytorch.org/benchmark/torchbench/inductor_with_cudagraphs?startTime=Fri,%2027%20Oct%202023%2017:42:43%20GMT&stopTime=Fri,%2003%20Nov%202023%2017:42:43%20GMT&granularity=hour&mode=training&dtype=amp&lBranch=gh/williamwen42/62/head&lCommit=f4724f5ffc6d17ceae513a42fc18627be7b85482&rBranch=main&rCommit=29f3d392bf230072e3bffae37b078e770cae1956). We may also need to verify this on benchmarks where guard fails are common.
Sample script:
```python
import torch
def generate_data(b):
return (
torch.randn(b, 3, 32, 32).to(torch.float32).cuda(),
torch.randint(1000, (b,)).cuda(),
)
from torchvision.models import resnet18
def init_model():
return resnet18().to(torch.float32).cuda()
model = init_model()
model_opt = torch.compile(model, dynamic=False)
for b in range(16, 32):
data = generate_data(b)
model_opt(data[0])
```
Sample logs:
```bash
(/data/users/williamwen/py310-env) [williamwen@devgpu020.odn1 /data/users/williamwen/pytorch (wwen/log-all-guards)]$ python playground5.py
/data/users/williamwen/pytorch/torch/_inductor/compile_fx.py:141: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
[2023-11-06 14:50:47,605] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2023-11-06 14:50:47,605] torch._dynamo.convert_frame: [WARNING] function: 'forward' (/data/users/williamwen/torchvision/torchvision/models/resnet.py:284)
[2023-11-06 14:50:47,605] torch._dynamo.convert_frame: [WARNING] last reason: tensor 'L['x']' size mismatch at index 0. expected 16, actual 24
[2023-11-06 14:50:47,605] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2023-11-06 14:50:47,605] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
(/data/users/williamwen/py310-env) [williamwen@devgpu020.odn1 /data/users/williamwen/pytorch (wwen/log-all-guards)]$ TORCH_LOGS="recompiles" python playground5.py
/data/users/williamwen/pytorch/torch/_inductor/compile_fx.py:141: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
[2023-11-06 14:53:31,591] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:53:31,591] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:53:31,591] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 17
[2023-11-06 14:53:41,333] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:53:41,333] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:53:41,333] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 18
[2023-11-06 14:53:41,333] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 18
[2023-11-06 14:53:50,463] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:53:50,463] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:53:50,463] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 19
[2023-11-06 14:53:50,463] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 19
[2023-11-06 14:53:50,463] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 19
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 19, actual 20
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 20
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 20
[2023-11-06 14:53:59,848] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 20
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 20, actual 21
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 19, actual 21
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 21
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 21
[2023-11-06 14:54:08,549] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 21
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 21, actual 22
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 20, actual 22
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 19, actual 22
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 22
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 22
[2023-11-06 14:54:17,795] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 22
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 22, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 21, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 20, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 19, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 23
[2023-11-06 14:54:27,430] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 23
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function forward in /data/users/williamwen/torchvision/torchvision/models/resnet.py:284
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 23, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 22, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 21, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 20, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 19, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 18, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 17, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 16, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.convert_frame: [WARNING] torch._dynamo hit config.cache_size_limit (8)
[2023-11-06 14:54:36,744] torch._dynamo.convert_frame: [WARNING] function: 'forward' (/data/users/williamwen/torchvision/torchvision/models/resnet.py:284)
[2023-11-06 14:54:36,744] torch._dynamo.convert_frame: [WARNING] last reason: tensor 'L['x']' size mismatch at index 0. expected 16, actual 24
[2023-11-06 14:54:36,744] torch._dynamo.convert_frame: [WARNING] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[2023-11-06 14:54:36,744] torch._dynamo.convert_frame: [WARNING] To diagnose recompilation issues, see https://pytorch.org/docs/master/compile/troubleshooting.html.
[2023-11-06 14:54:45,922] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:54:45,922] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:45,922] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 25
[2023-11-06 14:54:54,691] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:54:54,691] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:54:54,691] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 26
[2023-11-06 14:54:54,691] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 26
[2023-11-06 14:55:03,591] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:55:03,591] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:55:03,591] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 26, actual 27
[2023-11-06 14:55:03,591] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 27
[2023-11-06 14:55:03,591] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 27
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 27, actual 28
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 26, actual 28
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 28
[2023-11-06 14:55:12,384] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 28
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 28, actual 29
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 27, actual 29
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 26, actual 29
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 29
[2023-11-06 14:55:21,442] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 29
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 29, actual 30
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 28, actual 30
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 27, actual 30
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 26, actual 30
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 30
[2023-11-06 14:55:30,315] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 30
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] Recompiling function _forward_impl in /data/users/williamwen/torchvision/torchvision/models/resnet.py:266
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] triggered by the following guard failure(s):
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 30, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 29, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 28, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 27, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 26, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 25, actual 31
[2023-11-06 14:55:39,839] torch._dynamo.guards.__recompiles: [DEBUG] - tensor 'L['x']' size mismatch at index 0. expected 24, actual 31
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110325
Approved by: https://github.com/ezyang, https://github.com/jon-chuang
**Motivation**: We already have a `CompiledFunction` event that comes from the autograd.Function added by aot_autograd. However, this doesn't appear during inference, or if none of the inputs to a graph require grad. It also doesn't appear if your backend doesn't use aot_autograd.
This adds a profiler event that will always appear.
<img width="615" alt="Screenshot 2023-09-01 at 4 46 28 PM" src="https://github.com/pytorch/pytorch/assets/5067123/fed90ca9-a8e7-458c-80eb-b4160de55218">
Perf - increase in latency (with profiler turned off) was within noise when I measured a simple cpu-only torch-compiled function that returned `x.view_as(x)`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108462
Approved by: https://github.com/anijain2305
**Background**: "TorchDynamo Cache Lookup" events appear in traces to indicate a dynamo cache lookup; it's useful to check when cache lookups are taking a long time. To add a profiler event, one can use the `torch.profiler.record_function` context manager, or the C++ equivalent. Previously, the python version was used; first, when the profiler was enabled, callbacks for record_function_enter and record_function_exit were registered; then those would be called before and after every cache lookup.
**This PR**: Instead of calling the python bindings for `torch.profiler.record_function`, directly call the C++ implementation. This simplifies a lot of the code for binding C/C++. It also improves performance; previously there was a lot of overhead in the "TorchDynamo Cache Lookup" event, making the event artificially take a long time. After this change the events now appear shorter, because there's less overhead in starting/stopping the event: in other words, the profiler no longer distorts the results as much.
**Performance results**:
I ran using the script below on a cpu-only 1.6GHz machine. I report the median time (from 100 measurements) of a "TorchDynamo Cache Lookup" event before and after this PR. I think it is reasonable to consider the difference to be due to a reduction in overhead.
<details>
<summary>Benchmarking script</summary>
```python
def fn(x, y):
return (x * y).relu()
a, b = [torch.rand((4, 4), requires_grad=True) for _ in range(2)]
opt_fn = torch.compile(fn)
opt_fn(a, b)
opt_fn(a, b)
with torch.profiler.profile() as prof:
opt_fn(a, b)
```
</details>
Median before PR: 198-228 us (median of 100, measured 5 times)
Median after PR: 27us
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108436
Approved by: https://github.com/anijain2305, https://github.com/jansel
In https://github.com/pytorch/pytorch/pull/106673 , I created a private API `_debug_get_cache_entry_list` to help pull out cache entries from compiled functions.
Recently, I find that @anijain2305 commented in the code that this API should be revisited, and so I created this PR.
First, this API cannot be removed even if cache entry becomes a first-class python class`torch._C._dynamo.eval_frame._CacheEntry`. The facts that `extra_index` is static, and `get_extra_state` is inline static, make them not accessible elsewhere. This API `_debug_get_cache_entry_list` is the only way for users to get all the cache entries from code.
Second, since the`torch._C._dynamo.eval_frame._CacheEntry` class is a python class, I simplified the C-part code, and remove the necessity of creating a namedtuple for this in the python code.
Third, I also add a small improvement, that if the argument is a function, we can automatically pass its `__code__` to the API.
The above change will slightly change the output, from list of named tuple to list of `torch._C._dynamo.eval_frame._CacheEntry`. I will update the corresponding docs that use this API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108335
Approved by: https://github.com/jansel, https://github.com/anijain2305