mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTI][CPU] Introduce config.cpp.use_decompose_tanh (#152542)
Summary: Previously D70489427 changed tanh impl to `.tanh()`, and this is causing some meta internal workload perf regression. This diff will introduce a config so we can set it based on need. Differential Revision: D73909371 Pull Request resolved: https://github.com/pytorch/pytorch/pull/152542 Approved by: https://github.com/desertfire
This commit is contained in:
parent
7c63ddd817
commit
6f6acb4128
|
|
@ -1070,6 +1070,23 @@ class CPUReproTests(TestCase):
|
|||
x = torch.randn(1, 3, 64, 64)
|
||||
self.common(Model(), (x,))
|
||||
|
||||
@config.patch("cpp.use_decompose_tanh", "1")
|
||||
def test_tanh_atan2_use_decompose_tanh(self):
|
||||
# https://github.com/pytorch/pytorch/issues/148241
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.shrink = nn.Tanhshrink()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.shrink(x)
|
||||
x = torch.atan2(x, x)
|
||||
return x
|
||||
|
||||
x = torch.randn(1, 3, 64, 64)
|
||||
with self.assertRaises(AssertionError):
|
||||
self.common(Model(), (x,))
|
||||
|
||||
def test_index_propagation_issue_102065(self):
|
||||
def fn(x):
|
||||
x = torch.arange(x.numel())
|
||||
|
|
|
|||
|
|
@ -1420,7 +1420,15 @@ class CppVecOverrides(CppOverrides):
|
|||
|
||||
@staticmethod
|
||||
def tanh(a):
|
||||
return f"{a}.tanh()"
|
||||
if config.cpp.use_decompose_tanh:
|
||||
vec_one = f"decltype({a})(1)"
|
||||
vec_two = f"decltype({a})(2)"
|
||||
vec_minus_two = f"decltype({a})(-2)"
|
||||
return (
|
||||
f"{vec_two} / ({vec_one} + ({vec_minus_two} * {a}).exp()) - {vec_one}"
|
||||
)
|
||||
else:
|
||||
return f"{a}.tanh()"
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(a):
|
||||
|
|
|
|||
|
|
@ -1008,6 +1008,12 @@ class cpp:
|
|||
# enable this feature by their need.
|
||||
enable_concat_linear = False
|
||||
|
||||
# Whether to use decomposed tanh for cpu device
|
||||
# Disable by default due to https://github.com/pytorch/pytorch/issues/148241
|
||||
use_decompose_tanh = (
|
||||
os.environ.get("TORCHINDUCTOR_CPP_USE_DECOMPOSE_TANH", "0") == "1"
|
||||
)
|
||||
|
||||
|
||||
# config specific to codegen/triton.py
|
||||
class triton:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user