Isuru Fernando 2025-10-15 19:07:47 +00:00 committed by PyTorch MergeBot
parent 783da8b8e7
commit 99b32a6750
3 changed files with 31 additions and 1 deletions

View File

@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
self.assertEqual(fn(x[0:]), x[16:][:16])
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
def test_index_float_zero(self):
def fn(arg0, arg1, arg2):
t1 = torch.tanh(arg0)
t2 = t1.clone()
t2.fill_(arg1.item())
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
return torch.nn.functional.embedding(t3, arg2)
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
# from GPT2ForSequenceClassification
@skip_if_gpu_halide
def test_index_tensor(self):

View File

@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_):
x = self.doprint(expr.args[0])
return f"static_cast<float>({x})"
def _print_Float(self, expr: sympy.Expr) -> str:
if expr.is_integer:
# sympy considers 0.0 to be integer, but triton doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
return str(int(expr))
else:
return str(expr)
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
x = self.doprint(expr.args[0])

View File

@ -736,7 +736,12 @@ class TritonPrinter(PythonPrinter):
)
def _print_Float(self, expr: sympy.Expr) -> str:
if config.is_fbcode() and torch.version.hip:
if expr.is_integer:
# sympy considers 0.0 to be integer, but triton doesn't.
# this workaround prints the float as an integer
# xref: https://github.com/sympy/sympy/issues/26620
ret = str(int(expr))
elif config.is_fbcode() and torch.version.hip:
ret = f"{expr}"
else:
ret = f"tl.full([], {expr}, tl.float64)"