Revert "Close some sources of fake tensor leakages (#159923)"

This reverts commit 5afa4187df.

Reverted https://github.com/pytorch/pytorch/pull/159923 on behalf of https://github.com/zou3519 due to broke aoti test in inductor periodic ([comment](https://github.com/pytorch/pytorch/pull/159923#issuecomment-3215580688))
This commit is contained in:
PyTorch MergeBot 2025-08-22 20:42:50 +00:00
parent 3ea6cc8c2d
commit 981ac533c6
4 changed files with 23 additions and 151 deletions

View File

@ -1 +1 @@
22bc29b4d503fc895ff73bc720ff396e9723465f
e03a63be43e33596f7f0a43b0f530353785e4a59

View File

@ -4367,80 +4367,6 @@ def forward(self, x):
x = torch.tensor([1, 2])
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
def test_nested_module_fake_tensor_leak(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self._tensor_cache = None
def forward(self, x):
if self._tensor_cache is None:
self._tensor_cache = x + 2
return self._tensor_cache.sum() + x.sum()
class Foo(torch.nn.Module):
def __init__(self, bar):
super().__init__()
self.bar = bar
def forward(self, x):
return self.bar(x)
foo = Foo(Bar())
_ = export(foo, (torch.ones(4, 4),), strict=False)
self.assertTrue(foo.bar._tensor_cache is None)
def test_export_leak_compile(self):
class BaseModule(torch.nn.Module):
def forward(self, *args, **kwargs):
raise NotImplementedError
class CacheModule(BaseModule):
def __init__(self, cache: torch.Tensor):
super().__init__()
assert cache.ndim == 3
self.cache = torch.nn.Parameter(cache, requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_tokens = x.size(1)
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
rolled_cache[:, -n_tokens:, :] = x
self.cache.data = rolled_cache
return self.cache
class LinearBlock(torch.nn.Module):
def __init__(self, in_features, out_features, activation=None):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
self.activation = activation
def forward(self, x):
x = self.linear(x)
return self.activation(x) if self.activation else x
class MyModel(BaseModule):
def __init__(self):
super().__init__()
default_cache = torch.zeros(1, 10, 5)
self.cache_layer = CacheModule(default_cache)
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
self.fc2 = LinearBlock(10, 5)
def forward(self, x):
cached = self.cache_layer(x)
out = self.fc1(cached)
out = self.fc2(out)
return out
with self.assertRaisesRegex(
RuntimeError,
"We found a fake tensor in the exported program constant's list. "
"This typically means our tracing system encountered an op that we can't trace through. "
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
"Please file an issue on github.",
):
_ = export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
def test_export_for_training_with_container_type(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:

View File

@ -221,23 +221,10 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
# return any attributes of a module that are not standard attributes
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
def _get_all_module_attributes(mod):
# return attributes from all modules and submodules
result = {}
for name, submodule in mod.named_modules():
result[name] = _get_attributes(submodule)
return result
def _restore_all_module_attributes(mod, snapshot):
# restore attributes to all modules and submodules
for name, submodule in mod.named_modules():
if name in snapshot:
submodule.__dict__.update(snapshot[name])
# save state of attributes before enter
snapshot = pytree.tree_map(
lambda x: x,
_get_all_module_attributes(mod),
_get_attributes(mod),
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
)
try:
@ -249,54 +236,41 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
def _collect_assigned_tensor_attributes(kp, v, _v):
if _v is not v:
module_name, attr, *rest = kp
attr, *rest = kp
if isinstance(v, torch.Tensor):
module_prefix = f"{module_name.key}." if module_name.key else ""
assigned_tensor_attributes.append(
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
f"self.{attr.key}{pytree.keystr(rest)}"
)
# TODO(avik): Assigning all other types are allowed right now.
# Maybe in the future we want to limit this to primitive types?
return v
new_attrs = _get_all_module_attributes(mod)
new_attrs = _get_attributes(mod)
if len(new_attrs) != len(snapshot):
added_attrs = new_attrs.keys() - snapshot.keys()
deleted_attrs = snapshot.keys() - new_attrs.keys()
# Check for added/deleted attributes across all modules
for module_name in snapshot.keys() | new_attrs.keys():
old_module_attrs = snapshot.get(module_name, {})
new_module_attrs = new_attrs.get(module_name, {})
if len(added_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(new_module_attrs) != len(old_module_attrs):
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
module_prefix = f"self.{module_name}." if module_name else "self."
if len(added_attrs) > 0:
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(deleted_attrs) > 0:
formatted_attrs = [
f"{module_prefix}{attr}" for attr in deleted_attrs
]
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(deleted_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
pytree.tree_map_with_path(
_collect_assigned_tensor_attributes, snapshot, new_attrs
)
# restore state of all attributes (including, e.g., of primitive types)
_restore_all_module_attributes(mod, snapshot)
mod.__dict__.update(snapshot)
if assigned_tensor_attributes:
if len(assigned_tensor_attributes) > 1:

View File

@ -1850,14 +1850,6 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
return next(iter(node for node in gm.graph.nodes if node.name == name))
def _is_bogus_const_name(name: str):
splitted_names = name.split(".")
if len(splitted_names) < 1:
return True
return splitted_names[-1].startswith("lifted_tensor")
def _non_strict_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
@ -2057,11 +2049,6 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
has_ambient_mode = False
if not strict:
flat_args, _ = pytree.tree_flatten((args, kwargs))
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
# Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export
@ -2076,21 +2063,6 @@ def _export_for_training(
_to_aten_func=_export_to_aten_ir_make_fx,
)
# If we are tracing with fake inputs, it is expected to
# see fake tensor constants.
if not strict and not has_ambient_mode:
for const, val in export_artifact.aten.constants.items():
if isinstance(
val, torch._subclasses.fake_tensor.FakeTensor
) and _is_bogus_const_name(const):
raise RuntimeError(
f"We found a fake tensor in the exported program constant's list. "
f"This typically means our tracing system encountered an op that "
f"we can't trace through. For the potential source, you can refer to "
f"following model attribute: {const}. "
f"Please file an issue on github. "
)
export_graph_signature = export_artifact.aten.sig
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)