mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ONNX] Add a test to backed_size_oblivious patch in onnx (#166196)
Follow-up https://github.com/pytorch/pytorch/pull/166151 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196 Approved by: https://github.com/justinchuby
This commit is contained in:
parent
32ac38f85d
commit
b04173be9b
|
|
@ -572,6 +572,43 @@ class TestCustomTranslationTable(common_utils.TestCase):
|
||||||
self.assertIn("Add", all_nodes_decomp)
|
self.assertIn("Add", all_nodes_decomp)
|
||||||
self.assertNotIn("Sub", all_nodes_decomp)
|
self.assertNotIn("Sub", all_nodes_decomp)
|
||||||
|
|
||||||
|
def test_01_specialization_with_run_decomp_is_supported(self):
|
||||||
|
# Phi3RMSNorm changes and redo shape inference after `run_decompositions` call
|
||||||
|
# We ned this test to make sure everything we do on fx graph is covered by
|
||||||
|
# backed_size_oblivious
|
||||||
|
class Phi3RMSNorm(torch.nn.Module):
|
||||||
|
def __init__(self, hidden_size, eps=1e-6):
|
||||||
|
"""
|
||||||
|
Phi3RMSNorm is equivalent to T5LayerNorm
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||||
|
self.variance_epsilon = eps
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
input_dtype = hidden_states.dtype
|
||||||
|
hidden_states = hidden_states.to(torch.float32)
|
||||||
|
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||||
|
hidden_states = hidden_states * torch.rsqrt(
|
||||||
|
variance + self.variance_epsilon
|
||||||
|
)
|
||||||
|
return self.weight * hidden_states.to(input_dtype)
|
||||||
|
|
||||||
|
op = torch.onnx.export(
|
||||||
|
Phi3RMSNorm(256).eval(),
|
||||||
|
args=(),
|
||||||
|
kwargs={"hidden_states": torch.rand((1, 32, 256))},
|
||||||
|
dynamic_shapes={
|
||||||
|
"hidden_states": {
|
||||||
|
0: "batch_size",
|
||||||
|
1: "seq_len",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
dynamo=True,
|
||||||
|
)
|
||||||
|
# batch size is not fixed to 1
|
||||||
|
self.assertNotEqual(op.model.graph.outputs[0].shape[0], 1)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
common_utils.run_tests()
|
common_utils.run_tests()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user