mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[draft export] avoid storing intermediate real tensors in proxies (#154630)"
This reverts commit5acb8d5080. Reverted https://github.com/pytorch/pytorch/pull/154630 on behalf of https://github.com/malfet due to This still ooms, at least occasionally see78624679a8/1([comment](https://github.com/pytorch/pytorch/pull/154630#issuecomment-2923759745))
This commit is contained in:
parent
faf973da5e
commit
0fab32290a
|
|
@ -668,32 +668,6 @@ class TestDraftExport(TestCase):
|
|||
package_path=f.name,
|
||||
)
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
|
||||
def test_cuda_memory_usage(self):
|
||||
# This used to OOM
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
for _ in range(100):
|
||||
x = x + 1e-3
|
||||
return x
|
||||
|
||||
# measure base usage
|
||||
device = torch.device("cuda:0")
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
base_usage = torch.cuda.memory_allocated(device)
|
||||
|
||||
# usage with input tensor allocated
|
||||
x = torch.randn(2**10, 2**10, 2**8).to(device)
|
||||
x_usage = torch.cuda.memory_allocated(device)
|
||||
|
||||
# draft export peak memory usage
|
||||
draft_export(Foo(), (x,), strict=False)
|
||||
peak_mem_usage = torch.cuda.memory_stats(device)["allocated_bytes.all.peak"]
|
||||
|
||||
# right now it's actually exactly 4x;
|
||||
# I guess original tensor, 2 tensors per add op, 1 for clone stored in node.meta["val"]
|
||||
self.assertTrue((peak_mem_usage - base_usage) <= (x_usage - base_usage) * 4.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1059,12 +1059,10 @@ def make_fast_binary_impl(
|
|||
|
||||
# disable the python dispatcher to avoid decomposing detach() further
|
||||
# (proxy_mode should still decompose detach() though)
|
||||
def fast_detach(fake_mode, x, include_real=False):
|
||||
def fast_detach(fake_mode, x):
|
||||
with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
|
||||
out = torch.ops.aten.detach.default(x)
|
||||
if include_real:
|
||||
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
|
||||
return FakeTensor(fake_mode, out, x.device)
|
||||
return FakeTensor(fake_mode, out, x.device, real_tensor=x.real_tensor)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -367,12 +367,12 @@ def get_proxy_slot(
|
|||
return res
|
||||
|
||||
|
||||
def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]:
|
||||
def snapshot_fake(val: Tensor) -> Optional[Tensor]:
|
||||
# val.detach() will also eventually call fast_detach(),
|
||||
# but this saves us a full trip into __torch_dispatch__
|
||||
# (snapshot_fake is called a lot)
|
||||
if isinstance(val, FakeTensor):
|
||||
return fast_detach(val.fake_mode, val, include_real)
|
||||
return fast_detach(val.fake_mode, val)
|
||||
else:
|
||||
return val.detach()
|
||||
|
||||
|
|
@ -393,9 +393,9 @@ _ExtractValType = Optional[
|
|||
]
|
||||
|
||||
|
||||
def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType:
|
||||
def extract_val(val: _ExtractValType) -> _ExtractValType:
|
||||
if is_fake(val):
|
||||
return snapshot_fake(val, include_real=include_real)
|
||||
return snapshot_fake(val)
|
||||
elif isinstance(val, py_sym_types):
|
||||
return val
|
||||
elif isinstance(val, _AnyScriptObject):
|
||||
|
|
@ -494,9 +494,7 @@ def maybe_enable_thunkify() -> Generator[None, None, None]:
|
|||
# grad_fn, _base (_base actually may be set due to recursive call to
|
||||
# ADInplaceOrView, but you shouldn't rely on it.)
|
||||
def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy:
|
||||
proxy.node.meta["val"] = extract_val(
|
||||
val, include_real=(proxy.node.op == "placeholder")
|
||||
)
|
||||
proxy.node.meta["val"] = extract_val(val)
|
||||
|
||||
with _enable_thunkify(proxy.tracer): # type: ignore[arg-type]
|
||||
# Best effort tensor_meta setting; prefer using val!
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user