mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
## Summary As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond). WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations. The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue. Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel. While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. ### Performance #### AMX Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded. In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead. Benchmarked with unit-tests. Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442 The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel. #### AVX2/AVX512 micro-kernels Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437 ### Follow-up 1. int4 WoQ GEMM micro-kernel will also be added in a separate PR. 2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand. E2E perf measurement should be done with #131310. Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887 Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
93 lines
2.9 KiB
Python
93 lines
2.9 KiB
Python
# mypy: allow-untyped-defs
|
|
import logging
|
|
|
|
import torch
|
|
from torch._inductor.kernel.mm_common import mm_args
|
|
|
|
from . import config as inductor_config, lowering
|
|
from .codegen.cpp_gemm_template import CppPackedGemmTemplate
|
|
from .codegen.cpp_utils import create_epilogue_with_attr
|
|
from .lowering import expand, register_lowering
|
|
from .select_algorithm import (
|
|
autotune_select_algorithm,
|
|
ExternKernelChoice,
|
|
realize_inputs,
|
|
)
|
|
from .utils import use_aten_gemm_kernels, use_cpp_packed_gemm_template
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
aten__weight_int8pack_mm = ExternKernelChoice(
|
|
torch._weight_int8pack_mm, "at::_weight_int8pack_mm", has_out_variant=False
|
|
)
|
|
|
|
|
|
quantized = torch.ops.quantized
|
|
_quantized = torch.ops._quantized
|
|
aten = torch.ops.aten
|
|
|
|
|
|
def register_quantized_ops():
|
|
lowering.add_needs_realized_inputs(
|
|
[
|
|
quantized.max_pool2d,
|
|
_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16,
|
|
_quantized.wrapped_fbgemm_linear_fp16_weight,
|
|
]
|
|
)
|
|
|
|
lowering.make_fallback(quantized.max_pool2d)
|
|
lowering.make_fallback(_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16)
|
|
lowering.make_fallback(_quantized.wrapped_fbgemm_linear_fp16_weight)
|
|
|
|
|
|
def register_woq_mm_ops():
|
|
@register_lowering(aten._weight_int8pack_mm, type_promotion_kind=None)
|
|
def int8pack_mm(input, weight, scale, *, layout=None):
|
|
_, _, _, layout, mat1, mat2 = mm_args(
|
|
input, weight, layout=layout, mat2_transposed=True
|
|
)
|
|
assert (
|
|
mat1.get_dtype() in [torch.bfloat16, torch.float16, torch.float]
|
|
and mat2.get_dtype() == torch.int8
|
|
)
|
|
aten_layout = layout
|
|
|
|
# options to tune from
|
|
choices = (
|
|
[aten__weight_int8pack_mm.bind((mat1, mat2, scale), aten_layout)]
|
|
if use_aten_gemm_kernels()
|
|
else []
|
|
)
|
|
|
|
# scale is applied as an epilogue, and the scale tensor is expanded (with a view op)
|
|
# for broadcasting, as it's 1D.
|
|
def _mul_epilogue(buf):
|
|
return create_epilogue_with_attr(
|
|
buf, "mul", other=realize_inputs(expand(scale, layout.size))
|
|
)
|
|
|
|
if use_cpp_packed_gemm_template(aten_layout, mat1, mat2, mat2_transposed=True):
|
|
CppPackedGemmTemplate.add_choices(
|
|
choices,
|
|
aten_layout,
|
|
[mat1, mat2, scale],
|
|
trans_w=True,
|
|
epilogue_creator=_mul_epilogue,
|
|
)
|
|
|
|
if (
|
|
len(choices) == 0
|
|
and inductor_config.autotune_fallback_to_aten
|
|
and not use_aten_gemm_kernels()
|
|
):
|
|
log.warning("No choices for GEMM, using ATen backend as fallback")
|
|
return aten__weight_int8pack_mm.bind(
|
|
(mat1, mat2, scale), aten_layout
|
|
).output_node()
|
|
|
|
return autotune_select_algorithm(
|
|
"_weight_int8pack_mm", choices, [mat1, mat2, scale], aten_layout
|
|
)
|