Commit Graph

74 Commits

Author SHA1 Message Date
Maggie Moss
9940e894ea Fix pyrefly ignore syntax in _inductor (#166247)
Ensures pyrefly ignores only ignore the intended error code.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166247
Approved by: https://github.com/oulgen
2025-10-27 02:48:42 +00:00
Yuanyuan Chen
fb64da0791 [2/N] Use "is" in python type comparison (#165142)
This is follow-up of #165037. It generally recommended to use `is/is not` to compare types. Therefore this series of changes apply this suggestion in the code base, and it aims to finally enabling related linter checks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165142
Approved by: https://github.com/albanD
2025-10-10 15:36:44 +00:00
Maggie Moss
9944cac6e6 Add suppressions to torch/_inductor (#165062)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Split this directory into two PRs to keep them from being too large.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165062
Approved by: https://github.com/oulgen, https://github.com/mlazos
2025-10-09 20:34:20 +00:00
Yuanyuan Chen
f7ab8a2710 [1/N] Fix ruff warnings (#164333)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164333
Approved by: https://github.com/albanD
2025-10-01 16:48:32 +00:00
wengshiy
c8c221c0b3 [Inductor][Float8] Add float8_e4m3fn into assertion dtype list. (#157684)
Fix assert issue.
Add float8_e4m3fn into dtype list.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157684
Approved by: https://github.com/Xia-Weiwen, https://github.com/leslie-fang-intel, https://github.com/jansel
2025-07-15 06:02:01 +00:00
Zeina Migeed
4f5be56612 [Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)
There are 31 places that I spotted which construct literal dictionaries.

This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2025-07-08 18:10:33 +00:00
Tom Ritchford
e3afbb0362 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-30 15:56:35 +00:00
Xuehai Pan
6ff6630375 [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-23 02:57:12 +00:00
PyTorch MergeBot
f1331f3f1b Revert "[BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)"
This reverts commit 3627270bdf.

Reverted https://github.com/pytorch/pytorch/pull/156313 on behalf of https://github.com/atalman due to export/test_torchbind.py::TestCompileTorchbind::test_compile_error_on_input_aliasing_contents_backend_aot_eager [GH job link](https://github.com/pytorch/pytorch/actions/runs/15804799771/job/44548489912) [HUD commit link](c95f7fa874) ([comment](https://github.com/pytorch/pytorch/pull/156313#issuecomment-2994171213))
2025-06-22 12:31:57 +00:00
Xuehai Pan
3627270bdf [BE][3/16] fix typos in torch/ (torch/_inductor/) (#156313)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156313
Approved by: https://github.com/jingsh
2025-06-22 08:43:09 +00:00
Nicolas Macchioni
68996dc183 [BE][2/X] Phase out usage of use_max_autotune() (#155848)
See #155847 for context

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155848
Approved by: https://github.com/masnesral
2025-06-18 01:18:09 +00:00
PyTorch MergeBot
7e4c097b07 Revert "[inductor] Add typing to _inductor/ir.py (#149958)"
This reverts commit 529e0357c6.

Reverted https://github.com/pytorch/pytorch/pull/149958 on behalf of https://github.com/malfet due to Looks like it broke inductor_torchbind tests, due to more graphbreaks, see b0fbbef136/1 ([comment](https://github.com/pytorch/pytorch/pull/149958#issuecomment-2949583209))
2025-06-06 15:19:16 +00:00
Tom Ritchford
529e0357c6 [inductor] Add typing to _inductor/ir.py (#149958)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149958
Approved by: https://github.com/Skylion007
2025-06-06 14:15:01 +00:00
Aaron Gokaslan
3555ebb63d [BE]: Update ruff to 0.11.8 (#153249)
Fixes a ton of false negatives throughout the codebase. RUFF also properly validates NOQA comments now and most of the changes are fixing typos there or removing filewide flake8 suppressions that were also silencing ruff issues.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153249
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/seemethere
2025-05-12 18:30:52 +00:00
Xia, Weiwen
ac9fcd6346 [Inductor][CPU] bug fix for int8 GEMM compensation epilogue (#152408)
Fixes #152398

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152408
Approved by: https://github.com/leslie-fang-intel
2025-05-05 08:26:47 +00:00
Anthony Shoumikhin
e2f9759bd0 Fix broken URLs (#152237)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152237
Approved by: https://github.com/huydhn, https://github.com/malfet
2025-04-27 09:56:42 +00:00
leslie-fang-intel
d70490ecfe [Inductor][CPP] Optimize the epilogue for int8 GEMM Template (#152000)
**Summary**
For int8 GEMM Template, the micro GEMM will calculate in u8s8s32 and we will do the scale/zp compensation in the epilogue. In general,  it will be calculated as:
```
temp = micro_gemm_output * x_scale * w_scale
temp = temp - (x_scale * w_scale * x_zp) * sum(w, 0)
```
For case when `x_scale, w_scale, x_zp` are constant, we can pre-calculate the compensation to save runtime calculation.

**Performance**
Test with 4 cores of XEON-5 and shapes from VIT model
Before
```
GEMM(M=197,N=768,K=768) compile: 0.0939 ms (2.48 TOPS, 18.13 GB/s)
GEMM(M=197,N=3072,K=768) compile: 0.4275 ms (2.17 TOPS, 13.90 GB/s)
GEMM(M=197,N=768,K=3072) compile: 0.2677 ms (3.47 TOPS, 22.20 GB/s)
GEMM(M=1,N=1000,K=768) compile: 0.0148 ms (0.10 TOPS, 99.10 GB/s)
```

After
```
GEMM(M=197,N=768,K=768) compile: 0.0597 ms (3.90 TOPS, 28.53 GB/s)
GEMM(M=197,N=3072,K=768) compile: 0.2126 ms (4.37 TOPS, 27.95 GB/s)
GEMM(M=197,N=768,K=3072) compile: 0.2282 ms (4.07 TOPS, 26.04 GB/s)
GEMM(M=1,N=1000,K=768) compile: 0.0149 ms (0.10 TOPS, 98.71 GB/s)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152000
Approved by: https://github.com/Xia-Weiwen, https://github.com/CaoE, https://github.com/jansel
2025-04-24 23:36:00 +00:00
Xia, Weiwen
246f3b6530 [Quant][PT2E][X86] enable qconv1d-relu fusion (#150751)
**Summary**
As the title.
- The `conv1d - relu` pattern will be annotated by the `X86InductorQuantizer`.
- The pattern will be fused as `qconv_pointwise` during lowering.

**Test plan**
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_qconv1d_relu_cpu
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150751
Approved by: https://github.com/jerryzh168, https://github.com/leslie-fang-intel
2025-04-09 14:42:02 +00:00
sanchitintel
1c544a9ddd [Inductor-CPP] If all of the activation scale dims are 1, make it a 0D tensor (#147033)
For int8 dynamically quantized activation & int8 quantized weights, add a workaround for some indexing issue that expected an empty index ( so, was expecting a 0D tensor) in epilogue creator when the activation scale was sized [1, 1] by converting it into a 0D tensor.

The issue was discovered while running LLaMA2 quantized with torchao's `int8_dynamic_activation_int8_weight` quantization on CPU with max-autotune enabled (although this error would've occurred regardless).

The final hidden states tensor that's activation to LM head is of shape `[batch_size, sequence_length, hidden_dim]` during decoding. For decoding one token at a time with batch size 1, sequence length is 1. The activation scale is shaped `[1, 1]` (reshaped from `[1, 1, 1]`). However, Inductor epilogue creator expects a 0D tensor in this case (my guess is that the corresponding logic in Inductor expects a 0D tensor if a tensor has only one element, even if it's 1D?).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147033
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel
2025-03-03 18:32:27 +00:00
Xuehai Pan
1cb4e2df65 [BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
2025-02-28 13:33:19 +00:00
ZhiweiYan-96
59915b8dec [Intel GPU] qlinear at XPU backend (#133307)
# Motivation
The PR is intended to enable `onednn.qlinear` and `onednn.qlinear_unary` at Intel GPU.

We register the qlinear ops at C++ backend via `TORCH_LIBRARY_IMPL`, the op this PR registers includes `onednn::qlinear_pointwise`, `onednn::qlinear_pointwise.tensor`, and `onednn::qlinear_prepack`. The prepack conduct transpose on weight for fitting oneDNN requirement on weight to acquire higher performance.

Also, we remove the limitation of the corresponding annotation method in  the `XPUInductorQuantizer` (`torch/ao/quantization/quantizer/xpu_inductor_quantizer.py`) to allow GPU linear conversion.

We add the kChar(`torch.int8`) dtype in the `torch/_inductor/fx_passes/quantization` and `torch/_inductor/mkldnn_ir.py`, as signed int8 is the default INT8 data type at GPU side.

We verified the op through UTs and e2e model testing like ResNet18, ResNet50.

# UT verification
```
 DNNL_VERBOSE=0 TORCH_COMPILE_DEBUG=0 python test/inductor/test_mkldnn_pattern_matcher.py -v  \
     -k test_qlinear_xpu \
     -k test_qlinear_relu_xpu \
     -k test_qlinear_gelu_xpu
```

# Runtime exemplification
Here is the oneDNN verbose collected through running above UTs
```
//pure int8 gemm
onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 dst_s8::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32+dst:0:s32,,2x4:4x3,0.187988
// post-relu fusion
onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 bia_f32::blocked:ab::f0_mask2 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_relu,,2x4:4x4,0.115234
// post-gelu fusion
onednn_verbose,primitive,exec,gpu:0,matmul,jit:gemm:any,undef,src_s8::blocked:ab::f0 wei_s8::blocked:ab::f0 dst_f32::blocked:ab::f0,attr-scratchpad:user attr-scales:src0:0:f32+dst:0:f32+wei:2:f32 attr-zero-points:src0:0:s32 attr-post-ops:eltwise_gelu_tanh,,2x4:4x4,0.170898

````

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133307
Approved by: https://github.com/liangan1, https://github.com/guangyey, https://github.com/EikanWang, https://github.com/jerryzh168

Co-authored-by: guangyey <guangye.yu@intel.com>
2025-02-18 04:02:42 +00:00
sanchitintel
46851022ff [Inductor][CPU] Add auto-tuning support for da8w8 sym act sym wgt GEMM (#143187)
## Summary

Templated `int8xint8->int32` GEMM that uses AMX ISA (present on Intel Xeon Gen 4 & above). Any epilogues such as weight scale, activation scale, and bias are applied per output block in a fused manner .
Performs well for large values of `M` dimension (assuming canonical dimensions [`M, K`] and [`K, N`] for the activation & weight matrices'/tensors' sizes) when the activation is quantized per-token.
Also supports SmoothQuant GEMM pattern when activation is quantized per-tensor (scalar scale) or per-token (vector scale is applied as an epilogue in this case).

Also increased coverage of GEMM template for uint8 activation, int8 weight GEMM UTs for when the activation zero point is a 1D tensor (the existing implementation only accepted 0D tensors). However, some of such UTs would have to be explicitly enabled with `max-autotune` Inductor config.

## Performance data

The templated codegened fused GEMM with M=32, K=4096, N=14336 used in LLaMA3 exhibits more than 2x perf-gain compared to oneDNN qlinear + mul (for activation's scale) with 48 cores of one socket of Xeon SP 4th gen Platinum 8468 when per-token quantization is used.

For M=1, K=4096, N=14336, regardless of whether per-tensor quantization was used for activation or per-token, the perf gain was more than 3x.

Intel OpenMP & libtcmalloc had been preloaded. All cores used by the workload corresponded to distinct physical cores.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143187
Approved by: https://github.com/jansel, https://github.com/leslie-fang-intel, https://github.com/jgong5

Co-authored-by: Leslie Fang <leslie.fang@intel.com>
2025-01-22 02:27:53 +00:00
Aaron Orenstein
bac62341eb PEP585 update - torch/_inductor (#145198)
See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
2025-01-21 21:04:33 +00:00
leslie-fang-intel
25de671ea8 [Inductor][CPP] Enable Grouped GEMM Template (#143796)
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012

- Support flexible number of GEMMs
- Share activation across GEMMs
  - The Grouped GEMM Template supports independent activations
  - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
  - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
  - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```

**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
        self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)

    def forward(self, x):
        return self.linear0(x), self.linear1(x)

if __name__ == "__main__":
    with torch.no_grad():
        input = torch.randn(batch_size, in_features, dtype=dtype)
        m = M(bias=bias).to(dtype=dtype).eval()
        cm = torch.compile(m)
        act_res = cm(input)
```

Generated Code:  https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py

**Next Step**

- Support Epilogue fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
2025-01-14 05:59:07 +00:00
Aaron Orenstein
45ef3309e3 [BE] typing for decorators (#144161)
Summary:
Untyped decorators strip annotations from the decorated items.

- _compile
- _inductor/fx_passes/post_grad
- _inductor/lowering
- _library/custom_ops
- _meta_registrations
- _ops
- _refs/nn/functional
- ao/quantization/quantizer/xnnpack_quantizer_utils
- distributed/_composable/contract
- fx/experimental/graph_gradual_typechecker
- fx/experimental/migrate_gradual_types/constraint_generator
- optim/optimizer
- signal/windows/windows
- testing/_internal/common_device_type
- torch/_inductor/decomposition
- utils/flop_counter

Test Plan: unit tests

Differential Revision: D62302684

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144161
Approved by: https://github.com/Skylion007, https://github.com/albanD
2025-01-04 16:40:09 +00:00
Mitchell, Frost
20f24e3fbd [inductor][cpp] Add BMM kernel template for autotuning (#129772)
This PR adds the Cpp template for BMM, for FP32, FP16, and BF16. See #125683 for more background.

1.  Adds `CppBmmTemplate` class which inherits from `CppPackedGemmTemplate`. Given a number of worker threads `num_threads` and batch size `B`, execute the Gemm kernel. For the first `B - (B % num_threads)` batch inputs, run one sub-gemm problem per thread. Then for the remaining `B % num_threads` sub-gemms, we execute each subproblem using the parallelized Gemm kernel.
To manage this code, the `GEMM_TEMPLATE` from `CppPackedGemmTemplate` is rendered two different times, one with a single thread and one which includes the parallel OMP pragma.
2. Adapts `CppPackedGemmTemplate` to allow for child class. The `GEMM_TEMPLATE` is separated into different strings to allow for rendering by the child class. Slicing/indexing are adapted to allow for 3D BMM inputs. Additional methods `get_options()` and `_get_params_for_choices()` are added to reduce code duplication.

BMM within `dlrm` benchmark has a single input buffer which is used for but X and W inputs. This is currently not supported in this PR.

### Performance
On Granite/Sapphire Rapids, cpp_bmm template code uses AMX which requires an expensive transpose operation so the BMM op is rarely selected as faster than the existing external bmm kernel. As a result, speedup on SPR is identical with and without BMM code. Pass rate matches the rates for main exactly.

#### Test Summary on Granite Rapids
Test   Scenario | Comp Item | Date | Compiler | torchbench | huggingface | timm_models
-- | -- | -- | -- | -- | -- | --
Single Socket Multi-Threads | Pass   Rate | gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |     |   |  bmm + gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
  |  |  Geomean Speedup | gemm autotune| inductor | 2.15x | 1.91x | 2.52x
   |     |   |  bmm + gemm autotune | inductor | 2.15x | 1.96x | 2.53x
Single Core Single-Thread | Pass   Rate | gemm autotune | inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
   |    |   |  bmm + gemm autotune| inductor | 91%,   73/80 | 100%,   46/46 | 100%,   61/61
 |  | Geomean Speedup | inductor_locally_benchmark_586 | inductor | 2.43x | 1.56x | 2.60x
   |    |   |  inductor_locally_benchmark_585 | inductor | 2.45x | 1.56x | 2.63x

This is not the case on an older Skylake Xeon machine.
For the BMM ops contained in torchbench models, bmm performance improves by 1.10-2.64x.

#### BF16 28-core Skylake Xeon
| Model | Inductor | GemmAutotune | Gemm+BMM Autotune |
|--------|--------|--------|--------|
| BERT_pytorch | 1.233x | 2.597x | 2.608x |
| hf_DistilBert | 1.128x | 2.242x | 2.368x |
| hf_Reformer | 1.124x | 1.419x | 1.590x |
| hf_T5_base | 1.012x | 1.257x | 1.382x |
| hf_T5_large | 1.085x | 2.228x | 2.345x |

## Example BMM Code
```
#include <c10/util/Unroll.h>
#include <torch/csrc/inductor/aoti_torch/c/shim.h>

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_32_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 32 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_loadd(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
        _tile_zero(2);
        _tile_zero(3);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(4, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(6, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 4, 6);
        _tile_loadd(7, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 4, 7);
        _tile_stream_loadd(5, A + 16 * lda + k, lda * sizeof(bfloat16));
        _tile_dpbf16ps(2, 5, 6);
        _tile_dpbf16ps(3, 5, 7);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
        _tile_stored(2, C + 16 * ldc + 0, ldc * sizeof(float));
        _tile_stored(3, C + 16 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 32 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm_amx_kernel_16_2(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc,
    uint8_t tilecfg_rows
) {
    // TODO(jgong5): add prefetch hint for A, B, C
    auto loadconfig = [](const amx_tilecfg& cfg) {
        _tile_loadconfig(&cfg);
    };
    const auto last_k_offset = K / 32 * 32;
    const auto tail_k_size = K - last_k_offset;
    if C10_LIKELY (last_k_offset > 0) {
        amx_state.configure(tilecfg_rows, 64, 16 / 16, 2, loadconfig);
    } else {
        amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
    }
    auto load_c = [&]() {
        _tile_loadd(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_loadd(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };
    auto zero_c = [&]() {
        _tile_zero(0);
        _tile_zero(1);
    };

    if constexpr (accum) {
        load_c();
    } else {
        zero_c();
    }

    auto compute = [&](int k) {
        _tile_stream_loadd(2, A + 0 * lda + k, lda * sizeof(bfloat16));
        _tile_loadd(3, B + k * ldb + 0, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(0, 2, 3);
        _tile_loadd(4, B + k * ldb + 32, ldb * 2 * sizeof(bfloat16));
        _tile_dpbf16ps(1, 2, 4);
    };

    #pragma GCC unroll 4
    for (int k = 0; k < last_k_offset; k += 32) {
        compute(k);
    }

    auto store_c = [&]() {
    // store to C
        _tile_stored(0, C + 0 * ldc + 0, ldc * sizeof(float));
        _tile_stored(1, C + 0 * ldc + 16, ldc * sizeof(float));
    };

    // TODO(jgong5): move tail k computation to separate loopnest to save tile configuration overhead
    if C10_UNLIKELY (tail_k_size > 0) {
        if C10_LIKELY (last_k_offset > 0) {
            store_c();
            amx_state.configure(tilecfg_rows, tail_k_size * sizeof(bfloat16), 16 / 16, 2, loadconfig);
            load_c();
        }
        compute(last_k_offset);
    }

    store_c();
}

template <bool accum>
inline void cpp_bmm_micro_gemm(
    AMXState& amx_state,
    const bfloat16* __restrict__ A,
    const bfloat16* __restrict__ B,
    float* __restrict__ C,
    int64_t M,
    int64_t N,
    int64_t K,
    int64_t lda,
    int64_t ldb,
    int64_t ldc
) {
    AOTI_TORCH_CHECK(N % 32 == 0, "N dimension must be multiple of 32");
    AOTI_TORCH_CHECK(K % 2 == 0, "K dimension must be multiple of 2");
    // TODO(jgong5): loop unroll for M and N
    for (int64_t n = 0; n < N; n += 32) {
        for (int64_t m = 0; m < M; m += 32) {
            int64_t block_m = std::min<int64_t>(M - m, 32);
            int64_t m_tail = m;
            if (block_m >= 32) {
                cpp_bmm_micro_gemm_amx_kernel_32_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 32;
                m_tail += 32;
            }
            else
            if (block_m >= 16) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m * lda,
                    B + n,
                    C + m * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    16
                );
                block_m -= 16;
                m_tail += 16;
            }
            if (block_m > 0) {
                cpp_bmm_micro_gemm_amx_kernel_16_2<accum>(
                    amx_state,
                    A + m_tail * lda,
                    B + n,
                    C + m_tail * ldc + n,
                    K,
                    lda,
                    ldb,
                    ldc,
                    block_m
                );
            }
        }
    }
}
void threaded_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 48;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 1;
    constexpr int64_t Nt_blocks = 1;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 1;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 48 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    #pragma omp parallel num_threads(48)
    {
        const int tid = omp_get_thread_num();
        const int64_t k_group_id = tid / num_Kt_blocks;
        const int64_t k_slice_id = tid % num_Kt_blocks;
        const int64_t n_group_id = k_group_id / num_Nt_blocks;
        const int64_t n_slice_id = k_group_id % num_Nt_blocks;
        const int64_t k_block_start = k_slice_id * Kt_blocks;
        const int64_t k_block_end = std::min(k_block_start + Kt_blocks, Kr_blocks);
        const int64_t n_block_start = n_slice_id * Nt_blocks;
        const int64_t n_block_end = std::min(n_block_start + Nt_blocks, Nr_blocks);
        const int64_t m_block_start = std::min(n_group_id * Mt_blocks, Mr_blocks);
        const int64_t m_block_end = std::min(m_block_start + Mt_blocks, Mr_blocks);
        const int64_t num_Mc_blocks_per_thread = (m_block_end - m_block_start + Mc_blocks - 1) / Mc_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
void single_thread_mm(const bfloat16* X, const bfloat16* W, bfloat16* Y, const int64_t ks_b_index)
{

    constexpr int64_t num_threads = 1;
    constexpr int64_t N = 64;
    constexpr int64_t K = 96;
    constexpr int64_t Mr = 32;
    constexpr int64_t Nr = 32;
    constexpr int64_t Kr = 32;
    constexpr int64_t Nr_blocks = (N + Nr - 1) / Nr;
    constexpr int64_t Kr_blocks = (K + Kr - 1) / Kr;
    constexpr int64_t M = static_cast<int64_t>(384L);
    constexpr int64_t Mr_blocks = (M + Mr - 1) / Mr;
    constexpr int64_t Mt_blocks = 12;
    constexpr int64_t Nt_blocks = 2;
    constexpr int64_t Kt_blocks = 3;
    constexpr int64_t Mc_blocks = 12;
    constexpr int64_t Nc_blocks = 1;
    constexpr int64_t Kc_blocks = 3;
    constexpr int64_t num_Mc_blocks = (Mr_blocks + Mc_blocks - 1) / Mc_blocks;
    constexpr int64_t num_Nc_blocks = (Nr_blocks + Nc_blocks - 1) / Nc_blocks;
    constexpr int64_t num_Mt_blocks = (Mr_blocks + Mt_blocks - 1) / Mt_blocks;
    constexpr int64_t num_Nt_blocks = (Nr_blocks + Nt_blocks - 1) / Nt_blocks;
    constexpr int64_t num_Kt_blocks = (Kr_blocks + Kt_blocks - 1) / Kt_blocks;

    // make sure all partitions are assigned
    AOTI_TORCH_CHECK(
        Mt_blocks * Nt_blocks * Kt_blocks * 1 >= Mr_blocks * Nr_blocks * Kr_blocks,
        "Not all partitions are assigned."
    );
    {
        constexpr int tid = 0;
        constexpr int64_t k_group_id = 0;
        constexpr int64_t k_slice_id = 0;
        constexpr int64_t n_group_id = 0;
        constexpr int64_t n_slice_id = 0;
        constexpr int64_t m_block_start = 0;
        constexpr int64_t n_block_start = 0;
        constexpr int64_t n_block_end = Nr_blocks;
        constexpr int64_t k_block_start = 0;
        constexpr int64_t k_block_end = Kr_blocks;
        constexpr int64_t num_Mc_blocks_per_thread = num_Mc_blocks;
        constexpr int64_t m_block_end = Mr_blocks;
        AMXState amx_state;
        auto _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); auto local_acc_buf = _local_acc_buf.get();
        for (int64_t mc_block_id = 0; mc_block_id < num_Mc_blocks_per_thread; mc_block_id++) {
            const int64_t my_mc_block_id = (mc_block_id + n_slice_id) % num_Mc_blocks_per_thread;
            const int64_t mc = m_block_start + my_mc_block_id * Mc_blocks;
            const int64_t m_start = mc * Mr;
            const int64_t m_end = std::min(std::min(mc + Mc_blocks, m_block_end) * Mr, M);
            const int64_t m_size = m_end - m_start;
            for (int64_t nc = n_block_start; nc < n_block_end; nc += Nc_blocks) {
                const int64_t n_start = nc * Nr;
                const int64_t n_end = std::min(std::min(nc + Nc_blocks, n_block_end) * Nr, N);
                const int64_t n_size = n_end - n_start;
                // NB: assume we pad N, nc_block_end won't exceed padded N here.
                const int64_t nc_block_end = std::min(nc + Nc_blocks, n_block_end);
                if (_local_acc_buf == nullptr) { _local_acc_buf = std::make_unique<float[]>(static_cast<int64_t>(Mc_blocks*Mr*Nc_blocks*Nr)); local_acc_buf = _local_acc_buf.get(); }
                for (int64_t kc = k_block_start; kc < k_block_end; kc += Kc_blocks) {
                    int64_t k_start = kc * Kr;
                    int64_t k_end = std::min(std::min(kc + Kc_blocks, k_block_end) * Kr, K);
                    for (int64_t nci = nc; nci < nc_block_end; nci++) {
                        if (kc == k_block_start) {
                            cpp_bmm_micro_gemm<static_cast<bool>(false)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        } else {
                            cpp_bmm_micro_gemm<static_cast<bool>(true)>(
                                amx_state,
                                &(X[static_cast<int64_t>(k_start + (96L*m_start) + (36864L*ks_b_index))]),
                                &(W[static_cast<int64_t>((32L*k_start) + (3072L*nci) + (6144L*ks_b_index))]),
                                &(local_acc_buf[static_cast<int64_t>((Nr*nci) + ((-1L)*Nr*nc))]),
                                static_cast<int64_t>(m_end + ((-1L)*m_start)),
                                static_cast<int64_t>(Nr),
                                static_cast<int64_t>(k_end + ((-1L)*k_start)),
                                static_cast<int64_t>(96L),
                                static_cast<int64_t>(32L),
                                static_cast<int64_t>(Nc_blocks*Nr)
                            );

                        }
                    }
                }
                {
                    {
                        #pragma GCC ivdep
                        for(int64_t x0=static_cast<int64_t>(0L); x0<static_cast<int64_t>(m_end + ((-1L)*m_start)); x0+=static_cast<int64_t>(1L))
                        {
                            for(int64_t x1=static_cast<int64_t>(0L); x1<static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1+=static_cast<int64_t>(16L))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(16));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(16));
                            }
                            for(int64_t x1=static_cast<int64_t>(16L*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))); x1<static_cast<int64_t>(n_end + ((-1L)*n_start)); x1+=(static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))) == 0 ? 1 : static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L)))))))
                            {
                                auto tmp0 = at::vec::Vectorized<float>::loadu(local_acc_buf + static_cast<int64_t>(x1 + (Nc_blocks*Nr*x0)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                                auto tmp1 = at::vec::convert<bfloat16>(tmp0);
                                tmp1.store(Y + static_cast<int64_t>(n_start + x1 + (64L*m_start) + (64L*x0) + (24576L*ks_b_index)), static_cast<int64_t>(n_end + ((-1L)*n_start) + ((-16L)*(c10::div_floor_integer(static_cast<int64_t>((n_end + ((-1L)*n_start))), static_cast<int64_t>(16L))))));
                            }
                        }
                    }

                }
            }
        }
        amx_state.release([]() { _tile_release(); });
    }
}
extern "C"
void cpp_bmm(const bfloat16* X, const bfloat16* W, bfloat16* Y)
{
    const int64_t B = static_cast<int64_t>(5L);
    constexpr int64_t num_threads = 48;
    int64_t B_single_thread_block = (B / num_threads) * num_threads;

    #pragma omp parallel for num_threads(48)
    for (int64_t b_start = 0; b_start < B_single_thread_block; ++b_start) {
        single_thread_mm(X, W, Y, b_start);
    }
    for (int64_t b_start = B_single_thread_block; b_start < B; ++b_start) {
        threaded_mm(X, W, Y, b_start);
    }
}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129772
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-12-06 04:54:00 +00:00
Jason Ansel
808f0f656d [inductor] Refactor MutableBox to make IRNode typing easier (#140895)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140895
Approved by: https://github.com/ezyang, https://github.com/Skylion007
2024-11-20 19:50:46 +00:00
Wu, Chunyuan
d7411c0cc1 [AOTI] add C shim for QConvPointWise (#138540)
This PR adds C shim for `QConvPointWisePT2E` and `QConvPointWiseBinaryPT2E` similar to https://github.com/pytorch/pytorch/pull/138439. Besides that, we aligned the implementation of `qconv_pointwise` with `qlinear_pointwise` in the following aspects:
1. The parameter order of `qconv_pointwise` and `qlinear_pointwise` are quite different, we aligned the schema of `qconv_pointwise` to have similar parameter order as `qlinear_pointwise` to make it more consistent.
2. We always converted `x_scale` and `x_zero_point` to Tensors, just like in the lowering of `qlinear_pointwise`. This avoids the need to create two separate C APIs (one for `double x_scale` and `int64_t x_zero_point`, and another for `Tensor` versions). Instead, we only need one API for `Tensor`-based `x_scale` and `x_zero_point`. If we later add dynamic quantization for qconv (which will use `Tensor` for `x_scale` and `x_zero_point`), we can reuse the code from this PR and don't need to change the C shim layer API.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138540
Approved by: https://github.com/jgong5, https://github.com/desertfire
ghstack dependencies: #138691, #138806
2024-10-31 02:03:01 +00:00
sanchitintel
f951fcd1d7 Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)
## Summary

As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with #131310.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-08-14 03:14:45 +00:00
PyTorch MergeBot
89670d5bdd Revert "Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)"
This reverts commit 8fbd7d92a8.

Reverted https://github.com/pytorch/pytorch/pull/131887 on behalf of https://github.com/fbgheith due to breaking internal builds ([comment](https://github.com/pytorch/pytorch/pull/131887#issuecomment-2285082401))
2024-08-12 23:45:46 +00:00
sanchitintel
8fbd7d92a8 Inductor-CPU WoQ int8 GEMM micro-kernel with scale epilogue (#131887)
## Summary

As part of #125683, this PR modifies existing CPU GEMM cpp template & micro-kernel template to enable int8 WoQ GEMM auto-tuning with AVX2, AVX512 & AMX ISAs (the latter is only available on Xeon 4th generation & beyond).

WoQ GEMM takes FP16/BF16 activations, int8 weights, and scale of the same dtype as activations.
The operation is equivalent to `torch.nn.functional.linear(x, w.to(x.dtype)) * scale`, which is essentially what the ATen op `torch.ops.aten._weight_int8pack_mm` currently does (except that weights are not cached by it). Weights will be considered constant & cached, so this implementation is suitable for inference, and not QAT. `scale` is supported as a `mul` epilogue.

Only BF16 activations have been supported in this PR because for FP16 & FP32, weight is dequantized during constant-folding pass of freezing, and then after auto-tuning, performance with a large `M` dimension may be better than either torch.ops.aten._weight_int8pack_mm, or the WoQ micro-kernel support introduced in this PR, which dequantizes `w` within the micro-kernel.
While even BF16 activations with a large `M` dimension may benefit from dequantizing `w` beforehand, for now, they would  use WoQ support in GEMM templates for auto-tuning, and then a subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

### Performance
#### AMX
Op-level speedup due to AMX micro-kernel (selected during auto-tuning) on 32 physical cores of Intel(R) Xeon(R) Platinum 8468H (of Xeon 4th generation series, codenamed Sapphire Rapids) vs. ATen kernel `torch.ops.aten._weight_int8pack_mm`. Intel OpenMP & tcmalloc were preloaded.

In a few cases with an odd `K`, the implementation being added in this PR may not perform as well as the ATen kernel, which is unrelated to this PR, though, since `test_linear_amx` also exhibits similar datapoints. In those cases, the AMX micro-kernel might be slower than AVX512 micro-kernel, so if such sets of shapes are used for auto-tuning, either the AVX512 micro-kernel implementation, or the ATen kernel would be chosen instead.

Benchmarked with unit-tests.

Tabular data at https://gist.github.com/sanchitintel/294811a86c8ff6b867c668ae2107c405?permalink_comment_id=5142442#gistcomment-5142442

The AVX512 micro-kernel was disabled to collect data for AMX micro-kernel.

#### AVX2/AVX512 micro-kernels

Tabular data at at https://gist.github.com/sanchitintel/52b5fa9c66f791be19e48e2aa6423dc4?permalink_comment_id=5142437#gistcomment-5142437

### Follow-up
1. int4 WoQ GEMM micro-kernel will also be added in a separate PR.
2. A subsequent PR would add logic for deciding whether or not to dequantize weights beforehand.

E2E perf measurement should be done with #131310.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131887
Approved by: https://github.com/jgong5, https://github.com/leslie-fang-intel, https://github.com/jansel
2024-08-10 02:01:04 +00:00
Bin Bao
fd874b799f [AOTI][refactor] Update MKLDNN ops cpp wrapper support (#132367)
Summary: Set op_overload for MKLDNN ops so that cpp_kernel_name and python_kernel_name are constructed from there. This is an important step towards support those MKLDNN ops in the ABI-compatible mode, because we will need to read schema from op_overload for generating correct fallback op call in C++.

Differential Revision: [D60909798](https://our.internmc.facebook.com/intern/diff/D60909798)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132367
Approved by: https://github.com/leslie-fang-intel, https://github.com/angelayi
2024-08-08 03:02:29 +00:00
PyTorch MergeBot
4684b8e9d7 Revert "[BE] typing for decorators - _inductor/lowering (#131574)"
This reverts commit b2cbcf710b.

Reverted https://github.com/pytorch/pytorch/pull/131574 on behalf of https://github.com/clee2000 due to breaking lint internally D60265575 ([comment](https://github.com/pytorch/pytorch/pull/131572#issuecomment-2254328359))
2024-07-28 03:29:32 +00:00
Aaron Orenstein
b2cbcf710b [BE] typing for decorators - _inductor/lowering (#131574)
See #131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131574
Approved by: https://github.com/oulgen, https://github.com/zou3519
ghstack dependencies: #131568, #131569, #131570, #131571, #131572, #131573
2024-07-25 22:24:19 +00:00
Aaron Orenstein
5a0068cc69 [BE] mypy: disallow untyped decorators (#131428)
Untyped decorators strip the types from their decorated function so even if the underlying function is fully typed then callers to it don't get any benefit from type annotations.

Step 1 - Enable the error and override in all the offending files.

#131429

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131428
Approved by: https://github.com/justinchuby, https://github.com/oulgen
2024-07-23 21:50:55 +00:00
Xuehai Pan
b6d477fd56 [BE][Easy][16/19] enforce style for empty lines in import segments in torch/_i*/ (#129768)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129768
Approved by: https://github.com/jansel
2024-07-20 16:20:58 +00:00
leslie-fang-intel
37e3c60897 [Inductor][CPP] Remove redundant INT8-specific logic in the INT8 GEMM template (#129470)
**Summary**
Remove redundant INT8-specific logic in the INT8 GEMM template to unify the code structure with FP32/BF16/FP16 GEMM Template.

**Test Plan**
```
numactl -C 56-111 -m 1 python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129470
Approved by: https://github.com/jgong5
ghstack dependencies: #128825, #129048, #129049, #129103, #129220, #129221
2024-07-02 13:15:15 +00:00
leslie-fang-intel
a796358330 [Inductor][CPP] Enable Quantized Linear GEMM Template with Binary Fusion (#129103)
**Summary**
Based on previous PR, add the config to support quantized linear binary - optional(unary) post op fusion.

- Activation dtype: uint8
- Weight dtype: int8
- Output dtype: float32/bfloat16/uint8
- Post Op Fusion: with binary and optional[Unary] post operator fusion

**Test Plan**
```
clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise_binary
```

**Next Step**
- [✓] Unary post op fusion
- [✓] Int8 output
- [✓] Binary Fusion
- [ ] AMX int8 MicroGEMM Kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129103
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825, #129048, #129049
2024-07-02 12:45:10 +00:00
leslie-fang-intel
86e2d16ba0 [Inductor][Quant] Change the schema of QLinear Binary (#129049)
**Summary**
We change the schema of QLinear Binary, so it will be easier to enable the corresponding gemm template.

- Extra input of binary post-op is a tensor which needs to be an input node of autotuning, we need to move it at front of `output_scale` which is a scalar.
- We also move it at front of `bias`, since `bias` is optional tensor for this fusion, but `other` is a must to have for linear binary fusion.

**Test Plan**
```
python -u -m pytest -s -v test/quantization/core/test_quantized_op.py -k qlinear
python -u -m pytest -s -v test/inductor/test_mkldnn_pattern_matcher.py -k qlinear
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129049
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825, #129048
2024-07-02 12:36:38 +00:00
leslie-fang-intel
1a689ea38c [Inductor][CPP] Enable Quantized Linear GEMM Template with INT8 output and Unary Post Op (#129048)
**Summary**
Based on previous PR, add the config to support of int8 output and unary post op fusion with `ReLU` and `GeLU`

- Activation dtype: uint8
- Weight dtype: int8
- Output dtype: float32/bfloat16/uint8
- Post Op Fusion: with unary post operator fusion

**Test Plan**
```
clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise
```

**Next Step**
- [✓] Unary post op fusion
- [✓] Int8 output
- [ ] Binary Fusion
- [ ] AMX int8 MicroGEMM Kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129048
Approved by: https://github.com/jgong5, https://github.com/jansel
ghstack dependencies: #128825
2024-06-30 09:53:55 +00:00
leslie-fang-intel
35a197defa [Inductor][CPP] Enable Quantized Linear GEMM Template with FP32 output (#128825)
**Summary**
Support int8 GEMM Template with refer MicroInt8GEMM kernel for case:

- Activation dtype: uint8
- Weight dtype: int8
- Output dtype: float32/bfloat16
- Post Op Fusion: without unary post operator fusion

**Test Plan**
```
clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise
```

**Next Step**
- [ ] Unary post op fusion
- [ ] Int8 output
- [ ] Binary Fusion
- [ ] AMX int8 MicroGEMM Kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128825
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-06-30 09:45:43 +00:00
leslie-fang-intel
304c934572 Move MKLDNN Specific IR to Separate File (#126504)
**Summary**
Following the discussion in https://github.com/pytorch/pytorch/pull/122593#discussion_r1604144782, Move Inductor MKLDNN specific IRs to a separate file.

Co-authored-by: Isuru Fernando <ifernando@quansight.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126504
Approved by: https://github.com/desertfire, https://github.com/jgong5
ghstack dependencies: #126841, #126940
2024-06-18 09:29:13 +00:00
Jiong Gong
1fd2cd26a0 [inductor][cpp] support bf16/fp16 gemm template epilogue fusion (#126545)
As part of #125683, this PR adds epilogue fusion support for bf16/fp16 gemms. The key changes are as follows:
1. bf16 linear w/ epilogue fusion of some ops was originally supported via ATen oneDNN linear pointwise ops. In order to match the ATen op semantics, in-template epilogue support is added to the cpp gemm template so that we would have: "gemm + in-template epilogues -> template buffer". If the template is chosen for codegen, the in-template epilogues will be concatenated with the out-of-template epilogues that are appended during the scheduling.
2. Support bf16/fp16 legalization for `codegen_loop_bodies` which is used to generate the epilogue loops.
3. We used to leverage the in-place buffer mechanism to handle the in-place buffers in the epilogue codegen, in particular, for the reuses for output buffers of GEMM, template and epilogues. This is not correct since the output buffer is an "output" not an "in-place" buffer of the template kernel itself. Now, we use a dedicated "aliases" dict to manage such buffer reuses and the intermediate aliasing buffers are removed after codegen.
4. Add `localize_buffer` method to `LocalBufferScope` to allow the replacement of a global buffer with a local one in the given inductor IR nodes. This helps the fused loops to work on smaller-sized local buffers for better data locality.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126545
Approved by: https://github.com/jansel
2024-06-13 09:46:22 +00:00
Jiong Gong
1edcb31d34 [RELAND][inductor][cpp] bf16/fp16 gemm template computed with fp32 (#128472)
reland for https://github.com/pytorch/pytorch/pull/126068

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128472
Approved by: https://github.com/desertfire
2024-06-12 08:37:16 +00:00
Aaron Orenstein
ea614fb2b1 Flip default value for mypy disallow_untyped_defs [2/11] (#127839)
See #127836 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127839
Approved by: https://github.com/oulgen
2024-06-08 18:23:08 +00:00
PyTorch MergeBot
029b3ec775 Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit dae33a4961.

Reverted https://github.com/pytorch/pytorch/pull/126068 on behalf of https://github.com/PaliC due to failing internal tests ([comment](https://github.com/pytorch/pytorch/pull/126068#issuecomment-2141992307))
2024-05-31 12:33:25 +00:00
Jiong Gong
dae33a4961 [inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)
As part of #125683, this PR adds the initial bf16/fp16 gemm template support with micro-gemm implemented with fused type casting and fp32 computation. It doesn't provide epilogue fusion support yet which will be added in the next PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126068
Approved by: https://github.com/jansel
ghstack dependencies: #124021, #126019
2024-05-29 11:15:41 +00:00
Jiong Gong
cef776bcd1 [inductor][cpp] GEMM template (infra and fp32) (#124021)
This PR adds the Cpp template infrastructure and the initial FP32 gemm template. See RFC https://github.com/pytorch/pytorch/issues/125683 for more background info.
1. Cpp template infrastructure
Similar template abstractions as the CUTLASS template, i.e., `CppTemplate`, `CppTemplateKernel`, `CppTemplateBuffer`. The MicroGemm micro-kernel abstraction that can be used by Cpp GEMM templates.
2. Initial FP32 gemm template
This involves a GEMM template implementation `CppPackedGemmTemplate` that supports GEMM with constant weight (`B`) requiring `N` to be a multiple of register blocking while allows the static or dynamic sizes for the `M` (batch dim) of `A`. The `B` matrix would be prepacked. This is a typical setting for inference workloads. The template handles the thread decomposition (via `thread_blocking`) and cache blocking (via `cache_blocking`). Then it invokes `CppMicroGemm` which handles register blocking, instruction selection, and other CPU architecture-specific optimizations. A `CppMicroGemmFP32Vec` micro-kernel implementation is provided for fp32 matmuls implemented with ATen vec abstraction.
3. Correctness and performance
The changes have been validated with fp32 inference on the three benchmark suites (torchbench, huggingface and timm_models) with both static shape and dynamic shapes. Since it is an initial implementation, we are still working on further performance improves with follow-up PRs including the optimizations in kernels as well as fusions. The perf gains are only observed from a selective number of models compared to the ATen kernels which are implemented with MKL. The perf gains are more obvious with dynamic shapes since MKL only supports packed gemm for static shapes. Below are details.

Static shapes
| Benchmark | torchbench | huggingface | timm_models |
|------------|-------------|--------------|--------------|
| Multi-threaded (baseline) | 1.47x | 1.36x | 1.91x |
| Multi-threaded (max-autotune) | 1.47x | 1.36x | 1.92x |
| Single-threaded (baseline) | 1.56x | 1.19x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.52x |

Key models being sped up:
drq: 1.14x
soft_act: 1.12
cait_m36_384: 1.18x

Dynamic shapes
| Benchmark | torchbench | huggingface | timm_models |
| --- | --- | --- | --- |
| Multi-threaded (baseline) | 1.43x | 1.28x | 1.85x |
| Multi-threaded (max-autotune) | 1.47x | 1.28x | 1.85x |
| Single-threaded (baseline) | 1.55x | 1.20x | 1.51x |
| Single-threaded (max-autotune) | 1.56x | 1.19x | 1.53x |

Key models being sped up:
BERT_pytorch: 1.22x
pyhpc_turbulent: 1.13x
soft_actor_critic: 1.77x
BlenderbotForCausalLM: 1.09x
cait_m36_384: 1.17x

Differential Revision: [D57585365](https://our.internmc.facebook.com/intern/diff/D57585365)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124021
Approved by: https://github.com/jansel
2024-05-29 07:37:41 +00:00
PyTorch MergeBot
4608971f7a Revert "[inductor][cpp] GEMM template (infra and fp32) (#124021)"
This reverts commit 0d1e228550.

Reverted https://github.com/pytorch/pytorch/pull/124021 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2133002071))
2024-05-27 09:01:45 +00:00
PyTorch MergeBot
68fddebf84 Revert "[inductor][cpp] bf16/fp16 gemm template computed with fp32 w/o epilogue fusion (#126068)"
This reverts commit 4aa43d11f3.

Reverted https://github.com/pytorch/pytorch/pull/126068 on behalf of https://github.com/DanilBaibak due to Break internal build ([comment](https://github.com/pytorch/pytorch/pull/124021#issuecomment-2133002071))
2024-05-27 09:01:45 +00:00