[ONNX] Add a test to backed_size_oblivious patch in onnx (#166196)

Follow-up https://github.com/pytorch/pytorch/pull/166151

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang 2025-10-24 22:47:07 +00:00 committed by PyTorch MergeBot
parent 32ac38f85d
commit b04173be9b

View File

@ -572,6 +572,43 @@ class TestCustomTranslationTable(common_utils.TestCase):
self.assertIn("Add", all_nodes_decomp) self.assertIn("Add", all_nodes_decomp)
self.assertNotIn("Sub", all_nodes_decomp) self.assertNotIn("Sub", all_nodes_decomp)
def test_01_specialization_with_run_decomp_is_supported(self):
# Phi3RMSNorm changes and redo shape inference after `run_decompositions` call
# We ned this test to make sure everything we do on fx graph is covered by
# backed_size_oblivious
class Phi3RMSNorm(torch.nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Phi3RMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(
variance + self.variance_epsilon
)
return self.weight * hidden_states.to(input_dtype)
op = torch.onnx.export(
Phi3RMSNorm(256).eval(),
args=(),
kwargs={"hidden_states": torch.rand((1, 32, 256))},
dynamic_shapes={
"hidden_states": {
0: "batch_size",
1: "seq_len",
}
},
dynamo=True,
)
# batch size is not fixed to 1
self.assertNotEqual(op.model.graph.outputs[0].shape[0], 1)
if __name__ == "__main__": if __name__ == "__main__":
common_utils.run_tests() common_utils.run_tests()