Close some sources of fake tensor leakages (#159923)

Differential Revision: D79694055

Couple of fixes:
1. When we run into an operation we didn't proxy, we end up emitting fake constants. We detect this and error using the FQN of the lifted constant
2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict
3. We modify yolov3 to fix the previous silent incorrect behaviour

When upgrading torchbench pin, opacus_cifar10 seems to not run on eager anymore. I verified this by pushing a temporary PR on master with new pin. So i added it to expect_fail list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159923
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2025-08-20 22:24:23 +00:00 committed by PyTorch MergeBot
parent 30384abcb1
commit 5afa4187df
4 changed files with 151 additions and 23 deletions

View File

@ -1 +1 @@
e03a63be43e33596f7f0a43b0f530353785e4a59
22bc29b4d503fc895ff73bc720ff396e9723465f

View File

@ -4341,6 +4341,80 @@ def forward(self, x):
x = torch.tensor([1, 2])
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
def test_nested_module_fake_tensor_leak(self):
class Bar(torch.nn.Module):
def __init__(self):
super().__init__()
self._tensor_cache = None
def forward(self, x):
if self._tensor_cache is None:
self._tensor_cache = x + 2
return self._tensor_cache.sum() + x.sum()
class Foo(torch.nn.Module):
def __init__(self, bar):
super().__init__()
self.bar = bar
def forward(self, x):
return self.bar(x)
foo = Foo(Bar())
_ = export(foo, (torch.ones(4, 4),), strict=False)
self.assertTrue(foo.bar._tensor_cache is None)
def test_export_leak_compile(self):
class BaseModule(torch.nn.Module):
def forward(self, *args, **kwargs):
raise NotImplementedError
class CacheModule(BaseModule):
def __init__(self, cache: torch.Tensor):
super().__init__()
assert cache.ndim == 3
self.cache = torch.nn.Parameter(cache, requires_grad=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
n_tokens = x.size(1)
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
rolled_cache[:, -n_tokens:, :] = x
self.cache.data = rolled_cache
return self.cache
class LinearBlock(torch.nn.Module):
def __init__(self, in_features, out_features, activation=None):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features)
self.activation = activation
def forward(self, x):
x = self.linear(x)
return self.activation(x) if self.activation else x
class MyModel(BaseModule):
def __init__(self):
super().__init__()
default_cache = torch.zeros(1, 10, 5)
self.cache_layer = CacheModule(default_cache)
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
self.fc2 = LinearBlock(10, 5)
def forward(self, x):
cached = self.cache_layer(x)
out = self.fc1(cached)
out = self.fc2(out)
return out
with self.assertRaisesRegex(
RuntimeError,
"We found a fake tensor in the exported program constant's list. "
"This typically means our tracing system encountered an op that we can't trace through. "
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
"Please file an issue on github.",
):
_ = export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
def test_export_for_training_with_container_type(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:

View File

@ -221,10 +221,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
# return any attributes of a module that are not standard attributes
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
def _get_all_module_attributes(mod):
# return attributes from all modules and submodules
result = {}
for name, submodule in mod.named_modules():
result[name] = _get_attributes(submodule)
return result
def _restore_all_module_attributes(mod, snapshot):
# restore attributes to all modules and submodules
for name, submodule in mod.named_modules():
if name in snapshot:
submodule.__dict__.update(snapshot[name])
# save state of attributes before enter
snapshot = pytree.tree_map(
lambda x: x,
_get_attributes(mod),
_get_all_module_attributes(mod),
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
)
try:
@ -236,41 +249,54 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
def _collect_assigned_tensor_attributes(kp, v, _v):
if _v is not v:
attr, *rest = kp
module_name, attr, *rest = kp
if isinstance(v, torch.Tensor):
module_prefix = f"{module_name.key}." if module_name.key else ""
assigned_tensor_attributes.append(
f"self.{attr.key}{pytree.keystr(rest)}"
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
)
# TODO(avik): Assigning all other types are allowed right now.
# Maybe in the future we want to limit this to primitive types?
return v
new_attrs = _get_attributes(mod)
if len(new_attrs) != len(snapshot):
added_attrs = new_attrs.keys() - snapshot.keys()
deleted_attrs = snapshot.keys() - new_attrs.keys()
new_attrs = _get_all_module_attributes(mod)
if len(added_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
# Check for added/deleted attributes across all modules
for module_name in snapshot.keys() | new_attrs.keys():
old_module_attrs = snapshot.get(module_name, {})
new_module_attrs = new_attrs.get(module_name, {})
if len(deleted_attrs) > 0:
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(new_module_attrs) != len(old_module_attrs):
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
module_prefix = f"self.{module_name}." if module_name else "self."
if len(added_attrs) > 0:
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
raise ValueError(
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
if len(deleted_attrs) > 0:
formatted_attrs = [
f"{module_prefix}{attr}" for attr in deleted_attrs
]
raise ValueError(
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
f"Such attributes must be registered as buffers using the `register_buffer` "
f"API and must be initialized at model.__init__ "
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
)
pytree.tree_map_with_path(
_collect_assigned_tensor_attributes, snapshot, new_attrs
)
# restore state of all attributes (including, e.g., of primitive types)
mod.__dict__.update(snapshot)
_restore_all_module_attributes(mod, snapshot)
if assigned_tensor_attributes:
if len(assigned_tensor_attributes) > 1:

View File

@ -1850,6 +1850,14 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
return next(iter(node for node in gm.graph.nodes if node.name == name))
def _is_bogus_const_name(name: str):
splitted_names = name.split(".")
if len(splitted_names) < 1:
return True
return splitted_names[-1].startswith("lifted_tensor")
def _non_strict_export(
mod: torch.nn.Module,
args: tuple[Any, ...],
@ -2049,6 +2057,11 @@ def _export_for_training(
original_state_dict = _get_original_state_dict(mod)
has_ambient_mode = False
if not strict:
flat_args, _ = pytree.tree_flatten((args, kwargs))
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
# Call the appropriate export function based on the strictness of tracing.
export_func = _strict_export if strict else _non_strict_export
@ -2063,6 +2076,21 @@ def _export_for_training(
_to_aten_func=_export_to_aten_ir_make_fx,
)
# If we are tracing with fake inputs, it is expected to
# see fake tensor constants.
if not strict and not has_ambient_mode:
for const, val in export_artifact.aten.constants.items():
if isinstance(
val, torch._subclasses.fake_tensor.FakeTensor
) and _is_bogus_const_name(const):
raise RuntimeError(
f"We found a fake tensor in the exported program constant's list. "
f"This typically means our tracing system encountered an op that "
f"we can't trace through. For the potential source, you can refer to "
f"following model attribute: {const}. "
f"Please file an issue on github. "
)
export_graph_signature = export_artifact.aten.sig
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)