pytorch/torch/_inductor/codegen
leslie-fang-intel 25de671ea8 [Inductor][CPP] Enable Grouped GEMM Template (#143796)
**Summary**
Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012

- Support flexible number of GEMMs
- Share activation across GEMMs
  - The Grouped GEMM Template supports independent activations
  - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs
- Each GEMM can have a unique weight but same sizes
- Each GEMM can have a unique bias or None
  - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR
- Each GEMM have its own epilogues
  - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear
python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid
python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear
```

**Example**
Here is the example and generated code
```
batch_size = 4
in_features = 512
out_features = 1024
dtype = torch.bfloat16

class M(torch.nn.Module):
    def __init__(self, bias):
        super().__init__()
        self.linear0 = torch.nn.Linear(in_features, out_features, bias=False)
        self.linear1 = torch.nn.Linear(in_features, out_features, bias=False)

    def forward(self, x):
        return self.linear0(x), self.linear1(x)

if __name__ == "__main__":
    with torch.no_grad():
        input = torch.randn(batch_size, in_features, dtype=dtype)
        m = M(bias=bias).to(dtype=dtype).eval()
        cm = torch.compile(m)
        act_res = cm(input)
```

Generated Code:  https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py

**Next Step**

- Support Epilogue fusion

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796
Approved by: https://github.com/jgong5, https://github.com/jansel
2025-01-14 05:59:07 +00:00
..
aoti_runtime [AOTI][refactor] Separate header codegen (#138882) 2024-10-27 14:14:27 +00:00
cuda Add instantiation level to CutlassArgs (#144506) 2025-01-10 02:01:40 +00:00
rocm [ROCm][Inductor][CK] hackfix for segfault in addmm op (#144519) 2025-01-10 19:29:14 +00:00
xpu [AOTI XPU] Enable Cpp wraper for Intel GPU. (#135318) 2024-11-26 11:51:32 +00:00
__init__.py
aoti_hipify_utils.py remove allow-untyped-defs from _inductor/codegen/aoti_hipify_utils.py (#143916) 2024-12-27 23:25:37 +00:00
block_analysis.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
common.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
cpp_bmm_template.py [inductor][cpu] Fix bmm b_index for dynamic expressions in inductor autotuner (#143141) 2025-01-05 18:02:37 +00:00
cpp_flex_attention_template.py Remove is_reduced_floating_point from namespace std (#144502) 2025-01-10 03:24:10 +00:00
cpp_gemm_template.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00
cpp_grouped_gemm_template.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00
cpp_micro_gemm.py Simplify & rectify dequantized B buffer loading for AMX GEMM micro-kernel for WoQ int8 case (#140258) 2024-11-22 01:34:06 +00:00
cpp_prefix.h Remove is_reduced_floating_point from namespace std (#144502) 2025-01-10 03:24:10 +00:00
cpp_template_kernel.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00
cpp_template.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00
cpp_utils.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
cpp_wrapper_cpu_array_ref.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
cpp_wrapper_cpu.py Support nanj in inductor (#144064) 2025-01-13 14:29:38 +00:00
cpp_wrapper_gpu.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
cpp.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00
cpu_device_op_overrides.py remove allow-untyped-defs from _inductor/codegen/cpu_device_op_overrides.py (#143881) 2024-12-27 04:10:47 +00:00
cuda_combined_scheduling.py Prologue Fusion (#134532) 2024-12-13 04:18:25 +00:00
debug_utils.py Rename convert_arrayref_tensor_to_tensor to copy_arrayref_tensor_to_tensor (#142182) 2024-12-09 22:23:21 +00:00
halide.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
memory_planning.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
mps_device_op_overrides.py [Inductor] Add MPS device op overrides (#143892) 2024-12-28 02:11:45 +00:00
mps.py [mps/inductor] Add support for round() (#144731) 2025-01-14 05:56:13 +00:00
multi_kernel.py Revert "Use absolute path path.resolve() -> path.absolute() (#129409)" 2025-01-04 14:17:20 +00:00
simd_kernel_features.py Skip L1 cache for single-use buffers (#143115) 2025-01-07 19:35:40 +00:00
simd.py [Inductor] Restrict ND tiling analysis to MemoryDeps (#144497) 2025-01-11 05:16:47 +00:00
triton_combo_kernel.py Migrate from Tuple -> tuple in torch/_inductor (#144264) 2025-01-07 03:27:27 +00:00
triton_split_scan.py [inductor] Replace set by OrderedSet (#138466) 2024-12-13 16:08:45 +00:00
triton_utils.py [inductor] Move V.graph.scheduler.current_device to V.graph.current_device (#138252) 2024-10-18 23:05:54 +00:00
triton.py Skip L1 cache for single-use buffers (#143115) 2025-01-07 19:35:40 +00:00
wrapper.py [Inductor][CPP] Enable Grouped GEMM Template (#143796) 2025-01-14 05:59:07 +00:00