mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix log2, PowByNatural printing (#147592)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147592 Approved by: https://github.com/eellison
This commit is contained in:
parent
aae36929ed
commit
5f62d07ec6
|
|
@ -6569,6 +6569,7 @@ class CommonTemplate:
|
|||
),
|
||||
)
|
||||
|
||||
@skip_if_halide # log2 not implemented for halide
|
||||
def test_log2(self):
|
||||
def fn(x):
|
||||
return torch.log2(x), torch.log2(x + 1) - 2
|
||||
|
|
@ -6587,6 +6588,7 @@ class CommonTemplate:
|
|||
(torch.randn([8, 8]) + 10,),
|
||||
)
|
||||
|
||||
@skip_if_halide # log2 not implemented for halide
|
||||
def test_log_fp64(self):
|
||||
def fn(x):
|
||||
return torch.log(x), torch.log2(x)
|
||||
|
|
@ -10340,6 +10342,15 @@ class CommonTemplate:
|
|||
[x],
|
||||
)
|
||||
|
||||
@skip_if_halide # log2 not yet implemented
|
||||
@skip_if_triton_cpu # log2 implemented only in Dec 2024
|
||||
def test_pow_by_natural_log2_dynamic_shapes(self):
|
||||
@torch.compile(dynamic=True)
|
||||
def fn(x):
|
||||
return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1))
|
||||
|
||||
self.common(fn, [torch.randn(5)])
|
||||
|
||||
def test_setitem_with_int_parameter(self):
|
||||
x = torch.zeros(7, device=self.device)
|
||||
|
||||
|
|
|
|||
|
|
@ -96,6 +96,8 @@ class HalidePrinter(PythonPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
|
||||
|
||||
_print_FloorToInt = _print_floor
|
||||
|
||||
def _print_Trunc(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
|
||||
|
|
@ -140,39 +142,42 @@ class HalidePrinter(PythonPrinter):
|
|||
|
||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.cos(({self._print(expr.args[0])})"
|
||||
return f"hl.cos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.cosh(({self._print(expr.args[0])})"
|
||||
return f"hl.cosh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.acos(({self._print(expr.args[0])})"
|
||||
return f"hl.acos({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.sin(({self._print(expr.args[0])})"
|
||||
return f"hl.sin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.sinh(({self._print(expr.args[0])})"
|
||||
return f"hl.sinh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.asin(({self._print(expr.args[0])})"
|
||||
return f"hl.asin({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.tan(({self._print(expr.args[0])})"
|
||||
return f"hl.tan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.tanh(({self._print(expr.args[0])})"
|
||||
return f"hl.tanh({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
||||
assert len(expr.args) == 1
|
||||
return f"hl.atan(({self._print(expr.args[0])})"
|
||||
return f"hl.atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_log2(self, expr):
|
||||
raise NotImplementedError("log2")
|
||||
|
||||
def _print_FloorDiv(self, expr):
|
||||
if expr.is_integer:
|
||||
|
|
@ -453,6 +458,10 @@ class HalideOverrides(OpOverrides):
|
|||
def log(x):
|
||||
return f"hl.log({x})" # hl.fast_log fails accuracy
|
||||
|
||||
@staticmethod
|
||||
def log2(x):
|
||||
raise NotImplementedError("log2")
|
||||
|
||||
@staticmethod
|
||||
def isinf(x):
|
||||
# workaround https://github.com/halide/Halide/issues/8309
|
||||
|
|
|
|||
|
|
@ -605,7 +605,12 @@ class TritonPrinter(PythonPrinter):
|
|||
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
|
||||
)
|
||||
|
||||
_print_PowByNatural = _print_FloatPow
|
||||
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||
if expr.args[0].is_Integer:
|
||||
return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})"
|
||||
return (
|
||||
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
|
||||
)
|
||||
|
||||
def _print_Where(self, expr: sympy.Expr) -> str:
|
||||
c = self.doprint(expr.args[0])
|
||||
|
|
@ -678,6 +683,10 @@ class TritonPrinter(PythonPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))"
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -264,6 +264,10 @@ class PythonPrinter(ExprPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return f"math.atan({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"math.log2({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
return f"round({self._print(expr.args[0])})"
|
||||
|
|
@ -351,6 +355,10 @@ class CppPrinter(ExprPrinter):
|
|||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||
# use std::pow, that operates on floats
|
||||
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||
# Implement the special-case of 2**x for now
|
||||
base, exp = expr.args
|
||||
if base == 2:
|
||||
return f"(1 << ({self._print(exp)}))"
|
||||
raise NotImplementedError(
|
||||
f"_print_PowByNatural not implemented for {type(self)}"
|
||||
)
|
||||
|
|
@ -465,6 +473,9 @@ class CppPrinter(ExprPrinter):
|
|||
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
||||
return f"std::sqrt({self._print(expr.args[0])})"
|
||||
|
||||
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
|
||||
return f"std::log2({self._print(expr.args[0])})"
|
||||
|
||||
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
# TODO: dispatch to llrint depending on index type
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user