mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
**Summary** Enable the CPP Grouped GEMM Fusion, lowering and Grouped GEMM Template following the RFC: https://github.com/pytorch/pytorch/issues/144012 - Support flexible number of GEMMs - Share activation across GEMMs - The Grouped GEMM Template supports independent activations - However, the pattern matcher requires an anchor node, which is as the shared activation across GEMMs - Each GEMM can have a unique weight but same sizes - Each GEMM can have a unique bias or None - Current PR does not yet support biases; this will be addressed in a follow-up epilogue fusion PR - Each GEMM have its own epilogues - Epilogue fusion is not yet supported in this PR and will be enabled in an upcoming follow-up epilogue fusion PR **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_grouped_linear_invalid python -u -m pytest -s -v test/inductor/test_cpu_cpp_wrapper.py -k test_grouped_linear ``` **Example** Here is the example and generated code ``` batch_size = 4 in_features = 512 out_features = 1024 dtype = torch.bfloat16 class M(torch.nn.Module): def __init__(self, bias): super().__init__() self.linear0 = torch.nn.Linear(in_features, out_features, bias=False) self.linear1 = torch.nn.Linear(in_features, out_features, bias=False) def forward(self, x): return self.linear0(x), self.linear1(x) if __name__ == "__main__": with torch.no_grad(): input = torch.randn(batch_size, in_features, dtype=dtype) m = M(bias=bias).to(dtype=dtype).eval() cm = torch.compile(m) act_res = cm(input) ``` Generated Code: https://gist.github.com/leslie-fang-intel/ed2e8d23aeb3586eb504feeace692e16#file-grouped-gemm-generated-code-py **Next Step** - Support Epilogue fusion Pull Request resolved: https://github.com/pytorch/pytorch/pull/143796 Approved by: https://github.com/jgong5, https://github.com/jansel |
||
|---|---|---|
| .. | ||
| aoti_runtime | ||
| cuda | ||
| rocm | ||
| xpu | ||
| __init__.py | ||
| aoti_hipify_utils.py | ||
| block_analysis.py | ||
| common.py | ||
| cpp_bmm_template.py | ||
| cpp_flex_attention_template.py | ||
| cpp_gemm_template.py | ||
| cpp_grouped_gemm_template.py | ||
| cpp_micro_gemm.py | ||
| cpp_prefix.h | ||
| cpp_template_kernel.py | ||
| cpp_template.py | ||
| cpp_utils.py | ||
| cpp_wrapper_cpu_array_ref.py | ||
| cpp_wrapper_cpu.py | ||
| cpp_wrapper_gpu.py | ||
| cpp.py | ||
| cpu_device_op_overrides.py | ||
| cuda_combined_scheduling.py | ||
| debug_utils.py | ||
| halide.py | ||
| memory_planning.py | ||
| mps_device_op_overrides.py | ||
| mps.py | ||
| multi_kernel.py | ||
| simd_kernel_features.py | ||
| simd.py | ||
| triton_combo_kernel.py | ||
| triton_split_scan.py | ||
| triton_utils.py | ||
| triton.py | ||
| wrapper.py | ||