Fix the parity of original and exported module parameters (#160600)

## Problem
Fixing parameter mismatch issue during torch.export with strict mode (see "How to reproduce the issue" section below):

When there are two attribute mapping to the same tensor, the strict mode will
1. Have a standard param buffer table to standardize the name (bug happens [here](f861dc1826/torch/export/_trace.py (L356))! when 2 parameter have same id(param), the latter name will overwrite the previous name)
2. [Update](f861dc1826/torch/export/_trace.py (L1481)) exported signature with updated standard FQN (problematic)
3. When getting exported_program.module(), it will call [_unlift_exported_program_lifted_states](f861dc1826/torch/export/exported_program.py (L1297)) to recover attribute from exported signature where the parameter name is defined and standardized
Then the named_parameter of this module will have overwritten name instead of original name

## How to reproduce the issue?
reproduce issue shared by @taotaohuang001

torch version: 2.8.0
```python
import torch
from torch import nn

# ---- Toy model with embedding weight sharing (aliasing) ----
class Toy(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding_layers = nn.ModuleDict()
        tbl = nn.Embedding(100, 8)
        self.embedding_layers["ActorId"] = tbl
        # Alias: reuse the SAME module instance for another feature
        self.embedding_layers["RootActorId"] = self.embedding_layers["ActorId"]
        self.proj = nn.Linear(16, 1)

    def forward(self, feats: dict[str, torch.Tensor]):
        e1 = self.embedding_layers["ActorId"](feats["ActorId"])
        e2 = self.embedding_layers["RootActorId"](feats["RootActorId"])
        return self.proj(torch.cat([e1, e2], dim=-1))

torch.manual_seed(0)

m = Toy().eval()

# Show pre-export parameter names (canonicalized; shared weight appears once)
print("PRE-EXPORT named_parameters:")
print([name for name, _ in m.named_parameters()])

# Sanity: the two feature names point to the same weight object
w1 = m.embedding_layers["ActorId"].weight
w2 = m.embedding_layers["RootActorId"].weight
print("PRE-EXPORT alias -> same object:", w1 is w2, "| same storage:", w1.data_ptr() == w2.data_ptr())

# Example inputs (dict structure will be captured by export)
ex_in = {
    "ActorId":     torch.randint(0, 100, (4,)),
    "RootActorId": torch.randint(0, 100, (4,)),
}

# ---- Export (in memory) and materialize the runnable module ----
ep = torch.export.export(m, (ex_in,), strict=True)
gm = ep.module()  # GraphModule with new (canonical) parameter names

print("\nPOST-EXPORT named_parameters (GraphModule):")
post_names = [name for name, _ in gm.named_parameters()]
print(post_names)

# Prove alias persists after export: run fwd/bwd and check a single grad tensor exists
out = gm(ex_in).sum()
out.backward()

# Find the embedding weight in the exported module by shape (100, 8)
emb_names = [name for name, p in gm.named_parameters() if p.shape == torch.Size([100, 8])]
print("\nEmbedding param (post-export) canonical name:", emb_names[0] if emb_names else "<not found>")

# Show that only one grad exists for the shared table
for name, p in gm.named_parameters():
    if p.grad is not None and p.shape == torch.Size([100, 8]):
        print("Grad present on shared embedding weight:", name, "| grad shape:", tuple(p.grad.shape))
        break

```

And you will see parameters are different before and after export
```
PRE-EXPORT named_parameters:
['embedding_layers.ActorId.weight', 'proj.weight', 'proj.bias']
PRE-EXPORT alias -> same object: True | same storage: True

POST-EXPORT named_parameters (GraphModule):
['embedding_layers.RootActorId.weight', 'proj.weight', 'proj.bias']

Embedding param (post-export) canonical name: embedding_layers.RootActorId.weight
Grad present on shared embedding weight: embedding_layers.RootActorId.weight | grad shape: (100, 8)

```
## Solution
Fixing this issue by making sure latter named parameter will not overwrite the `param_buffer_table` when original model's named parameter already maps to certain parameter.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160600
Approved by: https://github.com/angelayi
This commit is contained in:
Chen 2025-08-25 19:40:06 +00:00 committed by PyTorch MergeBot
parent 3e210f90c2
commit ffa1ce7650
2 changed files with 26 additions and 1 deletions

View File

@ -16163,6 +16163,29 @@ def forward(self, x):
wrapper = Wrapper(pyt_model, example_inputs)
wrapper.forward()
def test_strict_export_with_shared_parameters(self):
"""Test that parameter names are preserved when there are shared parameters with the same name."""
class M(torch.nn.Module):
def __init__(self):
super().__init__()
self.n1 = torch.nn.Parameter(torch.ones(3))
self.n2 = self.n1
def forward(self, x):
res1 = x * self.n1
res2 = x * self.n2
return res1 + res2
m = M()
ep = torch.export.export(m, (torch.ones(3),), strict=True)
gm = ep.module()
# Check that named_parameters are preserved
original_param_names = [name for name, _ in m.named_parameters()]
exported_param_names = [name for name, _ in gm.named_parameters()]
self.assertEqual(original_param_names, exported_param_names)
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
class TestExportCustomClass(TorchTestCase):

View File

@ -384,7 +384,9 @@ def _get_param_buffer_mapping(
param_lookup: dict[int, str] = {}
buffer_lookup: dict[int, str] = {}
for name, param in original_module.named_parameters(remove_duplicate=False):
param_lookup[id(param)] = name
if param_lookup.get(id(param)) is None:
# we only want to keep the first occurrence of a parameter to guarantee parity of original and traced module.
param_lookup[id(param)] = name
for name, buffer in original_module.named_buffers(remove_duplicate=False):
buffer_lookup[id(buffer)] = name