[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)

Summary: see D62220158

Test Plan:
```
buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:pad_mm -- --exact 'caffe2/test/inductor:pad_mm - test_pad_mm_bf16 (caffe2.test.inductor.test_pad_mm.PadMMTest)' --run-disabled
```

### H100

Buck UI: https://www.internalfb.com/buck2/e5d85802-cab7-41a5-aacc-95f541796a99
Test UI: https://www.internalfb.com/intern/testinfra/testrun/9570149258587374
Network: Up: 9.1KiB  Down: 0B  (reSessionID-b339b51b-6a0e-4347-9414-1ba38f26a5d0)
Jobs completed: 9. Time elapsed: 1:15.7s.
Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3)
Tests finished: Pass 1. Fail 0. Fatal 0. Skip 1. Build failure 0

### A100

Buck UI: https://www.internalfb.com/buck2/1082ad6e-56b0-4eb5-8092-ce507ca9a70e
Test UI: https://www.internalfb.com/intern/testinfra/testrun/8444249533824784
Network: Up: 9.2KiB  Down: 0B  (reSessionID-2b3056ac-f29e-4de4-b6f5-9d994acf566b)
Jobs completed: 9. Time elapsed: 1:36.9s.
Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3)
Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0

# E2E

see D62220158

Differential Revision: D63040455

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136349
Approved by: https://github.com/dshi7
This commit is contained in:
Menglu Yu 2024-09-21 06:35:50 +00:00 committed by PyTorch MergeBot
parent 02871461f7
commit e18439113e
2 changed files with 24 additions and 0 deletions

View File

@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
return should_pad
def should_pad_mm_bf16(dtype, M, N, K):
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
"pad_aten_mm_pass"
].get("k_threshold_to_pad", 8388608)
if (
dtype is torch.bfloat16
and K > M
and K > N
and N % 2 == 1
and K >= large_k_threshold_to_pad
and torch.cuda.get_device_capability() < (9, 0)
): # doesnt repro on h100s:
return True
return False
def should_pad_bench(
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
) -> bool:
@ -410,6 +427,12 @@ def should_pad_bench(
if torch._inductor.config.force_shape_pad:
return True
if (
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
and should_pad_mm_bf16(mat1.dtype, m, n, k)
):
return True
if not has_triton():
return False

View File

@ -65,6 +65,7 @@ post_grad_pass_names = [
"decompose_mm_pass",
"unbind_stack_aten_pass",
"shape_padding_multiplier",
"pad_aten_mm_pass",
]
for pass_name in pre_grad_pass_names: