[xpu][fix] [Inductor] Avoid using tl.sqrt_rn on XPU before triton is ready (#165740)

Fixes #165738

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165740
Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/chuanqi129, https://github.com/desertfire
This commit is contained in:
Zhang, Jianyi 2025-10-30 09:24:19 +00:00 committed by PyTorch MergeBot
parent 39e5cdddf7
commit 32920926f0

View File

@ -801,6 +801,9 @@ class TritonPrinter(PythonPrinter):
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
def _helper_sqrt(self, expr: sympy.Expr) -> str: def _helper_sqrt(self, expr: sympy.Expr) -> str:
# work around for https://github.com/pytorch/pytorch/issues/165738
if torch.xpu.is_available():
return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))" return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))"
def _print_FloatPow(self, expr: sympy.Expr) -> str: def _print_FloatPow(self, expr: sympy.Expr) -> str:
@ -1212,6 +1215,9 @@ class TritonOverrides(OpOverrides):
@staticmethod @staticmethod
@maybe_upcast_float32() @maybe_upcast_float32()
def sqrt(x): def sqrt(x):
# work around for https://github.com/pytorch/pytorch/issues/165738
if torch.xpu.is_available():
return f"libdevice.sqrt({x})"
return f"tl.sqrt_rn({x})" return f"tl.sqrt_rn({x})"
@staticmethod @staticmethod