mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][ca] support dynamic annotations on tensors in ListVariables/TupleVariables (#152119)
Together with https://github.com/pytorch/pytorch/pull/151962, FIXES https://github.com/pytorch/pytorch/issues/133575 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152119 Approved by: https://github.com/jansel ghstack dependencies: #151731, #151962
This commit is contained in:
parent
6dea8ef555
commit
500cbeee4e
|
|
@ -1237,6 +1237,25 @@ main()
|
|||
|
||||
self.check_output_and_recompiles(fn)
|
||||
|
||||
def test_dynamic_shapes_annotations(self):
|
||||
@torch.compile
|
||||
def f(x):
|
||||
return x.sin().sin()
|
||||
|
||||
with torch._dynamo.compiled_autograd._enable(torch.compile):
|
||||
x = torch.randn(2, 3, requires_grad=True)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
out = f(x)
|
||||
out.sum().backward()
|
||||
|
||||
x = torch.randn(4, 3, requires_grad=True)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
out = f(x)
|
||||
out.sum().backward()
|
||||
|
||||
# mark_dynamic should not cause ConstraintViolationError
|
||||
self.assertEqual(counters["compiled_autograd"]["captures"], 1)
|
||||
|
||||
def test_torch_compile_api_dynamic_shapes(self):
|
||||
# Here, we have no way of marking the symbolic sizes using in SumBackward as dynamic
|
||||
def fn(call_backward):
|
||||
|
|
|
|||
|
|
@ -2438,12 +2438,21 @@ class SubgraphTracer(fx.Tracer):
|
|||
# Also see NOTE: [Export inputs must be explicitly passed in]
|
||||
is_strict_export = self.is_export
|
||||
is_non_strict_export = torch.compiler.is_compiling()
|
||||
if (
|
||||
not is_strict_export
|
||||
and not is_non_strict_export
|
||||
and isinstance(example_value, torch.Tensor)
|
||||
):
|
||||
self._lift_basic_symbols(example_value, source)
|
||||
if not is_strict_export and not is_non_strict_export:
|
||||
if isinstance(example_value, torch.Tensor):
|
||||
self._lift_basic_symbols(example_value, source)
|
||||
elif isinstance(example_value, (list, tuple)):
|
||||
for i, e in enumerate(example_value):
|
||||
if not isinstance(e, torch.Tensor):
|
||||
continue
|
||||
|
||||
e_source = None
|
||||
if source:
|
||||
e_source = GetItemSource(
|
||||
base=source, index=i, index_is_slice=False
|
||||
)
|
||||
|
||||
self._lift_basic_symbols(e, e_source)
|
||||
|
||||
# Bound the symbol to ph if example_value is a SymInt with basic symbol.
|
||||
if isinstance(example_value, torch.SymInt) and isinstance(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user