pytorch/torch/ao/quantization/pt2e/lowering.py
linhaifeng 369f2d6951 [3/N] fix typo in other folders (#166606)
fix typo in other folders

#166374
#166126

_typos.toml
```bash
[files]
extend-exclude = ["tools/linter/dictionary.txt"]
[default.extend-words]
nd = "nd"
arange = "arange"
Nd = "Nd"
GLOBALs = "GLOBALs"
hte = "hte"
iy = "iy"
PN = "PN"
Dout = "Dout"
optin = "optin"
gam = "gam"
PTD = "PTD"
Sur = "Sur"
nin = "nin"
tme = "tme"
inpt = "inpt"
mis = "mis"
Raison = "Raison"
ouput = "ouput"
nto = "nto"
Onwer = "Onwer"
callibrate = "callibrate"
ser = "ser"
Metdata = "Metdata"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606
Approved by: https://github.com/ezyang
2025-10-30 10:30:40 +00:00

61 lines
1.8 KiB
Python

import torch
from torch._inductor.constant_folding import constant_fold
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
__all__ = [
"lower_pt2e_quantized_to_x86",
]
def lower_pt2e_quantized_to_x86(
model: torch.fx.GraphModule,
example_inputs: tuple[torch.Tensor, ...],
) -> torch.fx.GraphModule:
"""Lower a PT2E-quantized model to x86 backend.
Args:
* `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
* `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
Return:
A GraphModule lowered to x86 backend.
"""
def _post_autograd_decomp_table(): # type: ignore[no-untyped-def]
decomp_table = torch.export.default_decompositions()
# if we are post-autograd, we shouldn't
# decomp prim ops.
for k in list(decomp_table.keys()):
if not torch._export.utils._is_cia_op(k):
del decomp_table[k]
return decomp_table
def _node_replace(m): # type: ignore[no-untyped-def]
# Replace aten.t(x) with aten.permute(x, [1, 0])
aten = torch.ops.aten
g = m.graph
for node in g.nodes:
if node.target == aten.t.default:
with g.inserting_before(node):
x = node.args[0]
dims = [1, 0]
perm_node = g.call_function(aten.permute.default, args=(x, dims))
node.replace_all_uses_with(perm_node)
g.erase_node(node)
g.lint()
m.recompile()
lowered_model = (
torch.export.export(model, example_inputs, strict=True)
.run_decompositions(_post_autograd_decomp_table())
.module()
)
_node_replace(lowered_model)
freezing_passes(lowered_model, example_inputs)
constant_fold(lowered_model)
return lowered_model