pytorch/torch/ao/quantization/pt2e/lowering.py
Tugsbayasgalan (Tugsuu) Manlaibaatar de05dbc39c Replace export_for_training with export (#162396)
Summary: replace export_for_training with epxort

Test Plan:
CI

Rollback Plan:

Differential Revision: D81935792

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162396
Approved by: https://github.com/angelayi, https://github.com/jerryzh168
2025-09-10 14:19:34 +00:00

61 lines
1.8 KiB
Python

import torch
from torch._inductor.constant_folding import constant_fold
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
__all__ = [
"lower_pt2e_quantized_to_x86",
]
def lower_pt2e_quantized_to_x86(
model: torch.fx.GraphModule,
example_inputs: tuple[torch.Tensor, ...],
) -> torch.fx.GraphModule:
"""Lower a PT2E-qantized model to x86 backend.
Args:
* `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
* `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
Return:
A GraphModule lowered to x86 backend.
"""
def _post_autograd_decomp_table(): # type: ignore[no-untyped-def]
decomp_table = torch.export.default_decompositions()
# if we are post-autograd, we shouldn't
# decomp prim ops.
for k in list(decomp_table.keys()):
if not torch._export.utils._is_cia_op(k):
del decomp_table[k]
return decomp_table
def _node_replace(m): # type: ignore[no-untyped-def]
# Replace aten.t(x) with aten.permute(x, [1, 0])
aten = torch.ops.aten
g = m.graph
for node in g.nodes:
if node.target == aten.t.default:
with g.inserting_before(node):
x = node.args[0]
dims = [1, 0]
perm_node = g.call_function(aten.permute.default, args=(x, dims))
node.replace_all_uses_with(perm_node)
g.erase_node(node)
g.lint()
m.recompile()
lowered_model = (
torch.export.export(model, example_inputs, strict=True)
.run_decompositions(_post_autograd_decomp_table())
.module()
)
_node_replace(lowered_model)
freezing_passes(lowered_model, example_inputs)
constant_fold(lowered_model)
return lowered_model