mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[ONNX] Cover all FX passes into backed size oblivious (#166151)
Found a bug that after `run_decomposition()`, the shape could be fixed to 1. It's caused by the fact that all FX graph (related to shape inference) surgery should happen inside backed size oblivious patch.
```python
import torch
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm
# Previous to this PR, this will generate a fixed batch size
op = torch.onnx.export(
Phi3RMSNorm(256).eval(),
args=(),
kwargs={"hidden_states": torch.rand((1, 32, 256))},
dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# It is dynamic when it's only in torch.export
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
ep = torch.onnx.export(
Phi3RMSNorm(256).eval(),
args=(),
kwargs={"hidden_states": torch.rand((1, 32, 256))},
dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# But when run_decomposition is called outside of the patch, it is static.
# ep = ep.run_decompositions()
print(ep)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166151
Approved by: https://github.com/justinchuby
This commit is contained in:
parent
5a4997dcae
commit
c12293dcbe
|
|
@ -988,16 +988,22 @@ def _prepare_exported_program_for_export(
|
|||
) -> torch.export.ExportedProgram:
|
||||
"""Decompose and apply pre-export transformations to the exported program."""
|
||||
|
||||
# Decompose the graph given the implemented torch ops in ONNX
|
||||
exported_program = _fx_passes.decompose_with_registry(exported_program, registry)
|
||||
with (
|
||||
# Support the dynamism with 0/1 input dim
|
||||
torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined]
|
||||
):
|
||||
# Decompose the graph given the implemented torch ops in ONNX
|
||||
exported_program = _fx_passes.decompose_with_registry(
|
||||
exported_program, registry
|
||||
)
|
||||
|
||||
graph_module = exported_program.graph_module
|
||||
# Include explicit type promotion nodes
|
||||
_fx_passes.insert_type_promotion_nodes(graph_module)
|
||||
graph_module = _fx_passes.remove_assertion_nodes(graph_module)
|
||||
# Reassign the graph module to save some runtime.
|
||||
exported_program._graph_module = graph_module
|
||||
return exported_program
|
||||
graph_module = exported_program.graph_module
|
||||
# Include explicit type promotion nodes
|
||||
_fx_passes.insert_type_promotion_nodes(graph_module)
|
||||
graph_module = _fx_passes.remove_assertion_nodes(graph_module)
|
||||
# Reassign the graph module to save some runtime.
|
||||
exported_program._graph_module = graph_module
|
||||
return exported_program
|
||||
|
||||
|
||||
def _get_scope_name(scoped_name: str) -> tuple[str, str]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user