mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix implicit state dict modification (#151436)
Summary: Previously we were modyfing ep.state_dict while runnning decomp which it shouldn't Test Plan: CI Fixes: https://github.com/pytorch/pytorch/issues/151366 Differential Revision: D73102315 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151436 Approved by: https://github.com/angelayi
This commit is contained in:
parent
34266836d5
commit
c2a202169d
|
|
@ -4789,6 +4789,33 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
args = (torch.ones(2, 3), torch.ones(3, 4), torch.ones(2, 3), torch.ones(3, 4))
|
||||
self._test_export_same_as_eager(kw_func, args)
|
||||
|
||||
@testing.expectedFailureCppRuntime
|
||||
@testing.expectedFailureLegacyExportNonStrict
|
||||
@testing.expectedFailureLegacyExportStrict
|
||||
def test_export_module(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.p1 = torch.nn.Parameter(torch.ones(3, 4))
|
||||
self.p2 = torch.nn.Parameter(
|
||||
CustomTensorPlainOut(
|
||||
torch.ones(3, 4),
|
||||
torch.ones(3, 4),
|
||||
)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
a = (2 * self.p1 + self.p2).sum()
|
||||
return x + a
|
||||
|
||||
model = Foo()
|
||||
example_inputs = (torch.randn(3, 4),)
|
||||
ep = export(model, example_inputs, strict=False)
|
||||
before = list(ep.state_dict.keys())
|
||||
ep.run_decompositions()
|
||||
after = list(ep.state_dict.keys())
|
||||
self.assertEqual(before, after)
|
||||
|
||||
def test_export_func_with_keyword_only_args(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, arg1, arg2, *args, kw1, kw2):
|
||||
|
|
|
|||
|
|
@ -535,20 +535,25 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
|||
# the state dict of ep.module but ep.module only stores params
|
||||
# buffers that participate in forward. If we undo this behaviour,
|
||||
# it would break some downstream users.
|
||||
for name, p in unwrapped_params_buffers.items():
|
||||
if name not in wrapped_params_buffers:
|
||||
ep.state_dict[name] = p
|
||||
new_state_dict = {
|
||||
**ep.state_dict,
|
||||
**{
|
||||
name: p
|
||||
for name, p in unwrapped_params_buffers.items()
|
||||
if name not in wrapped_params_buffers
|
||||
},
|
||||
}
|
||||
|
||||
for name, p in wrapped_params_buffers.items():
|
||||
# Buffers can be persistent/non-persistent
|
||||
if name not in ep.state_dict:
|
||||
if name not in new_state_dict:
|
||||
assert not isinstance(p, torch.nn.Parameter)
|
||||
|
||||
if name in ep.state_dict:
|
||||
if name in new_state_dict:
|
||||
if name not in unwrapped_params_buffers:
|
||||
ep.state_dict.pop(name)
|
||||
new_state_dict.pop(name)
|
||||
|
||||
return gm, new_graph_signature, ep.state_dict
|
||||
return gm, new_graph_signature, new_state_dict
|
||||
|
||||
old_placeholders = [
|
||||
node for node in ep.graph_module.graph.nodes if node.op == "placeholder"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user