mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix flaky test_inductor_multiple_specializations (#159264)
Summary: This test was using do_bench, so it was flaky performance is non-deterministic. Test Plan: buck test 'fbcode//mode/opt' fbcode//caffe2/test/inductor:compile_subprocess -- --exact 'caffe2/test/inductor:compile_subprocess - test_inductor_multiple_specializations_cuda (caffe2.test.inductor.test_compile_subprocess.GPUTests)' --run-disabled Rollback Plan: Differential Revision: D79098692 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159264 Approved by: https://github.com/jingsh
This commit is contained in:
parent
27ae72036d
commit
6de24135e5
|
|
@ -10553,8 +10553,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
|
||||||
)
|
)
|
||||||
def test_inductor_multiple_specializations(self):
|
def test_inductor_multiple_specializations(self):
|
||||||
from triton.testing import do_bench
|
|
||||||
|
|
||||||
@torch.compile(
|
@torch.compile(
|
||||||
options={
|
options={
|
||||||
"max_autotune": True,
|
"max_autotune": True,
|
||||||
|
|
@ -10569,7 +10567,7 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
m = 16
|
m = 16
|
||||||
k = 1280
|
k = 1280
|
||||||
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
|
dynamic_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||||
dynamic_specialized_a = torch.randn(m, k, device=GPU_TYPE, dtype=torch.bfloat16)
|
dynamic_specialized_a = dynamic_a.clone()
|
||||||
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
|
b = torch.randn(k, m, device=GPU_TYPE, dtype=torch.bfloat16)
|
||||||
torch._dynamo.decorators.mark_dynamic(
|
torch._dynamo.decorators.mark_dynamic(
|
||||||
dynamic_a,
|
dynamic_a,
|
||||||
|
|
@ -10584,12 +10582,10 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
||||||
b,
|
b,
|
||||||
1,
|
1,
|
||||||
)
|
)
|
||||||
dynamic = do_bench(lambda: inductor_matmul(dynamic_a, b))
|
dynamic = inductor_matmul(dynamic_a, b)
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
dynamic_specialized = do_bench(
|
dynamic_specialized = inductor_matmul(dynamic_specialized_a, b)
|
||||||
lambda: inductor_matmul(dynamic_specialized_a, b)
|
self.assertEqual(dynamic, dynamic_specialized)
|
||||||
)
|
|
||||||
self.assertGreaterEqual(dynamic, dynamic_specialized)
|
|
||||||
|
|
||||||
@requires_gpu()
|
@requires_gpu()
|
||||||
def test_stride_preservation_with_stride_modifying_fx_pass(self):
|
def test_stride_preservation_with_stride_modifying_fx_pass(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user