mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[export] Fix tensor_constant and buffer naming conflicts in TS converter (#148803)
Summary: In TS converter, tensor constants are traced as BUFFER and later we will convert them back to CONSTANT_TENSOR. So we need to prevent naming conflicts during lift constant pass. Test Plan: CI Differential Revision: D70826426 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148803 Approved by: https://github.com/angelayi
This commit is contained in:
parent
49570cb402
commit
15cd6921a5
|
|
@ -1463,6 +1463,26 @@ class TestConverter(TestCase):
|
|||
inp = (torch.randn(1, 10),)
|
||||
self._check_equal_ts_ep_converter(m, inp, ["script"])
|
||||
|
||||
def test_ts2ep_convert_quantized_model_with_opcontext_and_constant(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, linear_op):
|
||||
super().__init__()
|
||||
self.linear_op = linear_op
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.ops.prepacked.linear_clamp_run(
|
||||
x + torch.ones(1), self.linear_op
|
||||
)
|
||||
return x
|
||||
|
||||
linear_op = torch.ops.prepacked.linear_clamp_prepack(
|
||||
torch.randn(10, 10), torch.randn(10)
|
||||
)
|
||||
|
||||
m = M(linear_op)
|
||||
inp = (torch.randn(1, 10),)
|
||||
self._check_equal_ts_ep_converter(m, inp, ["script"])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -1549,6 +1549,7 @@ DEBUG: (TORCH_LOGS="+export" <cmd>), additionally
|
|||
name_to_constant[spec.target], torch.Tensor
|
||||
), f"{type(name_to_constant[spec.target])} has been erroneously marked as buffer"
|
||||
spec.kind = InputKind.CONSTANT_TENSOR
|
||||
spec.persistent = None
|
||||
ep.verifier().check(ep)
|
||||
|
||||
return ep
|
||||
|
|
|
|||
|
|
@ -148,11 +148,13 @@ def lift_constants_pass(
|
|||
)
|
||||
|
||||
first_user_input_loc, first_user_input = 0, next(iter(gm.graph.nodes))
|
||||
used_target_names = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "placeholder":
|
||||
if node.name in graph_signature.user_inputs:
|
||||
first_user_input = node
|
||||
break
|
||||
used_target_names.add(inputs[first_user_input_loc].target)
|
||||
first_user_input_loc += 1
|
||||
# If we ever hit here, it means that
|
||||
# there was no user input so the constants
|
||||
|
|
@ -194,6 +196,10 @@ def lift_constants_pass(
|
|||
else:
|
||||
constant_name = f"lifted_custom_{num_custom_obj}"
|
||||
constant_fqn = get_constant_fqn(node, constant_name)
|
||||
while constant_fqn in used_target_names:
|
||||
num_custom_obj += 1
|
||||
constant_name = f"lifted_custom_{num_custom_obj}"
|
||||
constant_fqn = get_constant_fqn(node, constant_name)
|
||||
num_custom_obj += 1
|
||||
elif isinstance(constant_val, torch.Tensor):
|
||||
# Remove the parameterness of constant_val
|
||||
|
|
@ -212,6 +218,10 @@ def lift_constants_pass(
|
|||
else:
|
||||
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
||||
constant_fqn = get_constant_fqn(node, constant_name)
|
||||
while constant_fqn in used_target_names:
|
||||
num_tensor_constants += 1
|
||||
constant_name = f"lifted_tensor_{num_tensor_constants}"
|
||||
constant_fqn = get_constant_fqn(node, constant_name)
|
||||
num_tensor_constants += 1
|
||||
elif isinstance(constant_val, torch.fx.GraphModule):
|
||||
continue
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user