From 0148809131f494b842baf50d1f392f7404b87b44 Mon Sep 17 00:00:00 2001 From: Natalia Gimelshein Date: Tue, 20 Dec 2022 02:11:53 +0000 Subject: [PATCH] use libdevice for tanh (#90889) Per title I see slight differences in perf with this implementation, where standalone tanh is slightly slower for a tensor of 4000000 elements (20.4 us instead of 19.4us), other sizes are within noise. @bertmaher could you check if it affects your benchmarks? Pull Request resolved: https://github.com/pytorch/pytorch/pull/90889 Approved by: https://github.com/bertmaher, https://github.com/anijain2305 --- torch/_inductor/codegen/cpp.py | 4 ++++ torch/_inductor/codegen/triton.py | 4 ++++ torch/_inductor/decomposition.py | 5 ----- torch/_inductor/lowering.py | 5 +++++ 4 files changed, 13 insertions(+), 5 deletions(-) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index 474bd5861b7..da0230059e3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -416,6 +416,10 @@ class CppOverrides(OpOverrides): def expm1(x): return f"std::expm1({x})" + @staticmethod + def tanh(x): + return f"std::tanh({x})" + @staticmethod def signbit(x): return f"std::signbit({x})" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0269fdffe73..9f0ea5ddf9a 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -239,6 +239,10 @@ class TritonOverrides(OpOverrides): def expm1(x): return f"tl.libdevice.expm1({x})" + @staticmethod + def tanh(x): + return f"tl.libdevice.tanh({x})" + @staticmethod def sigmoid(x): return f"tl.sigmoid({x})" diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index c9cc7a2acaf..d7efc50db82 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -126,11 +126,6 @@ def clamp(x, min=None, max=None): return x -@register_decomposition([aten.tanh]) -def tanh(x): - return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0 - - # TorchInductor-only decomposition. It should not be taken to core. # See https://github.com/pytorch/torchdynamo/pull/1120 @register_decomposition([aten.floor_divide.default]) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 46b7290fb73..530e09142b1 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3621,6 +3621,11 @@ register_pointwise( type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, ) +register_pointwise( + aten.tanh, + type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT, +) + register_pointwise( aten.log, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,