Fix log2, PowByNatural printing (#147592)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147592
Approved by: https://github.com/eellison
This commit is contained in:
Isuru Fernando 2025-03-03 20:33:01 +00:00 committed by PyTorch MergeBot
parent aae36929ed
commit 5f62d07ec6
4 changed files with 50 additions and 10 deletions

View File

@ -6569,6 +6569,7 @@ class CommonTemplate:
),
)
@skip_if_halide # log2 not implemented for halide
def test_log2(self):
def fn(x):
return torch.log2(x), torch.log2(x + 1) - 2
@ -6587,6 +6588,7 @@ class CommonTemplate:
(torch.randn([8, 8]) + 10,),
)
@skip_if_halide # log2 not implemented for halide
def test_log_fp64(self):
def fn(x):
return torch.log(x), torch.log2(x)
@ -10340,6 +10342,15 @@ class CommonTemplate:
[x],
)
@skip_if_halide # log2 not yet implemented
@skip_if_triton_cpu # log2 implemented only in Dec 2024
def test_pow_by_natural_log2_dynamic_shapes(self):
@torch.compile(dynamic=True)
def fn(x):
return x + 2 ** (math.floor(math.log2(x.shape[0]) + 1))
self.common(fn, [torch.randn(5)])
def test_setitem_with_int_parameter(self):
x = torch.zeros(7, device=self.device)

View File

@ -96,6 +96,8 @@ class HalidePrinter(PythonPrinter):
assert len(expr.args) == 1
return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
_print_FloorToInt = _print_floor
def _print_Trunc(self, expr):
assert len(expr.args) == 1
return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
@ -140,39 +142,42 @@ class HalidePrinter(PythonPrinter):
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"hl.cos(({self._print(expr.args[0])})"
return f"hl.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"hl.cosh(({self._print(expr.args[0])})"
return f"hl.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"hl.acos(({self._print(expr.args[0])})"
return f"hl.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"hl.sin(({self._print(expr.args[0])})"
return f"hl.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"hl.sinh(({self._print(expr.args[0])})"
return f"hl.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"hl.asin(({self._print(expr.args[0])})"
return f"hl.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"hl.tan(({self._print(expr.args[0])})"
return f"hl.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"hl.tanh(({self._print(expr.args[0])})"
return f"hl.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"hl.atan(({self._print(expr.args[0])})"
return f"hl.atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_log2(self, expr):
raise NotImplementedError("log2")
def _print_FloorDiv(self, expr):
if expr.is_integer:
@ -453,6 +458,10 @@ class HalideOverrides(OpOverrides):
def log(x):
return f"hl.log({x})" # hl.fast_log fails accuracy
@staticmethod
def log2(x):
raise NotImplementedError("log2")
@staticmethod
def isinf(x):
# workaround https://github.com/halide/Halide/issues/8309

View File

@ -605,7 +605,12 @@ class TritonPrinter(PythonPrinter):
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
)
_print_PowByNatural = _print_FloatPow
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
if expr.args[0].is_Integer:
return f"libdevice.pow({float(expr.args[0])}, {self._print(expr.args[1])})"
return (
f"libdevice.pow({self._print(expr.args[0])}, {self._print(expr.args[1])})"
)
def _print_Where(self, expr: sympy.Expr) -> str:
c = self.doprint(expr.args[0])
@ -678,6 +683,10 @@ class TritonPrinter(PythonPrinter):
assert len(expr.args) == 1
return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))"
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"libdevice.log2(({self._print(expr.args[0])}).to(tl.float32))"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return (

View File

@ -264,6 +264,10 @@ class PythonPrinter(ExprPrinter):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.log2({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
@ -351,6 +355,10 @@ class CppPrinter(ExprPrinter):
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
# Implement the special-case of 2**x for now
base, exp = expr.args
if base == 2:
return f"(1 << ({self._print(exp)}))"
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
@ -465,6 +473,9 @@ class CppPrinter(ExprPrinter):
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return f"std::sqrt({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_log2(self, expr: sympy.Expr) -> str:
return f"std::log2({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type