mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[inductor] Make generated kernels deterministic (#143951)"
This reverts commit 79b354ee37.
Reverted https://github.com/pytorch/pytorch/pull/143951 on behalf of https://github.com/wdvr due to failing tests on trunk ([comment](https://github.com/pytorch/pytorch/pull/143951#issuecomment-2564952267))
This commit is contained in:
parent
cf89127137
commit
1b0d19a2cb
|
|
@ -40,13 +40,11 @@ from torch._dynamo.testing import (
|
|||
skipIfPy312,
|
||||
)
|
||||
from torch._dynamo.utils import ifdynstaticdefault
|
||||
from torch._guards import CompileContext, CompileId
|
||||
from torch._inductor.aoti_eager import (
|
||||
aoti_compile_with_persistent_cache,
|
||||
aoti_eager_cache_dir,
|
||||
load_aoti_eager_cache,
|
||||
)
|
||||
from torch._inductor.codecache import cpp_prefix_path
|
||||
from torch._inductor.codegen.common import DataTypePropagation, OptimizationContext
|
||||
from torch._inductor.fx_passes import pad_mm
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
|
|
@ -54,7 +52,6 @@ from torch._inductor.utils import (
|
|||
add_scheduler_init_hook,
|
||||
run_and_get_code,
|
||||
run_and_get_cpp_code,
|
||||
run_and_get_kernels,
|
||||
run_and_get_triton_code,
|
||||
run_fw_bw_and_get_code,
|
||||
)
|
||||
|
|
@ -6218,99 +6215,6 @@ class CommonTemplate:
|
|||
(torch.arange(-1e-5, 1e-5, 1e-7).to(dtype=dtype),),
|
||||
)
|
||||
|
||||
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
|
||||
@config.patch(force_disable_caches=True)
|
||||
@skip_if_cpp_wrapper("run_and_get_kernels issue")
|
||||
def test_deterministic_codegen(self):
|
||||
if "cpu" in str(self.device) and config.is_fbcode():
|
||||
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def a(x):
|
||||
return x.cos().sin().softmax(-1)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def b(x):
|
||||
return x.sin().cos().softmax(-1)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def c(x):
|
||||
return x.cos().sin().softmax(-1)
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
_, (coda_a0,) = run_and_get_kernels(a, x)
|
||||
_, (coda_b0,) = run_and_get_kernels(b, x)
|
||||
_, (coda_c0,) = run_and_get_kernels(c, x)
|
||||
self.assertEqual(coda_a0, coda_c0)
|
||||
|
||||
# compile in a different order
|
||||
torch.compiler.reset()
|
||||
_, (coda_c1,) = run_and_get_kernels(c, x)
|
||||
_, (coda_a1,) = run_and_get_kernels(a, x)
|
||||
_, (coda_b1,) = run_and_get_kernels(b, x)
|
||||
self.assertEqual(coda_a0, coda_a1)
|
||||
self.assertEqual(coda_b0, coda_b1)
|
||||
self.assertEqual(coda_c0, coda_c1)
|
||||
|
||||
# force a different CompileId
|
||||
torch.compiler.reset()
|
||||
CompileContext_init = CompileContext.__init__
|
||||
with patch.object(
|
||||
CompileContext,
|
||||
"__init__",
|
||||
lambda self, _: CompileContext_init(self, CompileId(999, 999)),
|
||||
):
|
||||
_, (coda_a2,) = run_and_get_kernels(a, x)
|
||||
_, (coda_c2,) = run_and_get_kernels(c, x)
|
||||
_, (coda_b2,) = run_and_get_kernels(b, x)
|
||||
self.assertEqual(coda_a0, coda_a2)
|
||||
self.assertEqual(coda_b0, coda_b2)
|
||||
self.assertEqual(coda_c0, coda_c2)
|
||||
|
||||
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
|
||||
@config.patch(force_disable_caches=True)
|
||||
@skip_if_cpp_wrapper("run_and_get_kernels issue")
|
||||
def test_deterministic_codegen_on_graph_break(self):
|
||||
if "cpu" in str(self.device) and config.is_fbcode():
|
||||
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
|
||||
|
||||
def a(x):
|
||||
return x.cos().sin().softmax(-1)
|
||||
|
||||
@torch.compile()
|
||||
def b(x):
|
||||
x = a(x)
|
||||
torch._dynamo.graph_break()
|
||||
x = a(x)
|
||||
return x
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
_, (code0, code1) = run_and_get_kernels(b, x)
|
||||
self.assertEqual(code0, code1)
|
||||
|
||||
@patch.object(cpp_prefix_path, "cache_clear", lambda: None)
|
||||
@config.patch(force_disable_caches=True)
|
||||
@skip_if_cpp_wrapper("run_and_get_kernels issue")
|
||||
def test_deterministic_codegen_with_suffix(self):
|
||||
if "cpu" in str(self.device) and config.is_fbcode():
|
||||
raise unittest.SkipTest("cpp packaging is wacky in fbcode")
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def a(x):
|
||||
return x.cos().sin().softmax(-1)
|
||||
|
||||
@torch.compile(fullgraph=True)
|
||||
def b(x, y):
|
||||
x = x.cos().sin().softmax(-1)
|
||||
x = torch.matmul(x, y)
|
||||
return x
|
||||
|
||||
x = torch.randn(16, 256, device=self.device)
|
||||
y = torch.randn(256, 256, device=self.device)
|
||||
_, (code0,) = run_and_get_kernels(a, x)
|
||||
_, (code1,) = run_and_get_kernels(b, x, y)
|
||||
self.assertEqual(code0, code1)
|
||||
|
||||
def test_flip(self):
|
||||
def fn(x):
|
||||
return torch.flip(x, (-1,)), torch.flip(x, (0, 2)) - 2
|
||||
|
|
|
|||
|
|
@ -3069,6 +3069,7 @@ class TritonKernel(SIMDKernel):
|
|||
|
||||
@staticmethod
|
||||
def inductor_meta_common():
|
||||
compile_id = torch._guards.CompileContext.current_compile_id()
|
||||
inductor_meta = {
|
||||
"backend_hash": torch.utils._triton.triton_hash_with_backend(),
|
||||
"are_deterministic_algorithms_enabled": torch.are_deterministic_algorithms_enabled(),
|
||||
|
|
@ -3083,6 +3084,8 @@ class TritonKernel(SIMDKernel):
|
|||
"min_split_scan_rblock": config.triton.min_split_scan_rblock,
|
||||
"spill_threshold": config.triton.spill_threshold,
|
||||
"store_cubin": config.triton.store_cubin,
|
||||
"compile_id": str(compile_id) if compile_id else None,
|
||||
"is_forward": not V.graph.is_backward,
|
||||
}
|
||||
if torch.version.hip is not None:
|
||||
inductor_meta["is_hip"] = True
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import time
|
|||
from typing import Any, Container, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch._guards import CompileId
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from ..triton_bundler import TritonBundler
|
||||
|
|
@ -762,6 +763,8 @@ class CachingAutotuner(KernelInterface):
|
|||
log_pt2_compile_event=True,
|
||||
metadata={"kernel_name": self.inductor_meta.get("kernel_name")},
|
||||
dynamo_compile_runtime_column_us="runtime_triton_autotune_time_us",
|
||||
compile_id=CompileId.from_string(self.inductor_meta.get("compile_id")),
|
||||
is_forward=self.inductor_meta.get("is_forward"),
|
||||
):
|
||||
timings = {
|
||||
launcher: self.bench(launcher, *args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -1477,14 +1477,6 @@ def run_and_get_code(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
|
|||
return result, source_codes
|
||||
|
||||
|
||||
def run_and_get_kernels(fn, *args, **kwargs) -> Tuple[Any, List[str]]:
|
||||
result, source_codes = run_and_get_code(fn, *args, **kwargs)
|
||||
kernels = []
|
||||
for code in source_codes:
|
||||
kernels.extend(re.findall(r"'''.*?'''", code, re.DOTALL))
|
||||
return result, kernels
|
||||
|
||||
|
||||
def run_fw_bw_and_get_code(fn):
|
||||
def run_with_backward():
|
||||
result = fn()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user