mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix set_unbacked_bindings when list of Tensors is returned (#133585)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133585 Approved by: https://github.com/albanD
This commit is contained in:
parent
2443507acc
commit
2a49296d75
|
|
@ -652,6 +652,40 @@ graph():
|
|||
foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
|
||||
)
|
||||
|
||||
def test_unbacked_to_cond(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
az = a.nonzero()
|
||||
|
||||
def true_fn(x):
|
||||
return (x + 1).sum()
|
||||
|
||||
def false_fn(x):
|
||||
return (x + 3).sum()
|
||||
|
||||
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
|
||||
return r * 2
|
||||
|
||||
M()(torch.randn(7))
|
||||
torch.export.export(M(), (torch.randn(7),))
|
||||
|
||||
def test_unbacked_to_cond_passthrough(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
az = a.nonzero()
|
||||
|
||||
def true_fn(x):
|
||||
return x + 1
|
||||
|
||||
def false_fn(x):
|
||||
return x + 3
|
||||
|
||||
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
|
||||
return r * 2
|
||||
|
||||
M()(torch.randn(7))
|
||||
torch.export.export(M(), (torch.randn(7),))
|
||||
|
||||
def test_state_tensors(self):
|
||||
class M(torch.nn.Module): # simple with register buffer
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -596,8 +596,6 @@ def track_tensor_tree(
|
|||
constant: Optional[_NestedTensors],
|
||||
tracer: _ProxyTracer,
|
||||
) -> T:
|
||||
_set_unbacked_bindings(inner_res, proxy_res)
|
||||
|
||||
def wrap_with_proxy(
|
||||
e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
|
||||
) -> None:
|
||||
|
|
@ -606,11 +604,13 @@ def track_tensor_tree(
|
|||
assert constant is None or isinstance(constant, Tensor)
|
||||
track_tensor(e, proxy, tracer=tracer, constant=constant)
|
||||
set_meta(proxy, e)
|
||||
_set_unbacked_bindings(e, proxy)
|
||||
elif isinstance(e, py_sym_types):
|
||||
assert isinstance(proxy, Proxy)
|
||||
# NB: eagerly set meta here, so that the numbering is in order
|
||||
set_meta(proxy, e)
|
||||
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
||||
_set_unbacked_bindings(e, proxy)
|
||||
elif isinstance(e, _AnyScriptObject):
|
||||
assert isinstance(proxy, Proxy)
|
||||
set_proxy_slot(e, tracer, proxy)
|
||||
|
|
@ -2188,6 +2188,8 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
|
|||
"""A helper function for setting up unbacked_bindings on the destination FX graph."""
|
||||
from .symbolic_shapes import compute_unbacked_bindings
|
||||
|
||||
log.debug("_set_unbacked_bindings %s", out_proxy)
|
||||
|
||||
# Can't use detect_fake_mode here,
|
||||
#
|
||||
# python test/distributed/_tensor/test_dtensor_compile.py -k
|
||||
|
|
@ -2198,5 +2200,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
|
|||
fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
|
||||
if fake_mode and fake_mode.shape_env:
|
||||
if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
|
||||
assert isinstance(out_proxy, Proxy)
|
||||
assert isinstance(out_proxy, Proxy), out_proxy
|
||||
out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user