[ONNX] Cover all FX passes into backed size oblivious (#166151)

Found a bug that after `run_decomposition()`, the shape could be fixed to 1. It's caused by the fact that all FX graph (related to shape inference) surgery should happen inside backed size oblivious patch.

```python
import torch
from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm

# Previous to this PR, this will generate a fixed batch size
op = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)

# It is dynamic when it's only in torch.export
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
    ep = torch.onnx.export(
    Phi3RMSNorm(256).eval(),
    args=(),
    kwargs={"hidden_states": torch.rand((1, 32, 256))},
    dynamic_shapes={"hidden_states": {0: torch.export.Dim.DYNAMIC, 1: torch.export.Dim.DYNAMIC}},
)
# But when run_decomposition is called outside of the patch, it is static.
# ep = ep.run_decompositions()
print(ep)

```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/166151
Approved by: https://github.com/justinchuby
This commit is contained in:
Ti-Tai Wang 2025-10-24 03:25:16 +00:00 committed by PyTorch MergeBot
parent 5a4997dcae
commit c12293dcbe

View File

@ -988,8 +988,14 @@ def _prepare_exported_program_for_export(
) -> torch.export.ExportedProgram:
"""Decompose and apply pre-export transformations to the exported program."""
with (
# Support the dynamism with 0/1 input dim
torch.fx.experimental._config.patch(backed_size_oblivious=True), # type: ignore[attr-defined]
):
# Decompose the graph given the implemented torch ops in ONNX
exported_program = _fx_passes.decompose_with_registry(exported_program, registry)
exported_program = _fx_passes.decompose_with_registry(
exported_program, registry
)
graph_module = exported_program.graph_module
# Include explicit type promotion nodes