move benchmarking out of `torch._inductor.runtime.runtime_utils` and into `torch._inductor.runtime.benchmarking`, and prefer this path over directly accessing Triton's benchmarking
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132827
Approved by: https://github.com/eellison
Adds support for SymInts in the FakeTensor cache.
A couple notes:
1. When a SymInt is present in the input key for a FakeTensor operation we cache on the ShapeEnv instead of using the FakeTensorMode cache. This is necessary so we don't have to remember and check the guards. It reduces the cache hits but there's diminishing return on how much work we can do before the cache becomes more of a burden than a gain.
2. We need to be careful that when we cache an output SymInt that is a direct copy from the input that when we have a cache-hit we copy the SymNode from the input to the output. This is important because the fx-graph building code actually uses SymNode ids in the process of building the graph so constructing a same-content-but-different-id SymNode will fail.
3. In the cache key we store SymInts as a _PySymInputStub. These represent SymInt (and friends) but support `__hash__` and `__eq__` (which SymInt do not).
4. In the cache entry we store SymInts as a _SymIntOutputStub.
Perf example:
```
python benchmarks/dynamo/timm_models.py --ci --accuracy --timing
--explain --inductor --dynamic-shapes --dynamic-batch-only --device cuda
--training --amp --total-partitions 2 --partition-id 0 --output
/tmp/training_timm_models.csv --filter crossvit_9_240
```
fake tensor cache before:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 68137
INFO: cache_misses: 837
INFO: cache_bypasses:
INFO: symbolic shape: 48224
INFO: CompositeImplicitAutograd: 917
INFO: non-fake tensor: 70
INFO: non-FakeTensor output: 62
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
and after:
```
INFO: FakeTensor cache stats:
INFO: cache_hits: 88187
INFO: cache_misses: 14233
INFO: cache_bypasses:
INFO: CompositeImplicitAutograd: 1037
INFO: non-FakeTensor output: 602
INFO: non-fake tensor: 70
INFO: unsafe view: 36
INFO: non-builtin: 8
INFO: dynamic output shape: 1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127596
Approved by: https://github.com/eellison
ghstack dependencies: #131014, #129780
Summary:
Add '`TORCH_LOGS=+fsdp`' in the CLI to print fsdp logs
Example:
`TORCH_LOGS=+fsdp torchrun --standalone --nproc_per_node=2 run_fsdp.py`
Description:
Add logging to `FSDPParamGroup.pre_forward`, `FSDPParamGroup.post_forward`, `FSDPParamGroup.pre_backward`, and `FSDPParamGroup.post_backward`, `FSDPState._root_pre_forward` if is the root, and `FSDPState._root_post_backward_final_callback`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128663
Approved by: https://github.com/weifengpy, https://github.com/awgu
e.g. dist_ddp -> ddp
'distributed' shortcut remains unchained
Feedback has been that it is not appealing to have the dist_ prefix,
and the main reason for it was to keep the distributed shortcuts grouped
together in the help menu. It's nice to have shorter shortcuts.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126499
Approved by: https://github.com/XilunWu, https://github.com/kwen2501
ghstack dependencies: #126322
- sets it as a fake stack trace as we don't have a generic comment feature
- when verbose is disabled, still adds a contextmanager and flag checks. the alternative is to use MACROS, but that wouldn't be usable with TORCH_LOGS
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124954
Approved by: https://github.com/jansel
Partially fixes https://github.com/pytorch/pytorch/issues/105077
Repro:
```python
import tempfile
import torch
from torch._subclasses import fake_tensor
class TheModelClass(torch.nn.Module):
def __init__(self):
super(TheModelClass, self).__init__()
self.fc1 = torch.nn.Linear(5, 10)
def forward(self, x):
return self.fc1(x)
with tempfile.NamedTemporaryFile() as state_dict_file:
# Create state_dict to be loaded later
model = TheModelClass()
torch.save(model.state_dict(), state_dict_file.name)
fake_mode = fake_tensor.FakeTensorMode()
with fake_mode:
# This is where the bug is triggered
state_dict = torch.load(state_dict_file.name)
```
Error:
```bash
Traceback (most recent call last):
File "issue_gh_torch_105077.py", line 22, in <module>
state_dict = torch.load(state_dict_file.name)
File "/opt/pytorch/torch/serialization.py", line 1014, in load
return _load(opened_zipfile,
File "/opt/pytorch/torch/serialization.py", line 1422, in _load
result = unpickler.load()
File "/opt/pytorch/torch/_utils.py", line 205, in _rebuild_tensor_v2
tensor = _rebuild_tensor(storage, storage_offset, size, stride)
File "/opt/pytorch/torch/_utils.py", line 184, in _rebuild_tensor
return t.set_(storage._untyped_storage, storage_offset, size, stride)
File "/opt/pytorch/torch/utils/_stats.py", line 20, in wrapper
return fn(*args, **kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1288, in __torch_dispatch__
return self.dispatch(func, types, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1468, in dispatch
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
File "/opt/pytorch/torch/_subclasses/fake_tensor.py", line 1733, in invalidate_written_to_constants
_, new_kwargs = normalize_function(
File "/opt/pytorch/torch/fx/operator_schemas.py", line 297, in normalize_function
torch_op_schemas = get_signature_for_torch_op(target)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in get_signature_for_torch_op
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 167, in <listcomp>
signatures = [_torchscript_schema_to_signature(schema) for schema in schemas]
File "/opt/pytorch/torch/fx/operator_schemas.py", line 70, in _torchscript_schema_to_signature
arg_type = _torchscript_type_to_python_type(arg.type)
File "/opt/pytorch/torch/fx/operator_schemas.py", line 64, in _torchscript_type_to_python_type
return eval(ts_type.annotation_str, _type_eval_globals)
File "<string>", line 1, in <module>
NameError: name 'Storage' is not defined
```
This PR adds the ability to create fake tensors during `torch.load` by wrapping the `torch.tensor.set_` call around a `torch.utils._mode_utils.no_dispatch()` to skip fake mode dispatcher for it and thus create a real tensor. It later calls `fake_mode.from_tensor(t)` to finally create the fake tensor.
Co-authored-by: Edward Z. Yang <ezyang@mit.edu>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108186
Approved by: https://github.com/ezyang
This PR is the start to enable the integrate pytorch distributed logs in Torch LOGs. We now already have one tag "distributed" for all distributed components but distributed is a very large component and we want to have some hierarchy and give users options to only turn on logs for certain submodules. So we also added tags starting with "dist_*" for each submodule. (This PR only adds some of them and we are going to add more down the road)
Related discussions can be found here: https://github.com/pytorch/pytorch/issues/113544
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116434
Approved by: https://github.com/awgu, https://github.com/wanchaol
There are now 3 ways to see logs from ddpoptimzer.
1) TORCH_LOGS="distributed"
2) TORCH_LOGS="dynamo"
3) TORCH_LOGS="torch._dynamo.backends.distributed"
(1 and 2 are different supersets of 3 that also include other content)
Note: ddp_graphs is still a separate 'artifact' logger, which just
includes graph dumps from the graph-splitting process.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114376
Approved by: https://github.com/wanchaol
Followup to https://github.com/pytorch/pytorch/pull/110325 - re-add the `report_all_guard_failures config` as a logging artifact `recompiles_verbose` with the following changes:
- evaluating the check must be wrapped with exception handling because subsequent code parts following the first failure may result in errors if evaluated (e.g. if a guard checks first for size, then tries to index - a guard failure due to insufficient size would result in an index error for the latter check).
- Adding a test for this case
Sample:
```python
import torch
def fn(x):
return torch.rand(x[-1], len(x))
opt_fn = torch.compile(fn)
opt_fn([4, 5, 6])
opt_fn([7, 8])
opt_fn([9])
```
Output (with `TORCH_LOGS="recompiles_verbose"`):
```bash
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
[2023-11-15 16:13:26,741] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][1] == 5
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] Recompiling function fn in /data/users/williamwen/pytorch/playground5.py:15
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] triggered by the following guard failure(s):
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 0 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 2
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG]
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] guard 1 failures:
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - len(L['x']) == 3
[2023-11-15 16:13:26,970] torch._dynamo.guards.__recompiles_verbose: [DEBUG] - L['x'][0] == 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113585
Approved by: https://github.com/jon-chuang, https://github.com/ezyang
When SymNode was refactored into its own module, this broke logging for this file, as the `dynamic` alias no longer covered it. This PR adds supports for an alias to point to multiple qualified module names. To drive the refactor, I renamed `log_alias_to_log_qname` to `log_alias_to_log_qnames` and then audited all use sites. I invite you to do so as well.
For good measure, I also add dynamic to dynamo, so that I always get dynamic logs when dynamo is enabled. Empirically this will be helpful because people keep sending me dynamo debug logs that don't have dynamic logs.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113567
Approved by: https://github.com/Skylion007, https://github.com/lezcano, https://github.com/mlazos
ghstack dependencies: #113566
It looks like this:
```
[DEBUG] GUARD: ___check_type_id(L['z'][L["MyEnum"].BAR], 7640416) and L['z'][L["MyEnum"].BAR] == 10
[DEBUG] Stack:
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 6657, in <module>
[DEBUG] run_tests()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/test_case.py", line 38, in run_tests
[DEBUG] run_tests()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 985, in run_tests
[DEBUG] unittest.main(argv=argv)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/main.py", line 101, in __init__
[DEBUG] self.runTests()
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/main.py", line 271, in runTests
[DEBUG] self.result = testRunner.run(self.test)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/runner.py", line 184, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 84, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/suite.py", line 122, in run
[DEBUG] test(result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 650, in __call__
[DEBUG] return self.run(*args, **kwds)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2521, in run
[DEBUG] self._run_with_retry(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2450, in _run_with_retry
[DEBUG] super_run(result=result)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 591, in run
[DEBUG] self._callTestMethod(testMethod)
[DEBUG] File "/home/ezyang/local/b/pytorch-env/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
[DEBUG] method()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/testing/_internal/common_utils.py", line 2377, in wrapper
[DEBUG] method(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 2529, in test_enum_as_dict_key_with_overloaded_str
[DEBUG] res = opt_fn(x)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 333, in _fn
[DEBUG] return fn(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/test/dynamo/test_misc.py", line 2519, in fn
[DEBUG] torch._dynamo.graph_break()
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/eval_frame.py", line 493, in catch_errors
[DEBUG] return callback(frame, cache_size, hooks, frame_state)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 637, in _convert_frame
[DEBUG] result = inner_convert(frame, cache_size, hooks, frame_state)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 133, in _fn
[DEBUG] return fn(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 371, in _convert_frame_assert
[DEBUG] return _compile(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 567, in _compile
[DEBUG] guarded_code = compile_inner(code, one_graph, hooks, transform)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/utils.py", line 181, in time_wrapper
[DEBUG] r = func(*args, **kwargs)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 466, in compile_inner
[DEBUG] out_code = transform_code_object(code, transform)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
[DEBUG] transformations(instructions, code_options)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/convert_frame.py", line 416, in transform
[DEBUG] tracer = InstructionTranslator(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2018, in __init__
[DEBUG] self.symbolic_locals = collections.OrderedDict(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/symbolic_convert.py", line 2021, in <genexpr>
[DEBUG] VariableBuilder(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 211, in __call__
[DEBUG] vt = self._wrap(value).clone(**self.options())
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 404, in _wrap
[DEBUG] result = {
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 405, in <dictcomp>
[DEBUG] k: VariableBuilder(
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 211, in __call__
[DEBUG] vt = self._wrap(value).clone(**self.options())
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 354, in _wrap
[DEBUG] return type_dispatch(self, value)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 837, in wrap_literal
[DEBUG] return self.wrap_unspecialized_primitive(value)
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 1073, in wrap_unspecialized_primitive
[DEBUG] guards=self.make_guards(GuardBuilder.CONSTANT_MATCH),
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 269, in make_guards
[DEBUG] return {source.make_guard(guard) for guard in guards}
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_dynamo/variables/builder.py", line 269, in <setcomp>
[DEBUG] return {source.make_guard(guard) for guard in guards}
[DEBUG] File "/data/users/ezyang/b/pytorch/torch/_guards.py", line 641, in make_guard
[DEBUG] return Guard(self.name(), self.guard_sou
```
One downside is I can't report *why* the guard was added. I'm not entirely sure how to do this; the problem is guards will propagate to a bunch of variables before finally getting included as part of the final set. Maybe a very very verbose version could report stack traces at every handoff point.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107388
Approved by: https://github.com/mlazos
ghstack dependencies: #107438, #107358
Since Python 3.11 bytecode contains endline and column information, for each bytecode, we attribute the source code corresponding to the bytecode in a more accurate way. For example, we can highlight a function call in a series of nested function calls, or highlight a function call spanning multiple lines.
Sample:
```python
import torch
import torch._dynamo
from functorch.experimental.control_flow import cond
def h(x):
return x * 5
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f(pred, x):
x = h(
h(h(x))
)
x = x[1:][:2]
torch._dynamo.graph_break()
x = cond(pred, true_fn, false_fn, [x])
opt_f = torch.compile(f, backend="eager")
opt_f(torch.tensor(True), torch.randn(3, 3, 3, 3))
```
Output:
```
$ TORCH_LOGS="trace_call" python playground9.py
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^
TRACE FX call mul from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:16
h(h(x))
~^^^^^^
TRACE FX call mul_1 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE inlined call h from f /scratch/williamwen/work/pytorch/playground9.py:15
x = h(
~^
h(h(x))
^^^^^^^
)
^
TRACE FX call mul_2 from h /scratch/williamwen/work/pytorch/playground9.py:6 (inline depth: 1)
return x * 5
~~^~~
TRACE FX call getitem from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~^^^^
TRACE FX call getitem_1 from f /scratch/williamwen/work/pytorch/playground9.py:18
x = x[1:][:2]
~~~~~^^^^
TRACE inlined call true_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from true_fn /scratch/williamwen/work/pytorch/playground9.py:9 (inline depth: 1)
return x * 2
~~^~~
TRACE inlined call false_fn from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TRACE FX call mul from false_fn /scratch/williamwen/work/pytorch/playground9.py:12 (inline depth: 1)
return x * 3
~~^~~
TRACE FX call cond from <resume in f> /scratch/williamwen/work/pytorch/playground9.py:20
x = cond(pred, true_fn, false_fn, [x])
~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104676
Approved by: https://github.com/ezyang
Prototype for the feature request:
>When working on a codebase that is unfamiliar to you, it can be helpful to single step through all of the code to see what is getting executed, what conditional branches are taken, and where indirect function jumps go. Model x-ray uses dynamo to give you a single step log of every source code line that does something relevant (i.e., a Tensor operation)
Dynamo logs to the ~`starts_line`~ `trace_source` logging artifact at the start of tracing new bytecode with a new line. It logs the line of source code associated with that bytecode.
~~Dynamo logs to the `graph_source` logging when a FX GraphModule is constructed. For each node in the graph, it logs the location of the original source code associated with that node.~~
Development notes: https://docs.google.com/document/d/1LjFeHzCgDDt535QUq5HydcQs56d7jWl5RvW8TLZN19g/edit?usp=sharing
Since the draft, we removed the `graph_source` logging artifact since printing the code of `GraphModule`s already displays the original source.
Sample:
```python
import torch
from functorch.experimental.control_flow import cond
def true_fn(x):
return x * 2
def false_fn(x):
return x * 3
def f_cond(pred, x):
return cond(pred, true_fn, false_fn, [x])
def f_outer(pred, x):
y = f_cond(pred, x)
if x.sum() > 0:
x = x * 2
else:
x = x * 3
return x, y
opt_f_cond = torch.compile(f_outer, backend="eager")
opt_f_cond(torch.tensor(True), torch.randn(3, 3))
```
Logs:
```shell
$ TORCH_LOGS="trace_source" python playground8.py
TRACE starts_line f_outer playground8.py:54
def f_outer(pred, x):
TRACE starts_line f_outer playground8.py:55
y = f_cond(pred, x)
TRACE starts_line f_cond playground8.py:51 (inline depth: 1)
def f_cond(pred, x):
TRACE starts_line f_cond playground8.py:52 (inline depth: 1)
return cond(pred, true_fn, false_fn, [x])
TRACE starts_line true_fn playground8.py:45 (inline depth: 2)
def true_fn(x):
TRACE starts_line true_fn playground8.py:46 (inline depth: 2)
return x * 2
TRACE starts_line false_fn playground8.py:48 (inline depth: 2)
def false_fn(x):
TRACE starts_line false_fn playground8.py:49 (inline depth: 2)
return x * 3
TRACE starts_line f_outer playground8.py:56
if x.sum() > 0:
TRACE starts_line <resume in f_outer> playground8.py:56
if x.sum() > 0:
TRACE starts_line <resume in f_outer> playground8.py:57
x = x * 2
TRACE starts_line <resume in f_outer> playground8.py:60
return x, y
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104013
Approved by: https://github.com/ezyang
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh