diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ff04091fafa..68d900d2060 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar self.assertEqual(fn(x[0:]), x[16:][:16]) self.assertEqual(fn(x[128:]), x[128 + 16 :][:16]) + def test_index_float_zero(self): + def fn(arg0, arg1, arg2): + t1 = torch.tanh(arg0) + t2 = t1.clone() + t2.fill_(arg1.item()) + t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long) + return torch.nn.functional.embedding(t3, arg2) + + arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device) + arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device) + arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device) + + cfn = torch.compile(fullgraph=True, dynamic=True)(fn) + + self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2)) + # from GPT2ForSequenceClassification @skip_if_gpu_halide def test_index_tensor(self): diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index f68c241ca83..790ea9bb90d 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_): x = self.doprint(expr.args[0]) return f"static_cast({x})" + def _print_Float(self, expr: sympy.Expr) -> str: + if expr.is_integer: + # sympy considers 0.0 to be integer, but triton doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + return str(int(expr)) + else: + return str(expr) + def _print_FloorToInt(self, expr: sympy.Expr) -> str: assert len(expr.args) == 1 x = self.doprint(expr.args[0]) diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index c24cde56358..910c1441c05 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -736,7 +736,12 @@ class TritonPrinter(PythonPrinter): ) def _print_Float(self, expr: sympy.Expr) -> str: - if config.is_fbcode() and torch.version.hip: + if expr.is_integer: + # sympy considers 0.0 to be integer, but triton doesn't. + # this workaround prints the float as an integer + # xref: https://github.com/sympy/sympy/issues/26620 + ret = str(int(expr)) + elif config.is_fbcode() and torch.version.hip: ret = f"{expr}" else: ret = f"tl.full([], {expr}, tl.float64)"