mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
use libdevice for tanh (#90889)
Per title I see slight differences in perf with this implementation, where standalone tanh is slightly slower for a tensor of 4000000 elements (20.4 us instead of 19.4us), other sizes are within noise. @bertmaher could you check if it affects your benchmarks? Pull Request resolved: https://github.com/pytorch/pytorch/pull/90889 Approved by: https://github.com/bertmaher, https://github.com/anijain2305
This commit is contained in:
parent
30edd39bdc
commit
0148809131
|
|
@ -416,6 +416,10 @@ class CppOverrides(OpOverrides):
|
||||||
def expm1(x):
|
def expm1(x):
|
||||||
return f"std::expm1({x})"
|
return f"std::expm1({x})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tanh(x):
|
||||||
|
return f"std::tanh({x})"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def signbit(x):
|
def signbit(x):
|
||||||
return f"std::signbit({x})"
|
return f"std::signbit({x})"
|
||||||
|
|
|
||||||
|
|
@ -239,6 +239,10 @@ class TritonOverrides(OpOverrides):
|
||||||
def expm1(x):
|
def expm1(x):
|
||||||
return f"tl.libdevice.expm1({x})"
|
return f"tl.libdevice.expm1({x})"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tanh(x):
|
||||||
|
return f"tl.libdevice.tanh({x})"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def sigmoid(x):
|
def sigmoid(x):
|
||||||
return f"tl.sigmoid({x})"
|
return f"tl.sigmoid({x})"
|
||||||
|
|
|
||||||
|
|
@ -126,11 +126,6 @@ def clamp(x, min=None, max=None):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@register_decomposition([aten.tanh])
|
|
||||||
def tanh(x):
|
|
||||||
return 2.0 / (1.0 + torch.exp(-2.0 * x)) - 1.0
|
|
||||||
|
|
||||||
|
|
||||||
# TorchInductor-only decomposition. It should not be taken to core.
|
# TorchInductor-only decomposition. It should not be taken to core.
|
||||||
# See https://github.com/pytorch/torchdynamo/pull/1120
|
# See https://github.com/pytorch/torchdynamo/pull/1120
|
||||||
@register_decomposition([aten.floor_divide.default])
|
@register_decomposition([aten.floor_divide.default])
|
||||||
|
|
|
||||||
|
|
@ -3621,6 +3621,11 @@ register_pointwise(
|
||||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
register_pointwise(
|
||||||
|
aten.tanh,
|
||||||
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||||
|
)
|
||||||
|
|
||||||
register_pointwise(
|
register_pointwise(
|
||||||
aten.log,
|
aten.log,
|
||||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user