mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[xpu][fix] [Inductor] Avoid using tl.sqrt_rn on XPU before triton is ready (#165740)
Fixes #165738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165740 Approved by: https://github.com/etaf, https://github.com/EikanWang, https://github.com/chuanqi129, https://github.com/desertfire
This commit is contained in:
parent
39e5cdddf7
commit
32920926f0
|
|
@ -801,6 +801,9 @@ class TritonPrinter(PythonPrinter):
|
|||
return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})"
|
||||
|
||||
def _helper_sqrt(self, expr: sympy.Expr) -> str:
|
||||
# work around for https://github.com/pytorch/pytorch/issues/165738
|
||||
if torch.xpu.is_available():
|
||||
return f"libdevice.sqrt(({self._print(expr)}).to(tl.float32))"
|
||||
return f"tl.sqrt_rn(({self._print(expr)}).to(tl.float32))"
|
||||
|
||||
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||
|
|
@ -1212,6 +1215,9 @@ class TritonOverrides(OpOverrides):
|
|||
@staticmethod
|
||||
@maybe_upcast_float32()
|
||||
def sqrt(x):
|
||||
# work around for https://github.com/pytorch/pytorch/issues/165738
|
||||
if torch.xpu.is_available():
|
||||
return f"libdevice.sqrt({x})"
|
||||
return f"tl.sqrt_rn({x})"
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user