Fix set_unbacked_bindings when list of Tensors is returned (#133585)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133585
Approved by: https://github.com/albanD
This commit is contained in:
Edward Z. Yang 2024-09-01 15:49:37 -04:00 committed by PyTorch MergeBot
parent 2443507acc
commit 2a49296d75
2 changed files with 39 additions and 3 deletions

View File

@ -652,6 +652,40 @@ graph():
foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
) )
def test_unbacked_to_cond(self):
class M(torch.nn.Module):
def forward(self, a):
az = a.nonzero()
def true_fn(x):
return (x + 1).sum()
def false_fn(x):
return (x + 3).sum()
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
return r * 2
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),))
def test_unbacked_to_cond_passthrough(self):
class M(torch.nn.Module):
def forward(self, a):
az = a.nonzero()
def true_fn(x):
return x + 1
def false_fn(x):
return x + 3
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
return r * 2
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),))
def test_state_tensors(self): def test_state_tensors(self):
class M(torch.nn.Module): # simple with register buffer class M(torch.nn.Module): # simple with register buffer
def __init__(self) -> None: def __init__(self) -> None:

View File

@ -596,8 +596,6 @@ def track_tensor_tree(
constant: Optional[_NestedTensors], constant: Optional[_NestedTensors],
tracer: _ProxyTracer, tracer: _ProxyTracer,
) -> T: ) -> T:
_set_unbacked_bindings(inner_res, proxy_res)
def wrap_with_proxy( def wrap_with_proxy(
e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
) -> None: ) -> None:
@ -606,11 +604,13 @@ def track_tensor_tree(
assert constant is None or isinstance(constant, Tensor) assert constant is None or isinstance(constant, Tensor)
track_tensor(e, proxy, tracer=tracer, constant=constant) track_tensor(e, proxy, tracer=tracer, constant=constant)
set_meta(proxy, e) set_meta(proxy, e)
_set_unbacked_bindings(e, proxy)
elif isinstance(e, py_sym_types): elif isinstance(e, py_sym_types):
assert isinstance(proxy, Proxy) assert isinstance(proxy, Proxy)
# NB: eagerly set meta here, so that the numbering is in order # NB: eagerly set meta here, so that the numbering is in order
set_meta(proxy, e) set_meta(proxy, e)
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
_set_unbacked_bindings(e, proxy)
elif isinstance(e, _AnyScriptObject): elif isinstance(e, _AnyScriptObject):
assert isinstance(proxy, Proxy) assert isinstance(proxy, Proxy)
set_proxy_slot(e, tracer, proxy) set_proxy_slot(e, tracer, proxy)
@ -2188,6 +2188,8 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
"""A helper function for setting up unbacked_bindings on the destination FX graph.""" """A helper function for setting up unbacked_bindings on the destination FX graph."""
from .symbolic_shapes import compute_unbacked_bindings from .symbolic_shapes import compute_unbacked_bindings
log.debug("_set_unbacked_bindings %s", out_proxy)
# Can't use detect_fake_mode here, # Can't use detect_fake_mode here,
# #
# python test/distributed/_tensor/test_dtensor_compile.py -k # python test/distributed/_tensor/test_dtensor_compile.py -k
@ -2198,5 +2200,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
if fake_mode and fake_mode.shape_env: if fake_mode and fake_mode.shape_env:
if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
assert isinstance(out_proxy, Proxy) assert isinstance(out_proxy, Proxy), out_proxy
out_proxy.node.meta["unbacked_bindings"] = symbol_to_path out_proxy.node.meta["unbacked_bindings"] = symbol_to_path