mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ONNX] Add a test to backed_size_oblivious patch in onnx (#166196)
Follow-up https://github.com/pytorch/pytorch/pull/166151 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166196 Approved by: https://github.com/justinchuby
This commit is contained in:
parent
32ac38f85d
commit
b04173be9b
|
|
@ -572,6 +572,43 @@ class TestCustomTranslationTable(common_utils.TestCase):
|
|||
self.assertIn("Add", all_nodes_decomp)
|
||||
self.assertNotIn("Sub", all_nodes_decomp)
|
||||
|
||||
def test_01_specialization_with_run_decomp_is_supported(self):
|
||||
# Phi3RMSNorm changes and redo shape inference after `run_decompositions` call
|
||||
# We ned this test to make sure everything we do on fx graph is covered by
|
||||
# backed_size_oblivious
|
||||
class Phi3RMSNorm(torch.nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
"""
|
||||
Phi3RMSNorm is equivalent to T5LayerNorm
|
||||
"""
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(
|
||||
variance + self.variance_epsilon
|
||||
)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
op = torch.onnx.export(
|
||||
Phi3RMSNorm(256).eval(),
|
||||
args=(),
|
||||
kwargs={"hidden_states": torch.rand((1, 32, 256))},
|
||||
dynamic_shapes={
|
||||
"hidden_states": {
|
||||
0: "batch_size",
|
||||
1: "seq_len",
|
||||
}
|
||||
},
|
||||
dynamo=True,
|
||||
)
|
||||
# batch size is not fixed to 1
|
||||
self.assertNotEqual(op.model.graph.outputs[0].shape[0], 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
common_utils.run_tests()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user