From b04173be9b44f01eec086f5d70fd02cfa36d89a5 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Fri, 24 Oct 2025 22:47:07 +0000 Subject: [PATCH] [ONNX] Add a test to backed_size_oblivious patch in onnx (#166196) Follow-up https://github.com/pytorch/pytorch/pull/166151 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196 Approved by: https://github.com/justinchuby --- test/onnx/exporter/test_api.py | 37 ++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/test/onnx/exporter/test_api.py b/test/onnx/exporter/test_api.py index 7e6a487e18f..a81b7106084 100644 --- a/test/onnx/exporter/test_api.py +++ b/test/onnx/exporter/test_api.py @@ -572,6 +572,43 @@ class TestCustomTranslationTable(common_utils.TestCase): self.assertIn("Add", all_nodes_decomp) self.assertNotIn("Sub", all_nodes_decomp) + def test_01_specialization_with_run_decomp_is_supported(self): + # Phi3RMSNorm changes and redo shape inference after `run_decompositions` call + # We ned this test to make sure everything we do on fx graph is covered by + # backed_size_oblivious + class Phi3RMSNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + Phi3RMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt( + variance + self.variance_epsilon + ) + return self.weight * hidden_states.to(input_dtype) + + op = torch.onnx.export( + Phi3RMSNorm(256).eval(), + args=(), + kwargs={"hidden_states": torch.rand((1, 32, 256))}, + dynamic_shapes={ + "hidden_states": { + 0: "batch_size", + 1: "seq_len", + } + }, + dynamo=True, + ) + # batch size is not fixed to 1 + self.assertNotEqual(op.model.graph.outputs[0].shape[0], 1) + if __name__ == "__main__": common_utils.run_tests()