mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Close some sources of fake tensor leakages (#159923)
Differential Revision: D79694055 Couple of fixes: 1. When we run into an operation we didn't proxy, we end up emitting fake constants. We detect this and error using the FQN of the lifted constant 2. Previous attribute mutation detection logic in non-strict didn't account for nested module structure. This fixes silent incorrectness issue of exporting esm and qwen in non-strict 3. We modify yolov3 to fix the previous silent incorrect behaviour When upgrading torchbench pin, opacus_cifar10 seems to not run on eager anymore. I verified this by pushing a temporary PR on master with new pin. So i added it to expect_fail list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/159923 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
parent
30384abcb1
commit
5afa4187df
|
|
@ -1 +1 @@
|
|||
e03a63be43e33596f7f0a43b0f530353785e4a59
|
||||
22bc29b4d503fc895ff73bc720ff396e9723465f
|
||||
|
|
|
|||
|
|
@ -4341,6 +4341,80 @@ def forward(self, x):
|
|||
x = torch.tensor([1, 2])
|
||||
self.assertTrue(torch.allclose(mod(x), ep.module()(x)))
|
||||
|
||||
def test_nested_module_fake_tensor_leak(self):
|
||||
class Bar(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self._tensor_cache = None
|
||||
|
||||
def forward(self, x):
|
||||
if self._tensor_cache is None:
|
||||
self._tensor_cache = x + 2
|
||||
return self._tensor_cache.sum() + x.sum()
|
||||
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self, bar):
|
||||
super().__init__()
|
||||
self.bar = bar
|
||||
|
||||
def forward(self, x):
|
||||
return self.bar(x)
|
||||
|
||||
foo = Foo(Bar())
|
||||
_ = export(foo, (torch.ones(4, 4),), strict=False)
|
||||
self.assertTrue(foo.bar._tensor_cache is None)
|
||||
|
||||
def test_export_leak_compile(self):
|
||||
class BaseModule(torch.nn.Module):
|
||||
def forward(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class CacheModule(BaseModule):
|
||||
def __init__(self, cache: torch.Tensor):
|
||||
super().__init__()
|
||||
assert cache.ndim == 3
|
||||
self.cache = torch.nn.Parameter(cache, requires_grad=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
n_tokens = x.size(1)
|
||||
rolled_cache = torch.roll(self.cache.data, -n_tokens, dims=1)
|
||||
rolled_cache[:, -n_tokens:, :] = x
|
||||
self.cache.data = rolled_cache
|
||||
return self.cache
|
||||
|
||||
class LinearBlock(torch.nn.Module):
|
||||
def __init__(self, in_features, out_features, activation=None):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(in_features, out_features)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
x = self.linear(x)
|
||||
return self.activation(x) if self.activation else x
|
||||
|
||||
class MyModel(BaseModule):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
default_cache = torch.zeros(1, 10, 5)
|
||||
self.cache_layer = CacheModule(default_cache)
|
||||
self.fc1 = LinearBlock(5, 10, activation=torch.nn.ReLU())
|
||||
self.fc2 = LinearBlock(10, 5)
|
||||
|
||||
def forward(self, x):
|
||||
cached = self.cache_layer(x)
|
||||
out = self.fc1(cached)
|
||||
out = self.fc2(out)
|
||||
return out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"We found a fake tensor in the exported program constant's list. "
|
||||
"This typically means our tracing system encountered an op that we can't trace through. "
|
||||
"For the potential source, you can refer to following model attribute: cache_layer.lifted_tensor_0. "
|
||||
"Please file an issue on github.",
|
||||
):
|
||||
_ = export(MyModel(), (torch.randn(1, 3, 5),), strict=False)
|
||||
|
||||
def test_export_for_training_with_container_type(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
|
|||
|
|
@ -221,10 +221,23 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
|||
# return any attributes of a module that are not standard attributes
|
||||
return {k: v for k, v in mod.__dict__.items() if k not in STD_ATTRS}
|
||||
|
||||
def _get_all_module_attributes(mod):
|
||||
# return attributes from all modules and submodules
|
||||
result = {}
|
||||
for name, submodule in mod.named_modules():
|
||||
result[name] = _get_attributes(submodule)
|
||||
return result
|
||||
|
||||
def _restore_all_module_attributes(mod, snapshot):
|
||||
# restore attributes to all modules and submodules
|
||||
for name, submodule in mod.named_modules():
|
||||
if name in snapshot:
|
||||
submodule.__dict__.update(snapshot[name])
|
||||
|
||||
# save state of attributes before enter
|
||||
snapshot = pytree.tree_map(
|
||||
lambda x: x,
|
||||
_get_attributes(mod),
|
||||
_get_all_module_attributes(mod),
|
||||
is_leaf=lambda x: type(x) in _pytree_subclasses_that_lose_info,
|
||||
)
|
||||
try:
|
||||
|
|
@ -236,41 +249,54 @@ def _detect_attribute_assignment(mod: torch.nn.Module):
|
|||
|
||||
def _collect_assigned_tensor_attributes(kp, v, _v):
|
||||
if _v is not v:
|
||||
attr, *rest = kp
|
||||
module_name, attr, *rest = kp
|
||||
if isinstance(v, torch.Tensor):
|
||||
module_prefix = f"{module_name.key}." if module_name.key else ""
|
||||
assigned_tensor_attributes.append(
|
||||
f"self.{attr.key}{pytree.keystr(rest)}"
|
||||
f"self.{module_prefix}{attr.key}{pytree.keystr(rest)}"
|
||||
)
|
||||
# TODO(avik): Assigning all other types are allowed right now.
|
||||
# Maybe in the future we want to limit this to primitive types?
|
||||
return v
|
||||
|
||||
new_attrs = _get_attributes(mod)
|
||||
if len(new_attrs) != len(snapshot):
|
||||
added_attrs = new_attrs.keys() - snapshot.keys()
|
||||
deleted_attrs = snapshot.keys() - new_attrs.keys()
|
||||
new_attrs = _get_all_module_attributes(mod)
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {added_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
# Check for added/deleted attributes across all modules
|
||||
for module_name in snapshot.keys() | new_attrs.keys():
|
||||
old_module_attrs = snapshot.get(module_name, {})
|
||||
new_module_attrs = new_attrs.get(module_name, {})
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {deleted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
if len(new_module_attrs) != len(old_module_attrs):
|
||||
added_attrs = new_module_attrs.keys() - old_module_attrs.keys()
|
||||
deleted_attrs = old_module_attrs.keys() - new_module_attrs.keys()
|
||||
|
||||
module_prefix = f"self.{module_name}." if module_name else "self."
|
||||
|
||||
if len(added_attrs) > 0:
|
||||
formatted_attrs = [f"{module_prefix}{attr}" for attr in added_attrs]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were created in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
if len(deleted_attrs) > 0:
|
||||
formatted_attrs = [
|
||||
f"{module_prefix}{attr}" for attr in deleted_attrs
|
||||
]
|
||||
raise ValueError(
|
||||
f"During torch.export, following attrs were deleted in the model.forward: {formatted_attrs} "
|
||||
f"Such attributes must be registered as buffers using the `register_buffer` "
|
||||
f"API and must be initialized at model.__init__ "
|
||||
f"(https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_buffer)."
|
||||
)
|
||||
|
||||
pytree.tree_map_with_path(
|
||||
_collect_assigned_tensor_attributes, snapshot, new_attrs
|
||||
)
|
||||
# restore state of all attributes (including, e.g., of primitive types)
|
||||
mod.__dict__.update(snapshot)
|
||||
_restore_all_module_attributes(mod, snapshot)
|
||||
|
||||
if assigned_tensor_attributes:
|
||||
if len(assigned_tensor_attributes) > 1:
|
||||
|
|
|
|||
|
|
@ -1850,6 +1850,14 @@ def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node:
|
|||
return next(iter(node for node in gm.graph.nodes if node.name == name))
|
||||
|
||||
|
||||
def _is_bogus_const_name(name: str):
|
||||
splitted_names = name.split(".")
|
||||
if len(splitted_names) < 1:
|
||||
return True
|
||||
|
||||
return splitted_names[-1].startswith("lifted_tensor")
|
||||
|
||||
|
||||
def _non_strict_export(
|
||||
mod: torch.nn.Module,
|
||||
args: tuple[Any, ...],
|
||||
|
|
@ -2049,6 +2057,11 @@ def _export_for_training(
|
|||
|
||||
original_state_dict = _get_original_state_dict(mod)
|
||||
|
||||
has_ambient_mode = False
|
||||
if not strict:
|
||||
flat_args, _ = pytree.tree_flatten((args, kwargs))
|
||||
has_ambient_mode = torch._guards.detect_fake_mode(flat_args) is not None
|
||||
|
||||
# Call the appropriate export function based on the strictness of tracing.
|
||||
export_func = _strict_export if strict else _non_strict_export
|
||||
|
||||
|
|
@ -2063,6 +2076,21 @@ def _export_for_training(
|
|||
_to_aten_func=_export_to_aten_ir_make_fx,
|
||||
)
|
||||
|
||||
# If we are tracing with fake inputs, it is expected to
|
||||
# see fake tensor constants.
|
||||
if not strict and not has_ambient_mode:
|
||||
for const, val in export_artifact.aten.constants.items():
|
||||
if isinstance(
|
||||
val, torch._subclasses.fake_tensor.FakeTensor
|
||||
) and _is_bogus_const_name(const):
|
||||
raise RuntimeError(
|
||||
f"We found a fake tensor in the exported program constant's list. "
|
||||
f"This typically means our tracing system encountered an op that "
|
||||
f"we can't trace through. For the potential source, you can refer to "
|
||||
f"following model attribute: {const}. "
|
||||
f"Please file an issue on github. "
|
||||
)
|
||||
|
||||
export_graph_signature = export_artifact.aten.sig
|
||||
|
||||
forward_arg_names = _get_forward_arg_names(mod, args, kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user