mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] print 0.0 as 0 for triton (#164291)
Fixes https://github.com/pytorch/pytorch/issues/164157 Fixes https://github.com/pytorch/pytorch/issues/164086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164291 Approved by: https://github.com/bobrenjc93
This commit is contained in:
parent
783da8b8e7
commit
99b32a6750
|
|
@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
self.assertEqual(fn(x[0:]), x[16:][:16])
|
||||
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
|
||||
|
||||
def test_index_float_zero(self):
|
||||
def fn(arg0, arg1, arg2):
|
||||
t1 = torch.tanh(arg0)
|
||||
t2 = t1.clone()
|
||||
t2.fill_(arg1.item())
|
||||
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
|
||||
return torch.nn.functional.embedding(t3, arg2)
|
||||
|
||||
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
|
||||
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
|
||||
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
|
||||
|
||||
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
|
||||
|
||||
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
|
||||
|
||||
# from GPT2ForSequenceClassification
|
||||
@skip_if_gpu_halide
|
||||
def test_index_tensor(self):
|
||||
|
|
|
|||
|
|
@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_):
|
|||
x = self.doprint(expr.args[0])
|
||||
return f"static_cast<float>({x})"
|
||||
|
||||
def _print_Float(self, expr: sympy.Expr) -> str:
|
||||
if expr.is_integer:
|
||||
# sympy considers 0.0 to be integer, but triton doesn't.
|
||||
# this workaround prints the float as an integer
|
||||
# xref: https://github.com/sympy/sympy/issues/26620
|
||||
return str(int(expr))
|
||||
else:
|
||||
return str(expr)
|
||||
|
||||
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
x = self.doprint(expr.args[0])
|
||||
|
|
|
|||
|
|
@ -736,7 +736,12 @@ class TritonPrinter(PythonPrinter):
|
|||
)
|
||||
|
||||
def _print_Float(self, expr: sympy.Expr) -> str:
|
||||
if config.is_fbcode() and torch.version.hip:
|
||||
if expr.is_integer:
|
||||
# sympy considers 0.0 to be integer, but triton doesn't.
|
||||
# this workaround prints the float as an integer
|
||||
# xref: https://github.com/sympy/sympy/issues/26620
|
||||
ret = str(int(expr))
|
||||
elif config.is_fbcode() and torch.version.hip:
|
||||
ret = f"{expr}"
|
||||
else:
|
||||
ret = f"tl.full([], {expr}, tl.float64)"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user