Fix implicit state dict modification (#151436)

Summary: Previously we were modyfing ep.state_dict while runnning decomp which it shouldn't

Test Plan: CI

Fixes: https://github.com/pytorch/pytorch/issues/151366

Differential Revision: D73102315

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151436
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar 2025-04-18 00:58:52 +00:00 committed by PyTorch MergeBot
parent 34266836d5
commit c2a202169d
2 changed files with 39 additions and 7 deletions

View File

@ -4789,6 +4789,33 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
self._test_export_same_as_eager(kw_func, args)
@testing.expectedFailureCppRuntime
@testing.expectedFailureLegacyExportNonStrict
@testing.expectedFailureLegacyExportStrict
def test_export_module(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
self.p2 = torch.nn.Parameter(
CustomTensorPlainOut(
torch.ones(3, 4),
torch.ones(3, 4),
)
)
def forward(self, x):
a = (2 * self.p1 + self.p2).sum()
return x + a
model = Foo()
example_inputs = (torch.randn(3, 4),)
ep = export(model, example_inputs, strict=False)
before = list(ep.state_dict.keys())
ep.run_decompositions()
after = list(ep.state_dict.keys())
self.assertEqual(before, after)
def test_export_func_with_keyword_only_args(self):
class Module(torch.nn.Module):
def forward(self, arg1, arg2, *args, kw1, kw2):

View File

@ -535,20 +535,25 @@ def _decompose_and_get_gm_with_new_signature_constants(
# the state dict of ep.module but ep.module only stores params
# buffers that participate in forward. If we undo this behaviour,
# it would break some downstream users.
for name, p in unwrapped_params_buffers.items():
if name not in wrapped_params_buffers:
ep.state_dict[name] = p
new_state_dict = {
**ep.state_dict,
**{
name: p
for name, p in unwrapped_params_buffers.items()
if name not in wrapped_params_buffers
},
}
for name, p in wrapped_params_buffers.items():
# Buffers can be persistent/non-persistent
if name not in ep.state_dict:
if name not in new_state_dict:
assert not isinstance(p, torch.nn.Parameter)
if name in ep.state_dict:
if name in new_state_dict:
if name not in unwrapped_params_buffers:
ep.state_dict.pop(name)
new_state_dict.pop(name)
return gm, new_graph_signature, ep.state_dict
return gm, new_graph_signature, new_state_dict
old_placeholders = [
node for node in ep.graph_module.graph.nodes if node.op == "placeholder"