From e18439113e3e9260977e2bc52b83fa83a6fb30b0 Mon Sep 17 00:00:00 2001 From: Menglu Yu Date: Sat, 21 Sep 2024 06:35:50 +0000 Subject: [PATCH] [PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349) Summary: see D62220158 Test Plan: ``` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:pad_mm -- --exact 'caffe2/test/inductor:pad_mm - test_pad_mm_bf16 (caffe2.test.inductor.test_pad_mm.PadMMTest)' --run-disabled ``` ### H100 Buck UI: https://www.internalfb.com/buck2/e5d85802-cab7-41a5-aacc-95f541796a99 Test UI: https://www.internalfb.com/intern/testinfra/testrun/9570149258587374 Network: Up: 9.1KiB Down: 0B (reSessionID-b339b51b-6a0e-4347-9414-1ba38f26a5d0) Jobs completed: 9. Time elapsed: 1:15.7s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 1. Fail 0. Fatal 0. Skip 1. Build failure 0 ### A100 Buck UI: https://www.internalfb.com/buck2/1082ad6e-56b0-4eb5-8092-ce507ca9a70e Test UI: https://www.internalfb.com/intern/testinfra/testrun/8444249533824784 Network: Up: 9.2KiB Down: 0B (reSessionID-2b3056ac-f29e-4de4-b6f5-9d994acf566b) Jobs completed: 9. Time elapsed: 1:36.9s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E see D62220158 Differential Revision: D63040455 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136349 Approved by: https://github.com/dshi7 --- torch/_inductor/fx_passes/pad_mm.py | 23 +++++++++++++++++++++++ torch/_inductor/fx_passes/split_cat.py | 1 + 2 files changed, 24 insertions(+) diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 87450b34e7e..a00bad39747 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad +def should_pad_mm_bf16(dtype, M, N, K): + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesnt repro on h100s: + return True + return False + + def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -410,6 +427,12 @@ def should_pad_bench( if torch._inductor.config.force_shape_pad: return True + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index f850ecf6008..194d1d6dbaa 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,6 +65,7 @@ post_grad_pass_names = [ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", + "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: