mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PT2][Inductor][Optmus] fix test_pad_mm_bf16 and reland to fix long computation kernel (#136349)
Summary: see D62220158 Test Plan: ``` buck2 test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:pad_mm -- --exact 'caffe2/test/inductor:pad_mm - test_pad_mm_bf16 (caffe2.test.inductor.test_pad_mm.PadMMTest)' --run-disabled ``` ### H100 Buck UI: https://www.internalfb.com/buck2/e5d85802-cab7-41a5-aacc-95f541796a99 Test UI: https://www.internalfb.com/intern/testinfra/testrun/9570149258587374 Network: Up: 9.1KiB Down: 0B (reSessionID-b339b51b-6a0e-4347-9414-1ba38f26a5d0) Jobs completed: 9. Time elapsed: 1:15.7s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 1. Fail 0. Fatal 0. Skip 1. Build failure 0 ### A100 Buck UI: https://www.internalfb.com/buck2/1082ad6e-56b0-4eb5-8092-ce507ca9a70e Test UI: https://www.internalfb.com/intern/testinfra/testrun/8444249533824784 Network: Up: 9.2KiB Down: 0B (reSessionID-2b3056ac-f29e-4de4-b6f5-9d994acf566b) Jobs completed: 9. Time elapsed: 1:36.9s. Cache hits: 0%. Commands: 3 (cached: 0, remote: 0, local: 3) Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E see D62220158 Differential Revision: D63040455 Pull Request resolved: https://github.com/pytorch/pytorch/pull/136349 Approved by: https://github.com/dshi7
This commit is contained in:
parent
02871461f7
commit
e18439113e
|
|
@ -364,6 +364,23 @@ def should_pad(key: str, ori_time, pad_time) -> bool:
|
|||
return should_pad
|
||||
|
||||
|
||||
def should_pad_mm_bf16(dtype, M, N, K):
|
||||
# always force pad for mm with bf16 when the following are satisfied to avoid perf regression
|
||||
large_k_threshold_to_pad = torch._inductor.config.post_grad_fusion_options[
|
||||
"pad_aten_mm_pass"
|
||||
].get("k_threshold_to_pad", 8388608)
|
||||
if (
|
||||
dtype is torch.bfloat16
|
||||
and K > M
|
||||
and K > N
|
||||
and N % 2 == 1
|
||||
and K >= large_k_threshold_to_pad
|
||||
and torch.cuda.get_device_capability() < (9, 0)
|
||||
): # doesnt repro on h100s:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def should_pad_bench(
|
||||
match, mat1: Tensor, mat2: Tensor, op, input: Optional[Tensor] = None
|
||||
) -> bool:
|
||||
|
|
@ -410,6 +427,12 @@ def should_pad_bench(
|
|||
if torch._inductor.config.force_shape_pad:
|
||||
return True
|
||||
|
||||
if (
|
||||
"pad_aten_mm_pass" in torch._inductor.config.post_grad_fusion_options
|
||||
and should_pad_mm_bf16(mat1.dtype, m, n, k)
|
||||
):
|
||||
return True
|
||||
|
||||
if not has_triton():
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -65,6 +65,7 @@ post_grad_pass_names = [
|
|||
"decompose_mm_pass",
|
||||
"unbind_stack_aten_pass",
|
||||
"shape_padding_multiplier",
|
||||
"pad_aten_mm_pass",
|
||||
]
|
||||
|
||||
for pass_name in pre_grad_pass_names:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user