[export] Fix tensor_constant and buffer naming conflicts in TS converter (#148803)

Summary: In TS converter, tensor constants are traced as BUFFER and later we will convert them back to CONSTANT_TENSOR. So we need to prevent naming conflicts during lift constant pass.

Test Plan: CI

Differential Revision: D70826426

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148803
Approved by: https://github.com/angelayi
This commit is contained in:
Yiming Zhou 2025-03-14 00:38:12 +00:00 committed by PyTorch MergeBot
parent 49570cb402
commit 15cd6921a5
3 changed files with 31 additions and 0 deletions

View File

@ -1463,6 +1463,26 @@ class TestConverter(TestCase):
inp = (torch.randn(1, 10),)
self._check_equal_ts_ep_converter(m, inp, ["script"])
def test_ts2ep_convert_quantized_model_with_opcontext_and_constant(self):
class M(torch.nn.Module):
def __init__(self, linear_op):
super().__init__()
self.linear_op = linear_op
def forward(self, x):
x = torch.ops.prepacked.linear_clamp_run(
x + torch.ones(1), self.linear_op
)
return x
linear_op = torch.ops.prepacked.linear_clamp_prepack(
torch.randn(10, 10), torch.randn(10)
)
m = M(linear_op)
inp = (torch.randn(1, 10),)
self._check_equal_ts_ep_converter(m, inp, ["script"])
if __name__ == "__main__":
run_tests()

View File

@ -1549,6 +1549,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
name_to_constant[spec.target], torch.Tensor
), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
spec.kind = InputKind.CONSTANT_TENSOR
spec.persistent = None
ep.verifier().check(ep)
return ep

View File

@ -148,11 +148,13 @@ def lift_constants_pass(
)
first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes))
used_target_names = set()
for node in gm.graph.nodes:
if node.op == "placeholder":
if node.name in graph_signature.user_inputs:
first_user_input = node
break
used_target_names.add(inputs[first_user_input_loc].target)
first_user_input_loc += 1
# If we ever hit here, it means that
# there was no user input so the constants
@ -194,6 +196,10 @@ def lift_constants_pass(
else:
constant_name = f"lifted_custom_{num_custom_obj}"
constant_fqn = get_constant_fqn(node, constant_name)
while constant_fqn in used_target_names:
num_custom_obj += 1
constant_name = f"lifted_custom_{num_custom_obj}"
constant_fqn = get_constant_fqn(node, constant_name)
num_custom_obj += 1
elif isinstance(constant_val, torch.Tensor):
# Remove the parameterness of constant_val
@ -212,6 +218,10 @@ def lift_constants_pass(
else:
constant_name = f"lifted_tensor_{num_tensor_constants}"
constant_fqn = get_constant_fqn(node, constant_name)
while constant_fqn in used_target_names:
num_tensor_constants += 1
constant_name = f"lifted_tensor_{num_tensor_constants}"
constant_fqn = get_constant_fqn(node, constant_name)
num_tensor_constants += 1
elif isinstance(constant_val, torch.fx.GraphModule):
continue