Yuanhao Ji
d6b3ad4de2
[Dynamo] Replace torch._dynamo.optimize() with torch.compile() [2/N] ( #140238 )
...
related commits:
- #139706
- #140238
- #140247
- #140253
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140238
Approved by: https://github.com/soulitzer
2024-11-13 05:13:39 +00:00
Yuanhao Ji
5aadaaf2b5
[Dynamo] Allow filter() to handle infinite iterator ( #138305 )
...
Fixes #137380
```python
import torch
def filt(x):
return x < 10
@torch.compile(backend="eager", fullgraph=True)
def f(x):
x = x + 1
return zip(range(3), filter(filt, itertools.count()))
print(list(f(torch.ones(3)))) # [(0, 0), (1, 1), (2, 2)]
@torch.compile(backend="eager")
def g(x):
x = x + 1
return filter(filt, [1, 2, 3])
res = g(torch.ones(3))
assert isinstance(res, filter)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138305
Approved by: https://github.com/williamwen42
2024-11-12 17:32:56 +00:00
Yuanhao Ji
b9618c9b88
[Dynamo] Add itertools.compress() support ( #139061 )
...
Use polyfill to add `itertools.compress()` support in Dynamo.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139061
Approved by: https://github.com/jansel
2024-10-29 10:25:55 +00:00
Michael Lazos
0a304d9048
[Dynamo] Handle extracted unbound tensor methods ( #137227 )
...
fixes2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42 , https://github.com/anijain2305
ghstack dependencies: #137114 , #137115 , #137116 , #137117 , #137120
2024-10-09 02:29:40 +00:00
PyTorch MergeBot
76c5bdd2cc
Revert "[Dynamo] Handle extracted unbound tensor methods ( #137227 )"
...
This reverts commit 14eabd6915 .
Reverted https://github.com/pytorch/pytorch/pull/137227 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137227#issuecomment-2400406384 ))
2024-10-08 17:12:41 +00:00
Michael Lazos
14eabd6915
[Dynamo] Handle extracted unbound tensor methods ( #137227 )
...
fixes2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137227
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114 , #137115 , #137116 , #137117 , #137120
2024-10-07 18:55:26 +00:00
Salman Mohammadi
48c18ff850
[dynamo] Added support for tensor's is_inference method ( #136450 )
...
Fixes #135439
This PR adds support for the `is_inference` method on torch tensors which successfully compiles the following example fn without graph breaks:
```python
def fn_simple(x):
if x.is_inference():
return x.sum()
else:
return x.min()
```
I've also tried to add guards on the tensor to guard against `is_inference`. I wasn't 100% sure where these should go so please don't hesitate to correct me.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136450
Approved by: https://github.com/ezyang
2024-09-27 09:15:07 +00:00
Animesh Jain
5c38aa72c0
[dynamo][dicts][nv-embed] Support update with kwargs ( #135588 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135588
Approved by: https://github.com/yanboliang
2024-09-10 23:50:23 +00:00
torotoki
6d7cbc20d2
Add dynamo itertools.pairwise support ( #135416 )
...
Fixes #133766
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135416
Approved by: https://github.com/XuehaiPan , https://github.com/jansel
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
2024-09-10 11:37:59 +00:00
Yanbo Liang
d81731615f
[Dynamo] Adding CallFunctionNoArgsSource and ( #135425 )
...
CallFunctionNoArgsGuardAccessor to support torch.cuda.current_device()
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135425
Approved by: https://github.com/anijain2305
2024-09-09 22:46:00 +00:00
William Wen
a4030e37be
[dynamo] reland map/zip iterator related changes ( #135074 )
...
Differential Revision: [D62211019](https://our.internmc.facebook.com/intern/diff/D62211019 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135074
Approved by: https://github.com/jansel , https://github.com/anijain2305 , https://github.com/mlazos
2024-09-06 20:38:02 +00:00
Xinyu
58f2477a26
[Dynamo] Support builtin function frozenset ( #134563 )
...
Support builtin function frozenset in dynamo
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134563
Approved by: https://github.com/anijain2305 , https://github.com/EikanWang , https://github.com/jansel
2024-09-05 12:15:10 +00:00
Michael Lazos
4e71418566
[dynamo] rewrite addcmul_ to remove graph break ( #134168 )
...
Context: Adding support for the beta parameters to be tensors
Details: Similarly to the previous two PRs addcmul_ is used with the tensor betas as the value argument. When this occurs, an item() call is invoked in the aten op. To avoid this graph break, addcmul_ is decomposed into its constrituent ops to avoid this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134168
Approved by: https://github.com/anijain2305
ghstack dependencies: #134166 , #134167
2024-08-31 10:24:39 +00:00
Michael Lazos
3fb4c6bc38
[dynamo] Rewrite foreach pow to broadcast scalar argument ( #134167 )
...
Context: Adding support for the beta parameters to be tensors
Details:
In this PR similarly to the previous, foreach_pow calls item() on the first argument when it is a scalar tensor. In this case, we broadcast that scalar tensor into a list of aliases of that tensor to avoid the item() call, and this results in a device copy of the scalar tensor. Once again, I dont think we can change the foreach_pow API due to BC concerns, so this op rewrite allows us to avoid a graph break, generate semantically the same code, and not affect eager.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134167
Approved by: https://github.com/anijain2305
ghstack dependencies: #134166
2024-08-31 10:24:35 +00:00
Michael Lazos
471c33f007
[dynamo] Rewrite foreach_lerp to avoid aten item call ( #134166 )
...
Context: Adding support for the beta parameters to be tensors
Details:
In order to add support for the beta params to be tensors without graph breaks in the Adam family of optimizers it is necessary to support foreach_lerp(x, y, s) where s is a scalar tensor. Today, this isn't possible because when `s` is a scalar, internally the aten op calls item() on it to extract the value and distribute it to each of the ops on the individual list indices. To support this in dynamo without graph breaks, I decompose the lerp into its constituent ops which support a scalar tensor in the list argument positions which do not result in an item() call. To be clear the item() call is more performant for eager I think and for BC I don't think we can modify that API, so this allows us to have performance in eager and no graph breaks in compile.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134166
Approved by: https://github.com/anijain2305
2024-08-31 10:24:31 +00:00
Yanbo Liang
090d9cf410
[Dynamo][autograd.Function][vmap] support torch._C._are_functorch_transforms_active ( #134889 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134889
Approved by: https://github.com/jansel
2024-08-31 04:39:09 +00:00
Animesh Jain
594162f7ab
[dynamo] Support reading attributes from pybind objects ( #134630 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134630
Approved by: https://github.com/jansel
2024-08-29 15:06:52 +00:00
Animesh Jain
2bf622685d
[dynamo][dicts] Support hasattr on dicts ( #134590 )
...
Fixes - https://github.com/pytorch/pytorch/issues/134577
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590
Approved by: https://github.com/Skylion007
ghstack dependencies: #134610
2024-08-29 09:14:42 +00:00
PyTorch MergeBot
67d7040fce
Revert "[dynamo][dicts] Support hasattr on dicts ( #134590 )"
...
This reverts commit c566f2465f .
Reverted https://github.com/pytorch/pytorch/pull/134590 on behalf of https://github.com/ZainRizvi due to Sorry, I had to revert this in order to revert another PR ([comment](https://github.com/pytorch/pytorch/pull/134610#issuecomment-2316568553 ))
2024-08-29 02:02:12 +00:00
Animesh Jain
c566f2465f
[dynamo][dicts] Support hasattr on dicts ( #134590 )
...
Fixes - https://github.com/pytorch/pytorch/issues/134577
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590
Approved by: https://github.com/Skylion007
ghstack dependencies: #134610
2024-08-28 07:35:18 +00:00
PyTorch MergeBot
30094bedbc
Revert "[dynamo][dicts] Support hasattr on dicts ( #134590 )"
...
This reverts commit d23c0150f3 .
Reverted https://github.com/pytorch/pytorch/pull/134590 on behalf of https://github.com/anijain2305 due to causing trunk CI failures ([comment](https://github.com/pytorch/pytorch/pull/134590#issuecomment-2313705582 ))
2024-08-27 22:52:52 +00:00
Animesh Jain
d23c0150f3
[dynamo][dicts] Support hasattr on dicts ( #134590 )
...
Fixes - https://github.com/pytorch/pytorch/issues/134577
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134590
Approved by: https://github.com/Skylion007
ghstack dependencies: #134039
2024-08-27 20:43:40 +00:00
Yanbo Liang
7868b65c4d
[Dynamo] Support dict.setdefault ( #134083 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134083
Approved by: https://github.com/williamwen42
2024-08-22 01:57:33 +00:00
Animesh Jain
bd0db490bf
[dynamo][set] Fix EQUALS_MATCH guard for constant sets and lists ( #134016 )
...
Fixes https://github.com/pytorch/pytorch/issues/133509
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134016
Approved by: https://github.com/laithsakka , https://github.com/jansel
ghstack dependencies: #133742
2024-08-21 12:41:52 +00:00
Isuru Fernando
e554f71d7e
Implement filter in dynamo ( #131674 )
...
Fixes https://github.com/pytorch/pytorch/issues/128944
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131674
Approved by: https://github.com/amjames , https://github.com/jansel
2024-08-14 14:54:13 +00:00
Yanbo Liang
9de023d44d
[Dynamo] Make torch.Size can be reconstructed by LOAD_CONST ( #133342 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133342
Approved by: https://github.com/mlazos , https://github.com/jansel
2024-08-13 23:18:38 +00:00
xinyu-intel
5ae979ab10
[Dynamo] Support torch.autograd._is_checkpoint_valid ( #132611 )
...
Hi, we got `torch._dynamo.exc.Unsupported: torch.* op returned non-Tensor bool call_function <function _is_checkpoint_valid at 0x7f0b0d22e290>` while tracing activation [checkpointing function in deepspeed](324ee65cb0/deepspeed/runtime/activation_checkpointing/checkpointing.py (L630) ). Consider to add it to constant_folding list which is similar with https://github.com/pytorch/pytorch/pull/126196
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132611
Approved by: https://github.com/anijain2305 , https://github.com/williamwen42
2024-08-08 04:05:08 +00:00
Animesh Jain
194ec49d27
[dynamo][lists][stable diffusion] Do not add source on list slice ( #132912 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132912
Approved by: https://github.com/williamwen42
ghstack dependencies: #132806 , #132899
2024-08-08 02:23:07 +00:00
William Wen
01cdcbf7c8
[dynamo] revert map/zip iterator related changes ( #132528 )
...
Need to revert due to internal hangs: S437700
This reverts commit b6c1490cc0 .
Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725 )"
This reverts commit 2576dbbc35 .
Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716 )"
This reverts commit 35b4de32fa .
Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413 )"
This reverts commit 7d282d8755 .
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528
Approved by: https://github.com/ZainRizvi
2024-08-04 18:46:55 +00:00
PyTorch MergeBot
0a25666f92
Revert "[dynamo] revert map/zip iterator related changes ( #132528 )"
...
This reverts commit e81e74ca6c .
Reverted https://github.com/pytorch/pytorch/pull/132528 on behalf of https://github.com/ZainRizvi due to This stack entered a weird state in the diff train. Reverting and relanding to clean the state ([comment](https://github.com/pytorch/pytorch/pull/132528#issuecomment-2267628475 ))
2024-08-04 18:26:09 +00:00
William Wen
e81e74ca6c
[dynamo] revert map/zip iterator related changes ( #132528 )
...
Need to revert due to internal hangs: S437700
This reverts commit b6c1490cc0 .
Revert "[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate (#131725 )"
This reverts commit 2576dbbc35 .
Revert "[dynamo] add itertools repeat/count bytecode reconstruction (#131716 )"
This reverts commit 35b4de32fa .
Revert "[dynamo] add lazy IteratorVariable implementations for map and zip (#131413 )"
This reverts commit 7d282d8755 .
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132528
Approved by: https://github.com/ZainRizvi
2024-08-02 19:40:57 +00:00
Oguz Ulgen
920f0426ae
Add None return type to init -- tests rest ( #132376 )
...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132376
Approved by: https://github.com/jamesjwu
ghstack dependencies: #132335 , #132351 , #132352
2024-08-01 15:44:51 +00:00
datagero
bdd7a0322d
[Dynamo] Fix - str handler for UserDefinedObjectVariable ( #130506 )
...
Fixes #130301
Adjusted the call_str method to handle str conversion for UserDefinedObjectVariable.
Attempt in a clean branch for unrelated test errors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130506
Approved by: https://github.com/oulgen , https://github.com/anijain2305
2024-07-31 16:39:59 +00:00
Animesh Jain
03e058189e
[dynamo] Support dict unpack of MutableMapping objects ( #131961 )
...
Fixes https://github.com/pytorch/pytorch/issues/128067
The basic functionality was alredy introduced earlier. This just ensures
that we support UserDefinedObjectVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131961
Approved by: https://github.com/williamwen42 , https://github.com/mlazos , https://github.com/yanboliang
ghstack dependencies: #131827 , #131956
2024-07-30 05:49:58 +00:00
William Wen
b6c1490cc0
[dynamo] make more unpack_var_sequence calls forced ( #132069 )
...
Fixes [T197204962](https://www.internalfb.com/intern/tasks/?t=197204962 ) (example failure: https://www.internalfb.com/intern/testinfra/diagnostics/11540474088277914.281475138576374.1722221031/ )
Added tests contain a simple repro for the observed failure (`test_map_unpack_vars`).
Also fixes https://github.com/pytorch/pytorch/issues/132044
Differential Revision: [D60420335](https://our.internmc.facebook.com/intern/diff/D60420335 )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132069
Approved by: https://github.com/anijain2305
2024-07-30 02:30:08 +00:00
Chengji Yao
d47c470f47
[dynamo] implement var_getattr in UserFunctionVariable ( #130413 )
...
This PR addresses the `getattr` of UserFunctionVariable. Although this usage is uncommon, it does appear in [Megatron's code](https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py#L635 ).
```
def linear_with_grad_accumulation_and_async_allreduce(...):
....
if not linear_with_grad_accumulation_and_async_allreduce.warned:
....
....
linear_with_grad_accumulation_and_async_allreduce.warned = False
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130413
Approved by: https://github.com/yanboliang
2024-07-29 08:29:59 +00:00
Xuehai Pan
918ece4f4d
[BE][Easy][11/19] enforce style for empty lines in import segments in test/dy*/ ( #129762 )
...
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501 . Most changes are auto-generated by linter.
You can review these PRs via:
```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762
Approved by: https://github.com/anijain2305
2024-07-27 17:43:53 +00:00
William Wen
2576dbbc35
[dynamo] implement IteratorVariable and polyfill fallbacks for enumerate ( #131725 )
...
Fixes https://github.com/pytorch/pytorch/issues/112794 .
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131725
Approved by: https://github.com/anijain2305
ghstack dependencies: #131413 , #131716
2024-07-26 17:17:09 +00:00
William Wen
35b4de32fa
[dynamo] add itertools repeat/count bytecode reconstruction ( #131716 )
...
Also fix bugs in the count iterator variable implementation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131716
Approved by: https://github.com/anijain2305
ghstack dependencies: #131413
2024-07-26 17:17:09 +00:00
Yanbo Liang
e76e566cfb
[Dynamo] Support zip_longest ( #131497 )
...
Fixes #121348
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131497
Approved by: https://github.com/mlazos , https://github.com/jansel , https://github.com/zou3519
2024-07-26 14:06:10 +00:00
William Wen
7d282d8755
[dynamo] add lazy IteratorVariable implementations for map and zip ( #131413 )
...
Fixes https://github.com/pytorch/pytorch/issues/130750 .
Repro of lazy/eager `map` discrepancy without `islice`:
```python
def fn(a, b):
y = 1
def f(x):
nonlocal y
y += 1
return x
l = list(zip([a, b], map(f, [1, 2, 3, 4])))
return a + y
```
The major change is that we implement `MapVariable` and `ZipVariable` based on `IteratorVariable`. Before, `map` and `zip` were being traced by immediately unpacking the result as a `TupleVariable`, which is wrong in cases such as the example above.
`MapVariable`s are not allowed to be unpacked while `ZipVariable`s can only be unpacked if all of its iterables can also be unpacked.
We also add new `[has_]force_unpack_var_sequence` methods to `VariableTracker` for the case where it is safe to unpack the entire sequence lazily, e.g., when building a list from a map (i.e. `list(map(f, ...))`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131413
Approved by: https://github.com/anijain2305
2024-07-26 10:47:38 +00:00
Yidi Wu
ffc6bf8149
[dynamo] lazily guard and specialize on the symint when used in f-string. ( #131529 )
...
Fixes https://github.com/pytorch/pytorch/issues/103602 .
This PR implements the idea of "if someone creates a string and then ends up not using it, we would prefer to NOT have specialized." mentioned in above issue. Specifically, we create a lazy variable tracker instead of ConstantVariable when we're in FORMAT_VALUE, and when the lazy variable tracker is realized (i.e. it's going to be used), we create a ConstantVariable and the specialization/guarding happens at the time of realization.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131529
Approved by: https://github.com/ezyang
2024-07-25 16:16:34 +00:00
Animesh Jain
e2b941a1b4
[dynamo] Rename TENSOR_ALIASING to OBJECT_ALIASING. Permit OBJECT_ALIASING for dict guards ( #131480 )
...
Fixes https://github.com/pytorch/pytorch/issues/129667
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131480
Approved by: https://github.com/williamwen42
ghstack dependencies: #131347 , #131367 , #131378 , #131389 , #131405
2024-07-24 00:06:53 +00:00
Animesh Jain
6bbef2a06b
[dynamo] Support set on KeysView ( #131389 )
...
Fixes https://github.com/pytorch/pytorch/issues/129664
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131389
Approved by: https://github.com/mlazos
ghstack dependencies: #131347 , #131367 , #131378
2024-07-23 14:15:26 +00:00
Animesh Jain
e7c5e06772
[dynamo] Support __contains__ on __dict__ on UserDefinedClassVariable ( #131378 )
...
Fixes https://github.com/pytorch/pytorch/issues/129665
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131378
Approved by: https://github.com/mlazos
ghstack dependencies: #131347 , #131367
2024-07-23 14:15:26 +00:00
Animesh Jain
0bc5e26067
[dynamo] Support dict conversion of objects derived from MutableMapping ( #131367 )
...
Fixes - https://github.com/pytorch/pytorch/issues/129662
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131367
Approved by: https://github.com/williamwen42
ghstack dependencies: #131347
2024-07-23 14:15:20 +00:00
Animesh Jain
a944cce5b8
[dynamo] Support if callable on list ( #131347 )
...
Fixes https://github.com/pytorch/pytorch/issues/130720
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131347
Approved by: https://github.com/williamwen42 , https://github.com/mlazos
2024-07-23 14:15:15 +00:00
Alex Dennis
7d4f50de19
dynamo add support for defaultdict(set) ( #130745 )
...
Fixes #130554
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130745
Approved by: https://github.com/Skylion007
2024-07-15 22:23:33 +00:00
PyTorch MergeBot
dff9d68f18
Revert "Fix names conflict when lifting ( #129817 )"
...
This reverts commit 53cf46b8c6 .
Reverted https://github.com/pytorch/pytorch/pull/129817 on behalf of https://github.com/clee2000 due to Failing inductor/test_flex_attention.py https://github.com/pytorch/pytorch/actions/runs/9940532858/job/27478084137 74da2a467f Sorry for the churn, possibly a landrace? ([comment](https://github.com/pytorch/pytorch/pull/129817#issuecomment-2229519886 ))
2024-07-15 22:08:45 +00:00
Zhanghan Wang
53cf46b8c6
Fix names conflict when lifting ( #129817 )
...
## Bug description
When pending args that are potentially to be lift [here](58f346c874/torch/_dynamo/output_graph.py (L1866) ) having same base name, like `contiguous` and `contiguous_1`, the call into [create_graph_input](58f346c874/torch/_dynamo/output_graph.py (L2081) ) can finally create a name ([here](58f346c874/torch/fx/graph.py (L1008) )) that overwrite args to lift. And thus causing a wrong output of graph.
## Reproducing
Below is an reproduceable example,
```python
import logging
from typing import List
import torch
from functorch.compile import aot_module_simplified, make_boxed_func
@torch.library.custom_op("mylib::somefunc_forward", mutates_args=())
def somefunc_forward(
input_: torch.Tensor,
weight: torch.Tensor,
shape: List[int],
) -> torch.Tensor:
return torch.ones_like(input_)
@somefunc_forward.register_fake
def _(input_, shape, weight):
return torch.empty_like(input_)
@torch.library.custom_op("mylib::somefunc_backward", mutates_args=())
def somefunc_backward(
grad_output: torch.Tensor,
input_: torch.Tensor,
weight: torch.Tensor,
shape: List[int],
) -> torch.Tensor:
print(f"backward.{grad_output.shape=}")
print(f"backward.{input_.shape=}")
print(f"backward.{weight.shape=}")
print(f"backward.{shape=}")
assert list(weight.shape) == shape
return torch.ones_like(weight)
@somefunc_backward.register_fake
def _(grad_output, input_, weight, shape):
return torch.empty_like(weight)
def a_func(grad_output, input_, weight_, shape):
return torch.ones_like(input_.sum() * weight_)
class SomeFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input, weight, normalized_shape):
ctx.normalized_shape = normalized_shape
input_ = input.contiguous()
weight_ = weight.contiguous()
output = somefunc_forward(input_, weight_, ctx.normalized_shape)
ctx.save_for_backward(input_, weight_)
return output
@staticmethod
def backward(ctx, grad_output):
input_, weight_ = ctx.saved_tensors
# grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape)
grad_weight = somefunc_backward(
grad_output.contiguous(),
input_,
weight_,
ctx.normalized_shape,
)
return None, grad_weight, None
class MyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(7))
def forward(self, x):
return SomeFunc.apply(x, self.weight, [7])
model = MyModel()
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, graph_code=True)
def aot_print_backend(gm, sample_inputs):
# Forward compiler capture
def fw(gm, sample_inputs):
print(f"----- fw")
gm.print_readable()
return make_boxed_func(gm.forward)
# Backward compiler capture
def bw(gm, sample_inputs):
print(f"----- bw")
gm.print_readable()
return make_boxed_func(gm.forward)
# Call AOTAutograd
gm_forward = aot_module_simplified(
gm, sample_inputs, fw_compiler=fw, bw_compiler=bw
)
return gm_forward
model = torch.compile(
model,
backend=aot_print_backend,
dynamic=False,
)
out = model(torch.rand((128, 4, 7)))
out.mean().backward()
```
I can see log that showing calling into create_graph_input like
```log
V0629 02:08:46.839914 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous (none)
V0629 02:08:46.839998 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous_1 (none)
```
And the backward graph generate will be like
```log
class GraphModule(torch.nn.Module):
def forward(self, function_ctx, somefunc_forward_default: "f32[128, 4, 7]", contiguous: "f32[128, 4, 7]", contiguous_1: "f32[7]"):
contiguous_1 = contiguous
contiguous_2 = contiguous_1
# No stacktrace found for following nodes
_set_grad_enabled = torch._C._set_grad_enabled(False)
# File: /Users/bytedance/testtorch/test_custom_op_bug.py:61 in backward, code: grad_output.contiguous(),
contiguous: "f32[128, 4, 7]" = somefunc_forward_default.contiguous(); somefunc_forward_default = None
# File: /opt/tiger/pytorch/torch/_library/custom_ops.py:506 in __call__, code: return self._opoverload(*args, **kwargs)
somefunc_backward_default: "f32[7]" = torch.ops.mylib.somefunc_backward.default(contiguous, contiguous_1, contiguous_2, [7]); contiguous = contiguous_1 = contiguous_2 = None
# No stacktrace found for following nodes
_set_grad_enabled_1 = torch._C._set_grad_enabled(True)
return (None, somefunc_backward_default)
```
The original code of `somefunc_backward` takes a input list of `grad_output`, `input_`, `weight` and `shape`, where `weight` should be shape of `torch.Size([7])`. However, in the graph, `contiguous1` and `contiguous_2` are assigned with `contiguous`, this leads to assertion failure I added in `somefunc_backward`.
## Environment
```log
Collecting environment information...
PyTorch version: 2.5.0a0+git0b7e8df
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.26.4
Libc version: N/A
Python version: 3.9.19 (main, May 6 2024, 14:39:30) [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Pro
Versions of relevant libraries:
[pip3] numpy==2.0.0
[pip3] optree==0.11.0
[pip3] torch==2.5.0a0+git0b7e8df
[pip3] torchgraph==0.0.1
[conda] numpy 2.0.0 pypi_0 pypi
[conda] optree 0.11.0 pypi_0 pypi
[conda] torch 2.5.0a0+git0b7e8df dev_0 <develop>
[conda] torchgraph 0.0.1 dev_0 <develop>
```
## How to fix?
I put a naive fix that add the potential args to lift into the used_names. This visits private variables, will fix that if this issue makes sense to you.
@zou3519 @oulgen
Co-authored-by: rzou <zou3519@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129817
Approved by: https://github.com/zou3519
2024-07-15 18:49:12 +00:00