pytorch/torch/_dynamo/variables
Yidi Wu ba6ce66698 [dynamo] lift backed symint output of item() (#161198)
Before the change in this PR, we have an error for the following code
```python
import torch

torch._dynamo.config.capture_scalar_outputs = True

class M(torch.nn.Module):
    def forward(self, idx, x):
        u0 = idx.item()
        x0 = x.select(0, u0)
        def fn():
            return x0.sin()
        return torch.cond(x0.sum() > 0, fn, fn)

m = M()
out = torch.compile(m, fullgraph=True)(torch.tensor(0, dtype=torch.int64), torch.randn(3, 3))
```

The error is caused when speculate fn, and tries to lift symbol of x0.storage_offset() but found the symbols doesn't have a source associated with it.

What really happens is that, when input tensor is a scalar tensor of int type and resides on CPU, we have a short cut that creates a norm symint when .item() is called see https://github.com/pytorch/pytorch/pull/126245.

However, previously, we only track the unbacked symint output of an operation because we believe all the backed symint must have a source associated with it and has already bee lifted as input at the top-level. Now this invariant no longer holds, so we end up an error saying the symbol doesn't have source (because only input and symbols derided from inputs have source and result of .item() doesn't have a source).

In this PR, we start to also track the normal symint with the proxy that created it (i.e. in this case the proxy .item()).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161198
Approved by: https://github.com/zou3519
2025-08-26 17:06:54 +00:00
..
__init__.py [itertools] Implement itertools.cycle with a polyfill (#159102) 2025-07-31 23:28:57 +00:00
base.py [dynamo][guards] More small guard optimizations (#159345) 2025-07-29 18:36:49 +00:00
builder.py [dynamo] lift backed symint output of item() (#161198) 2025-08-26 17:06:54 +00:00
builtin.py [dynamo] Support method calls on complex ConstantVariables (#161122) 2025-08-22 21:40:03 +00:00
constant.py [dynamo] Support method calls on complex ConstantVariables (#161122) 2025-08-22 21:40:03 +00:00
ctx_manager.py Add support for tracing vmap in pre-dispatch export (#154650) 2025-08-20 19:31:07 +00:00
dicts.py [BE] [dynamo] Simplify two methods in ConstDictVariable (#159361) 2025-08-22 11:11:30 +00:00
distributed.py [dynamo][dist] trace DeviceMesh's get_local_rank and get_rank as constants (#160805) 2025-08-20 01:12:24 +00:00
functions.py [dynamo] [guard] Add caching for inside torch.compile.disable function to avoid unnecessary recompilation. (#160934) 2025-08-19 06:01:26 +00:00
higher_order_ops.py [hop][exc] make UncapturedHigherOrderOpError print user code and avoid re-raise (#159296) 2025-08-11 22:48:10 +00:00
iter.py Fixes for collections.NamedTuple (#159367) 2025-08-18 17:32:59 +00:00
lazy.py [dynamo] Avoid recompiling over unused objects (#156891) 2025-07-09 20:14:34 +00:00
lists.py Fixes for collections.NamedTuple (#159367) 2025-08-18 17:32:59 +00:00
misc.py [dynamo] propagate tensor metadata on Tensor.__setitem__(tensor) (#161036) 2025-08-22 04:43:22 +00:00
nn_module.py [dynamo] Trace nn.Module __delattr__ (#159969) 2025-08-06 23:43:19 +00:00
optimizer.py Allow bypasses for Precompile when guards, etc. cannot be serialized (#160902) 2025-08-21 18:20:42 +00:00
script_object.py [dynamo] Replace unimplemented with unimplemented_v2 in torch/_dynamo/variables/script_object.py (#159343) 2025-08-01 21:30:41 +00:00
sdpa.py [Dynamo][Misc] Apply typing hints for codegen (#150289) 2025-04-04 14:26:22 +00:00
tensor.py [dynamo] propagate tensor metadata on Tensor.__setitem__(tensor) (#161036) 2025-08-22 04:43:22 +00:00
torch_function.py [dynamo] Be consistent with UserMethodVariable source (#160155) 2025-08-09 04:16:14 +00:00
torch.py [dynamo] Pass requires_grad to nn.Parameter construction (#161364) 2025-08-25 16:49:28 +00:00
user_defined.py [dynamo] Fix graph break on calling functions decorated with special context manager (#160703) 2025-08-18 20:33:45 +00:00