[AOTI][CPU] Introduce config.cpp.use_decompose_tanh (#152542)

Summary: Previously D70489427 changed tanh impl to `.tanh()`, and this is causing some meta internal workload perf regression. This diff will introduce a config so we can set it based on need.

Differential Revision: D73909371

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152542
Approved by: https://github.com/desertfire
This commit is contained in:
Huamin Li 2025-05-01 10:25:27 +00:00 committed by PyTorch MergeBot
parent 7c63ddd817
commit 6f6acb4128
3 changed files with 32 additions and 1 deletions

View File

@ -1070,6 +1070,23 @@ class CPUReproTests(TestCase):
x = torch.randn(1, 3, 64, 64)
self.common(Model(), (x,))
@config.patch("cpp.use_decompose_tanh", "1")
def test_tanh_atan2_use_decompose_tanh(self):
# https://github.com/pytorch/pytorch/issues/148241
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.shrink = nn.Tanhshrink()
def forward(self, x):
x = self.shrink(x)
x = torch.atan2(x, x)
return x
x = torch.randn(1, 3, 64, 64)
with self.assertRaises(AssertionError):
self.common(Model(), (x,))
def test_index_propagation_issue_102065(self):
def fn(x):
x = torch.arange(x.numel())

View File

@ -1420,7 +1420,15 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def tanh(a):
return f"{a}.tanh()"
if config.cpp.use_decompose_tanh:
vec_one = f"decltype({a})(1)"
vec_two = f"decltype({a})(2)"
vec_minus_two = f"decltype({a})(-2)"
return (
f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
)
else:
return f"{a}.tanh()"
@staticmethod
def reciprocal(a):

View File

@ -1008,6 +1008,12 @@ class cpp:
# enable this feature by their need.
enable_concat_linear = False
# Whether to use decomposed tanh for cpu device
# Disable by default due to https://github.com/pytorch/pytorch/issues/148241
use_decompose_tanh = (
os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1"
)
# config specific to codegen/triton.py
class triton: