Revert "[inductor] Make generated kernels deterministic (#143951)"

This reverts commit 79b354ee37.

Reverted https://github.com/pytorch/pytorch/pull/143951 on behalf of https://github.com/wdvr due to failing tests on trunk ([comment](https://github.com/pytorch/pytorch/pull/143951#issuecomment-2564952267))
This commit is contained in:
PyTorch MergeBot 2024-12-30 02:06:38 +00:00
parent cf89127137
commit 1b0d19a2cb
4 changed files with 6 additions and 104 deletions

View File

@ -40,13 +40,11 @@ from torch._dynamo.testing import (
skipIfPy312,
)
from torch._dynamo.utils import ifdynstaticdefault
from torch._guards import CompileContext, CompileId
from torch._inductor.aoti_eager import (
aoti_compile_with_persistent_cache,
aoti_eager_cache_dir,
load_aoti_eager_cache,
)
from torch._inductor.codecache import cpp_prefix_path
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
from torch._inductor.fx_passes import pad_mm
from torch._inductor.test_case import TestCase as InductorTestCase
@ -54,7 +52,6 @@ from torch._inductor.utils import (
add_scheduler_init_hook,
run_and_get_code,
run_and_get_cpp_code,
run_and_get_kernels,
run_and_get_triton_code,
run_fw_bw_and_get_code,
)
@ -6218,99 +6215,6 @@ class CommonTemplate:
(torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),),
)
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
def test_deterministic_codegen(self):
if "cpu" in str(self.device) and config.is_fbcode():
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
@torch.compile(fullgraph=True)
def a(x):
return x.cos().sin().softmax(-1)
@torch.compile(fullgraph=True)
def b(x):
return x.sin().cos().softmax(-1)
@torch.compile(fullgraph=True)
def c(x):
return x.cos().sin().softmax(-1)
x = torch.randn(16, 256, device=self.device)
_, (coda_a0,) = run_and_get_kernels(a, x)
_, (coda_b0,) = run_and_get_kernels(b, x)
_, (coda_c0,) = run_and_get_kernels(c, x)
self.assertEqual(coda_a0, coda_c0)
# compile in a different order
torch.compiler.reset()
_, (coda_c1,) = run_and_get_kernels(c, x)
_, (coda_a1,) = run_and_get_kernels(a, x)
_, (coda_b1,) = run_and_get_kernels(b, x)
self.assertEqual(coda_a0, coda_a1)
self.assertEqual(coda_b0, coda_b1)
self.assertEqual(coda_c0, coda_c1)
# force a different CompileId
torch.compiler.reset()
CompileContext_init = CompileContext.__init__
with patch.object(
CompileContext,
"__init__",
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
):
_, (coda_a2,) = run_and_get_kernels(a, x)
_, (coda_c2,) = run_and_get_kernels(c, x)
_, (coda_b2,) = run_and_get_kernels(b, x)
self.assertEqual(coda_a0, coda_a2)
self.assertEqual(coda_b0, coda_b2)
self.assertEqual(coda_c0, coda_c2)
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
def test_deterministic_codegen_on_graph_break(self):
if "cpu" in str(self.device) and config.is_fbcode():
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
def a(x):
return x.cos().sin().softmax(-1)
@torch.compile()
def b(x):
x = a(x)
torch._dynamo.graph_break()
x = a(x)
return x
x = torch.randn(16, 256, device=self.device)
_, (code0, code1) = run_and_get_kernels(b, x)
self.assertEqual(code0, code1)
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
@config.patch(force_disable_caches=True)
@skip_if_cpp_wrapper("run_and_get_kernels issue")
def test_deterministic_codegen_with_suffix(self):
if "cpu" in str(self.device) and config.is_fbcode():
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
@torch.compile(fullgraph=True)
def a(x):
return x.cos().sin().softmax(-1)
@torch.compile(fullgraph=True)
def b(x, y):
x = x.cos().sin().softmax(-1)
x = torch.matmul(x, y)
return x
x = torch.randn(16, 256, device=self.device)
y = torch.randn(256, 256, device=self.device)
_, (code0,) = run_and_get_kernels(a, x)
_, (code1,) = run_and_get_kernels(b, x, y)
self.assertEqual(code0, code1)
def test_flip(self):
def fn(x):
return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2

View File

@ -3069,6 +3069,7 @@ class TritonKernel(SIMDKernel):
@staticmethod
def inductor_meta_common():
compile_id = torch._guards.CompileContext.current_compile_id()
inductor_meta = {
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
"are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(),
@ -3083,6 +3084,8 @@ class TritonKernel(SIMDKernel):
"min_split_scan_rblock": config.triton.min_split_scan_rblock,
"spill_threshold": config.triton.spill_threshold,
"store_cubin": config.triton.store_cubin,
"compile_id": str(compile_id) if compile_id else None,
"is_forward": not V.graph.is_backward,
}
if torch.version.hip is not None:
inductor_meta["is_hip"] = True

View File

@ -18,6 +18,7 @@ import time
from typing import Any, Container, Dict, List, Optional, Tuple
import torch
from torch._guards import CompileId
from torch.utils._ordered_set import OrderedSet
from ..triton_bundler import TritonBundler
@ -762,6 +763,8 @@ class CachingAutotuner(KernelInterface):
log_pt2_compile_event=True,
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
compile_id=CompileId.from_string(self.inductor_meta.get("compile_id")),
is_forward=self.inductor_meta.get("is_forward"),
):
timings = {
launcher: self.bench(launcher, *args, **kwargs)

View File

@ -1477,14 +1477,6 @@ def run_and_get_code(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
return result, source_codes
def run_and_get_kernels(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
result, source_codes = run_and_get_code(fn, *args, **kwargs)
kernels = []
for code in source_codes:
kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
return result, kernels
def run_fw_bw_and_get_code(fn):
def run_with_backward():
result = fn()