diff --git a/torch/_inductor/fx_passes/pad_mm.py b/torch/_inductor/fx_passes/pad_mm.py index 87450b34e7e..a00bad39747 100644 --- a/torch/_inductor/fx_passes/pad_mm.py +++ b/torch/_inductor/fx_passes/pad_mm.py @@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool: return should_pad +def should_pad_mm_bf16(dtype, M, N, K): + # always force pad for mm with bf16 when the following are satisfied to avoid perf regression + large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[ + "pad_aten_mm_pass" + ].get("k_threshold_to_pad", 8388608) + if ( + dtype is torch.bfloat16 + and K > M + and K > N + and N % 2 == 1 + and K >= large_k_threshold_to_pad + and torch.cuda.get_device_capability() < (9, 0) + ): # doesnt repro on h100s: + return True + return False + + def should_pad_bench( match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None ) -> bool: @@ -410,6 +427,12 @@ def should_pad_bench( if torch._inductor.config.force_shape_pad: return True + if ( + "pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options + and should_pad_mm_bf16(mat1.dtype, m, n, k) + ): + return True + if not has_triton(): return False diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index f850ecf6008..194d1d6dbaa 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -65,6 +65,7 @@ post_grad_pass_names = [ "decompose_mm_pass", "unbind_stack_aten_pass", "shape_padding_multiplier", + "pad_aten_mm_pass", ] for pass_name in pre_grad_pass_names: