[dynamo][ca] support dynamic annotations on tensors in ListVariables/TupleVariables (#152119)

Together with https://github.com/pytorch/pytorch/pull/151962, FIXES https://github.com/pytorch/pytorch/issues/133575

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152119
Approved by: https://github.com/jansel
ghstack dependencies: #151731, #151962
This commit is contained in:
Simon Fan 2025-05-07 14:14:46 -07:00 committed by PyTorch MergeBot
parent 6dea8ef555
commit 500cbeee4e
2 changed files with 34 additions and 6 deletions

View File

@ -1237,6 +1237,25 @@ main()
self.check_output_and_recompiles(fn)
def test_dynamic_shapes_annotations(self):
@torch.compile
def f(x):
return x.sin().sin()
with torch._dynamo.compiled_autograd._enable(torch.compile):
x = torch.randn(2, 3, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
out = f(x)
out.sum().backward()
x = torch.randn(4, 3, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
out = f(x)
out.sum().backward()
# mark_dynamic should not cause ConstraintViolationError
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
def test_torch_compile_api_dynamic_shapes(self):
# Here, we have no way of marking the symbolic sizes using in SumBackward as dynamic
def fn(call_backward):

View File

@ -2438,12 +2438,21 @@ class SubgraphTracer(fx.Tracer):
# Also see NOTE: [Export inputs must be explicitly passed in]
is_strict_export = self.is_export
is_non_strict_export = torch.compiler.is_compiling()
if (
not is_strict_export
and not is_non_strict_export
and isinstance(example_value, torch.Tensor)
):
self._lift_basic_symbols(example_value, source)
if not is_strict_export and not is_non_strict_export:
if isinstance(example_value, torch.Tensor):
self._lift_basic_symbols(example_value, source)
elif isinstance(example_value, (list, tuple)):
for i, e in enumerate(example_value):
if not isinstance(e, torch.Tensor):
continue
e_source = None
if source:
e_source = GetItemSource(
base=source, index=i, index_is_slice=False
)
self._lift_basic_symbols(e, e_source)
# Bound the symbol to ph if example_value is a SymInt with basic symbol.
if isinstance(example_value, torch.SymInt) and isinstance(