mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Before the change in this PR, we have an error for the following code
```python
import torch
torch._dynamo.config.capture_scalar_outputs = True
class M(torch.nn.Module):
def forward(self, idx, x):
u0 = idx.item()
x0 = x.select(0, u0)
def fn():
return x0.sin()
return torch.cond(x0.sum() > 0, fn, fn)
m = M()
out = torch.compile(m, fullgraph=True)(torch.tensor(0, dtype=torch.int64), torch.randn(3, 3))
```
The error is caused when speculate fn, and tries to lift symbol of x0.storage_offset() but found the symbols doesn't have a source associated with it.
What really happens is that, when input tensor is a scalar tensor of int type and resides on CPU, we have a short cut that creates a norm symint when .item() is called see https://github.com/pytorch/pytorch/pull/126245.
However, previously, we only track the unbacked symint output of an operation because we believe all the backed symint must have a source associated with it and has already bee lifted as input at the top-level. Now this invariant no longer holds, so we end up an error saying the symbol doesn't have source (because only input and symbols derided from inputs have source and result of .item() doesn't have a source).
In this PR, we start to also track the normal symint with the proxy that created it (i.e. in this case the proxy .item()).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/161198
Approved by: https://github.com/zou3519
|
||
|---|---|---|
| .. | ||
| __init__.py | ||
| base.py | ||
| builder.py | ||
| builtin.py | ||
| constant.py | ||
| ctx_manager.py | ||
| dicts.py | ||
| distributed.py | ||
| functions.py | ||
| higher_order_ops.py | ||
| iter.py | ||
| lazy.py | ||
| lists.py | ||
| misc.py | ||
| nn_module.py | ||
| optimizer.py | ||
| script_object.py | ||
| sdpa.py | ||
| tensor.py | ||
| torch_function.py | ||
| torch.py | ||
| user_defined.py | ||