[ONNX] Cover all FX passes into backed size oblivious (#166151)

Found a bug that after `run_decomposition()`, the shape could be fixed to 1. It's caused by the fact that all FX graph (related to shape inference) surgery should happen inside backed size oblivious patch.

```python
import torch
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm

# Previous to this PR, this will generate a fixed batch size
op = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)

# It is dynamic when it's only in torch.export
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# But when run_decomposition is called outside of the patch, it is static.
# ep = ep.run_decompositions()
print(ep)

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166151
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang 2025-10-24 03:25:16 +00:00 committed by PyTorch MergeBot
parent 5a4997dcae
commit c12293dcbe

View File

@ -988,16 +988,22 @@ def _prepare_exported_program_for_export(
) -> torch.export.ExportedProgram:
"""Decompose and apply pre-export transformations to the exported program."""
# Decompose the graph given the implemented torch ops in ONNX
exported_program = _fx_passes.decompose_with_registry(exported_program, registry)
with (
# Support the dynamism with 0/1 input dim
torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined]
):
# Decompose the graph given the implemented torch ops in ONNX
exported_program = _fx_passes.decompose_with_registry(
exported_program, registry
)
graph_module = exported_program.graph_module
# Include explicit type promotion nodes
_fx_passes.insert_type_promotion_nodes(graph_module)
graph_module = _fx_passes.remove_assertion_nodes(graph_module)
# Reassign the graph module to save some runtime.
exported_program._graph_module = graph_module
return exported_program
graph_module = exported_program.graph_module
# Include explicit type promotion nodes
_fx_passes.insert_type_promotion_nodes(graph_module)
graph_module = _fx_passes.remove_assertion_nodes(graph_module)
# Reassign the graph module to save some runtime.
exported_program._graph_module = graph_module
return exported_program
def _get_scope_name(scoped_name: str) -> tuple[str, str]: