mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
fix typo in other folders #166374 #166126 _typos.toml ```bash [files] extend-exclude = ["tools/linter/dictionary.txt"] [default.extend-words] nd = "nd" arange = "arange" Nd = "Nd" GLOBALs = "GLOBALs" hte = "hte" iy = "iy" PN = "PN" Dout = "Dout" optin = "optin" gam = "gam" PTD = "PTD" Sur = "Sur" nin = "nin" tme = "tme" inpt = "inpt" mis = "mis" Raison = "Raison" ouput = "ouput" nto = "nto" Onwer = "Onwer" callibrate = "callibrate" ser = "ser" Metdata = "Metdata" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166606 Approved by: https://github.com/ezyang
61 lines
1.8 KiB
Python
61 lines
1.8 KiB
Python
import torch
|
|
from torch._inductor.constant_folding import constant_fold
|
|
from torch._inductor.fx_passes.freezing_patterns import freezing_passes
|
|
|
|
|
|
__all__ = [
|
|
"lower_pt2e_quantized_to_x86",
|
|
]
|
|
|
|
|
|
def lower_pt2e_quantized_to_x86(
|
|
model: torch.fx.GraphModule,
|
|
example_inputs: tuple[torch.Tensor, ...],
|
|
) -> torch.fx.GraphModule:
|
|
"""Lower a PT2E-quantized model to x86 backend.
|
|
|
|
Args:
|
|
* `model` (torch.fx.GraphModule): a model quantized by PT2E quantization flow.
|
|
* `example_inputs` (tuple[torch.Tensor, ...]): example inputs for the model.
|
|
|
|
Return:
|
|
A GraphModule lowered to x86 backend.
|
|
"""
|
|
|
|
def _post_autograd_decomp_table(): # type: ignore[no-untyped-def]
|
|
decomp_table = torch.export.default_decompositions()
|
|
|
|
# if we are post-autograd, we shouldn't
|
|
# decomp prim ops.
|
|
for k in list(decomp_table.keys()):
|
|
if not torch._export.utils._is_cia_op(k):
|
|
del decomp_table[k]
|
|
|
|
return decomp_table
|
|
|
|
def _node_replace(m): # type: ignore[no-untyped-def]
|
|
# Replace aten.t(x) with aten.permute(x, [1, 0])
|
|
aten = torch.ops.aten
|
|
g = m.graph
|
|
for node in g.nodes:
|
|
if node.target == aten.t.default:
|
|
with g.inserting_before(node):
|
|
x = node.args[0]
|
|
dims = [1, 0]
|
|
perm_node = g.call_function(aten.permute.default, args=(x, dims))
|
|
node.replace_all_uses_with(perm_node)
|
|
g.erase_node(node)
|
|
|
|
g.lint()
|
|
m.recompile()
|
|
|
|
lowered_model = (
|
|
torch.export.export(model, example_inputs, strict=True)
|
|
.run_decompositions(_post_autograd_decomp_table())
|
|
.module()
|
|
)
|
|
_node_replace(lowered_model)
|
|
freezing_passes(lowered_model, example_inputs)
|
|
constant_fold(lowered_model)
|
|
return lowered_model
|