mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix the parity of original and exported module parameters (#160600)
## Problem Fixing parameter mismatch issue during torch.export with strict mode (see "How to reproduce the issue" section below): When there are two attribute mapping to the same tensor, the strict mode will 1. Have a standard param buffer table to standardize the name (bug happens [here](f861dc1826/torch/export/_trace.py (L356))! when 2 parameter have same id(param), the latter name will overwrite the previous name) 2. [Update](f861dc1826/torch/export/_trace.py (L1481)) exported signature with updated standard FQN (problematic) 3. When getting exported_program.module(), it will call [_unlift_exported_program_lifted_states](f861dc1826/torch/export/exported_program.py (L1297)) to recover attribute from exported signature where the parameter name is defined and standardized Then the named_parameter of this module will have overwritten name instead of original name ## How to reproduce the issue? reproduce issue shared by @taotaohuang001 torch version: 2.8.0 ```python import torch from torch import nn # ---- Toy model with embedding weight sharing (aliasing) ---- class Toy(nn.Module): def __init__(self): super().__init__() self.embedding_layers = nn.ModuleDict() tbl = nn.Embedding(100, 8) self.embedding_layers["ActorId"] = tbl # Alias: reuse the SAME module instance for another feature self.embedding_layers["RootActorId"] = self.embedding_layers["ActorId"] self.proj = nn.Linear(16, 1) def forward(self, feats: dict[str, torch.Tensor]): e1 = self.embedding_layers["ActorId"](feats["ActorId"]) e2 = self.embedding_layers["RootActorId"](feats["RootActorId"]) return self.proj(torch.cat([e1, e2], dim=-1)) torch.manual_seed(0) m = Toy().eval() # Show pre-export parameter names (canonicalized; shared weight appears once) print("PRE-EXPORT named_parameters:") print([name for name, _ in m.named_parameters()]) # Sanity: the two feature names point to the same weight object w1 = m.embedding_layers["ActorId"].weight w2 = m.embedding_layers["RootActorId"].weight print("PRE-EXPORT alias -> same object:", w1 is w2, "| same storage:", w1.data_ptr() == w2.data_ptr()) # Example inputs (dict structure will be captured by export) ex_in = { "ActorId": torch.randint(0, 100, (4,)), "RootActorId": torch.randint(0, 100, (4,)), } # ---- Export (in memory) and materialize the runnable module ---- ep = torch.export.export(m, (ex_in,), strict=True) gm = ep.module() # GraphModule with new (canonical) parameter names print("\nPOST-EXPORT named_parameters (GraphModule):") post_names = [name for name, _ in gm.named_parameters()] print(post_names) # Prove alias persists after export: run fwd/bwd and check a single grad tensor exists out = gm(ex_in).sum() out.backward() # Find the embedding weight in the exported module by shape (100, 8) emb_names = [name for name, p in gm.named_parameters() if p.shape == torch.Size([100, 8])] print("\nEmbedding param (post-export) canonical name:", emb_names[0] if emb_names else "<not found>") # Show that only one grad exists for the shared table for name, p in gm.named_parameters(): if p.grad is not None and p.shape == torch.Size([100, 8]): print("Grad present on shared embedding weight:", name, "| grad shape:", tuple(p.grad.shape)) break ``` And you will see parameters are different before and after export ``` PRE-EXPORT named_parameters: ['embedding_layers.ActorId.weight', 'proj.weight', 'proj.bias'] PRE-EXPORT alias -> same object: True | same storage: True POST-EXPORT named_parameters (GraphModule): ['embedding_layers.RootActorId.weight', 'proj.weight', 'proj.bias'] Embedding param (post-export) canonical name: embedding_layers.RootActorId.weight Grad present on shared embedding weight: embedding_layers.RootActorId.weight | grad shape: (100, 8) ``` ## Solution Fixing this issue by making sure latter named parameter will not overwrite the `param_buffer_table` when original model's named parameter already maps to certain parameter. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160600 Approved by: https://github.com/angelayi
This commit is contained in:
parent
3e210f90c2
commit
ffa1ce7650
|
|
@ -16163,6 +16163,29 @@ def forward(self, x):
|
|||
wrapper = Wrapper(pyt_model, example_inputs)
|
||||
wrapper.forward()
|
||||
|
||||
def test_strict_export_with_shared_parameters(self):
|
||||
"""Test that parameter names are preserved when there are shared parameters with the same name."""
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.n1 = torch.nn.Parameter(torch.ones(3))
|
||||
self.n2 = self.n1
|
||||
|
||||
def forward(self, x):
|
||||
res1 = x * self.n1
|
||||
res2 = x * self.n2
|
||||
return res1 + res2
|
||||
|
||||
m = M()
|
||||
ep = torch.export.export(m, (torch.ones(3),), strict=True)
|
||||
gm = ep.module()
|
||||
|
||||
# Check that named_parameters are preserved
|
||||
original_param_names = [name for name, _ in m.named_parameters()]
|
||||
exported_param_names = [name for name, _ in gm.named_parameters()]
|
||||
self.assertEqual(original_param_names, exported_param_names)
|
||||
|
||||
|
||||
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
||||
class TestExportCustomClass(TorchTestCase):
|
||||
|
|
|
|||
|
|
@ -384,6 +384,8 @@ def _get_param_buffer_mapping(
|
|||
param_lookup: dict[int, str] = {}
|
||||
buffer_lookup: dict[int, str] = {}
|
||||
for name, param in original_module.named_parameters(remove_duplicate=False):
|
||||
if param_lookup.get(id(param)) is None:
|
||||
# we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module.
|
||||
param_lookup[id(param)] = name
|
||||
for name, buffer in original_module.named_buffers(remove_duplicate=False):
|
||||
buffer_lookup[id(buffer)] = name
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user