pytorch/torch/_inductor/quantized_lowerings.py
sanchitintel f951fcd1d7 Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)
## Summary

As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with #131310.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-08-14 03:14:45 +00:00

93 lines
2.9 KiB
Python

# mypy: allow-untyped-defs
import logging
import torch
from torch._inductor.kernel.mm_common import mm_args
from . import config as inductor_config, lowering
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
from .codegen.cpp_utils import create_epilogue_with_attr
from .lowering import expand, register_lowering
from .select_algorithm import (
autotune_select_algorithm,
ExternKernelChoice,
realize_inputs,
)
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template
log = logging.getLogger(__name__)
aten__weight_int8pack_mm = ExternKernelChoice(
torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False
)
quantized = torch.ops.quantized
_quantized = torch.ops._quantized
aten = torch.ops.aten
def register_quantized_ops():
lowering.add_needs_realized_inputs(
[
quantized.max_pool2d,
_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16,
_quantized.wrapped_fbgemm_linear_fp16_weight,
]
)
lowering.make_fallback(quantized.max_pool2d)
lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16)
lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight)
def register_woq_mm_ops():
@register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None)
def int8pack_mm(input, weight, scale, *, layout=None):
_, _, _, layout, mat1, mat2 = mm_args(
input, weight, layout=layout, mat2_transposed=True
)
assert (
mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float]
and mat2.get_dtype() == torch.int8
)
aten_layout = layout
# options to tune from
choices = (
[aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)]
if use_aten_gemm_kernels()
else []
)
# scale is applied as an epilogue, and the scale tensor is expanded (with a view op)
# for broadcasting, as it's 1D.
def _mul_epilogue(buf):
return create_epilogue_with_attr(
buf, "mul", other=realize_inputs(expand(scale, layout.size))
)
if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True):
CppPackedGemmTemplate.add_choices(
choices,
aten_layout,
[mat1, mat2, scale],
trans_w=True,
epilogue_creator=_mul_epilogue,
)
if (
len(choices) == 0
and inductor_config.autotune_fallback_to_aten
and not use_aten_gemm_kernels()
):
log.warning("No choices for GEMM, using ATen backend as fallback")
return aten__weight_int8pack_mm.bind(
(mat1, mat2, scale), aten_layout
).output_node()
return autotune_select_algorithm(
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
)