mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix set_unbacked_bindings when list of Tensors is returned (#133585)
Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133585 Approved by: https://github.com/albanD
This commit is contained in:
parent
2443507acc
commit
2a49296d75
|
|
@ -652,6 +652,40 @@ graph():
|
||||||
foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
|
foo, bad_example_inp, dynamic_shapes=dynamic_shapes, strict=False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_unbacked_to_cond(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, a):
|
||||||
|
az = a.nonzero()
|
||||||
|
|
||||||
|
def true_fn(x):
|
||||||
|
return (x + 1).sum()
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return (x + 3).sum()
|
||||||
|
|
||||||
|
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
|
||||||
|
return r * 2
|
||||||
|
|
||||||
|
M()(torch.randn(7))
|
||||||
|
torch.export.export(M(), (torch.randn(7),))
|
||||||
|
|
||||||
|
def test_unbacked_to_cond_passthrough(self):
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, a):
|
||||||
|
az = a.nonzero()
|
||||||
|
|
||||||
|
def true_fn(x):
|
||||||
|
return x + 1
|
||||||
|
|
||||||
|
def false_fn(x):
|
||||||
|
return x + 3
|
||||||
|
|
||||||
|
r = torch.cond(az.size(0) > 3, true_fn, false_fn, (az,))
|
||||||
|
return r * 2
|
||||||
|
|
||||||
|
M()(torch.randn(7))
|
||||||
|
torch.export.export(M(), (torch.randn(7),))
|
||||||
|
|
||||||
def test_state_tensors(self):
|
def test_state_tensors(self):
|
||||||
class M(torch.nn.Module): # simple with register buffer
|
class M(torch.nn.Module): # simple with register buffer
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -596,8 +596,6 @@ def track_tensor_tree(
|
||||||
constant: Optional[_NestedTensors],
|
constant: Optional[_NestedTensors],
|
||||||
tracer: _ProxyTracer,
|
tracer: _ProxyTracer,
|
||||||
) -> T:
|
) -> T:
|
||||||
_set_unbacked_bindings(inner_res, proxy_res)
|
|
||||||
|
|
||||||
def wrap_with_proxy(
|
def wrap_with_proxy(
|
||||||
e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
|
e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors]
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
@ -606,11 +604,13 @@ def track_tensor_tree(
|
||||||
assert constant is None or isinstance(constant, Tensor)
|
assert constant is None or isinstance(constant, Tensor)
|
||||||
track_tensor(e, proxy, tracer=tracer, constant=constant)
|
track_tensor(e, proxy, tracer=tracer, constant=constant)
|
||||||
set_meta(proxy, e)
|
set_meta(proxy, e)
|
||||||
|
_set_unbacked_bindings(e, proxy)
|
||||||
elif isinstance(e, py_sym_types):
|
elif isinstance(e, py_sym_types):
|
||||||
assert isinstance(proxy, Proxy)
|
assert isinstance(proxy, Proxy)
|
||||||
# NB: eagerly set meta here, so that the numbering is in order
|
# NB: eagerly set meta here, so that the numbering is in order
|
||||||
set_meta(proxy, e)
|
set_meta(proxy, e)
|
||||||
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy))
|
||||||
|
_set_unbacked_bindings(e, proxy)
|
||||||
elif isinstance(e, _AnyScriptObject):
|
elif isinstance(e, _AnyScriptObject):
|
||||||
assert isinstance(proxy, Proxy)
|
assert isinstance(proxy, Proxy)
|
||||||
set_proxy_slot(e, tracer, proxy)
|
set_proxy_slot(e, tracer, proxy)
|
||||||
|
|
@ -2188,6 +2188,8 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
|
||||||
"""A helper function for setting up unbacked_bindings on the destination FX graph."""
|
"""A helper function for setting up unbacked_bindings on the destination FX graph."""
|
||||||
from .symbolic_shapes import compute_unbacked_bindings
|
from .symbolic_shapes import compute_unbacked_bindings
|
||||||
|
|
||||||
|
log.debug("_set_unbacked_bindings %s", out_proxy)
|
||||||
|
|
||||||
# Can't use detect_fake_mode here,
|
# Can't use detect_fake_mode here,
|
||||||
#
|
#
|
||||||
# python test/distributed/_tensor/test_dtensor_compile.py -k
|
# python test/distributed/_tensor/test_dtensor_compile.py -k
|
||||||
|
|
@ -2198,5 +2200,5 @@ def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None:
|
||||||
fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
|
fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE)
|
||||||
if fake_mode and fake_mode.shape_env:
|
if fake_mode and fake_mode.shape_env:
|
||||||
if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
|
if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out):
|
||||||
assert isinstance(out_proxy, Proxy)
|
assert isinstance(out_proxy, Proxy), out_proxy
|
||||||
out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
out_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user