mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[inductor] print 0.0 as 0 for triton (#164291)
Fixes https://github.com/pytorch/pytorch/issues/164157 Fixes https://github.com/pytorch/pytorch/issues/164086 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164291 Approved by: https://github.com/bobrenjc93, https://github.com/mlazos
This commit is contained in:
parent
6a5a436624
commit
bbb7d2270b
|
|
@ -8423,6 +8423,22 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
self.assertEqual(fn(x[0:]), x[16:][:16])
|
||||
self.assertEqual(fn(x[128:]), x[128 + 16 :][:16])
|
||||
|
||||
def test_index_float_zero(self):
|
||||
def fn(arg0, arg1, arg2):
|
||||
t1 = torch.tanh(arg0)
|
||||
t2 = t1.clone()
|
||||
t2.fill_(arg1.item())
|
||||
t3 = torch.clamp(t2, 0, arg2.size(0) - 1).to(torch.long)
|
||||
return torch.nn.functional.embedding(t3, arg2)
|
||||
|
||||
arg0 = torch.randint(0, 1000, [47], dtype=torch.int64, device=self.device)
|
||||
arg1 = torch.randint(0, 1000, [], dtype=torch.int64, device=self.device)
|
||||
arg2 = torch.rand([256, 88], dtype=torch.float16, device=self.device)
|
||||
|
||||
cfn = torch.compile(fullgraph=True, dynamic=True)(fn)
|
||||
|
||||
self.assertEqual(fn(arg0, arg1, arg2), cfn(arg0, arg1, arg2))
|
||||
|
||||
# from GPT2ForSequenceClassification
|
||||
@skip_if_gpu_halide
|
||||
def test_index_tensor(self):
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import pytest
|
|||
|
||||
import torch
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.inductor_utils import HAS_CUDA_AND_TRITON
|
||||
|
||||
|
||||
class TestFuzzerCompileIssues(TestCase):
|
||||
|
|
@ -220,67 +221,6 @@ class TestFuzzerCompileIssues(TestCase):
|
|||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #164086")
|
||||
def test_fuzzer_issue_164086(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
def foo(arg0, arg1, arg2, arg3, arg4, arg5):
|
||||
t0 = arg0 # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
|
||||
t1 = torch.tanh(
|
||||
t0
|
||||
) # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
|
||||
t2 = t1.clone()
|
||||
t2.zero_() # size=(42, 56), stride=(42, 1), dtype=int64, device=cuda
|
||||
t3 = (
|
||||
arg1 # size=(50000, 128), stride=(50000, 1), dtype=float16, device=cuda
|
||||
)
|
||||
t4 = arg2 # size=(46, 128), stride=(46, 1), dtype=float16, device=cuda
|
||||
t5 = torch.nn.functional.linear(
|
||||
t3, t4
|
||||
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
|
||||
t6 = arg3 # size=(50000, 4, 46), stride=(184, 46, 1), dtype=float16, device=cuda
|
||||
t7 = t6.max(
|
||||
dim=1
|
||||
).values # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
|
||||
t8 = arg4 # size=(25786, 46), stride=(46, 1), dtype=float16, device=cuda
|
||||
t9 = arg5 # size=(24214, 46), stride=(46, 1), dtype=float16, device=cuda
|
||||
t10 = torch.cat(
|
||||
[t8, t9], dim=0
|
||||
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
|
||||
t11 = torch.pow(
|
||||
torch.pow(torch.pow(torch.pow(t5, t7), t10), t5), t7
|
||||
) # size=(50000, 46), stride=(50000, 1), dtype=float16, device=cuda
|
||||
t12 = torch.nn.functional.embedding(
|
||||
torch.clamp(t2, 0, t11.size(0) - 1).to(torch.long), t11
|
||||
) # size=(42, 56, 46), stride=(2576, 46, 1), dtype=float16, device=cuda
|
||||
output = t12
|
||||
return output
|
||||
|
||||
arg0 = torch.randint(0, 1000, [42, 56], dtype=torch.int64, device="cuda")
|
||||
arg1 = torch.rand(
|
||||
[50000, 128], dtype=torch.float16, device="cuda", requires_grad=True
|
||||
)
|
||||
arg2 = torch.rand(
|
||||
[46, 128], dtype=torch.float16, device="cuda", requires_grad=True
|
||||
)
|
||||
arg3 = torch.rand(
|
||||
[50000, 4, 46], dtype=torch.float16, device="cuda", requires_grad=True
|
||||
)
|
||||
arg4 = torch.rand(
|
||||
[25786, 46], dtype=torch.float16, device="cuda", requires_grad=True
|
||||
)
|
||||
arg5 = torch.rand(
|
||||
[24214, 46], dtype=torch.float16, device="cuda", requires_grad=True
|
||||
)
|
||||
|
||||
out_eager = foo(arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
out_eager.sum().backward()
|
||||
print("Eager Success! ✅")
|
||||
compiled_foo = torch.compile(foo, fullgraph=True, dynamic=True)
|
||||
out_compiled = compiled_foo(arg0, arg1, arg2, arg3, arg4, arg5)
|
||||
out_compiled.sum().backward()
|
||||
print("Compile Success! ✅")
|
||||
|
||||
@pytest.mark.xfail(reason="Issue #163877")
|
||||
def test_fuzzer_issue_163877(self):
|
||||
torch.manual_seed(0)
|
||||
|
|
|
|||
|
|
@ -141,6 +141,15 @@ class MetalExprPrinter(ExprPrinter_):
|
|||
x = self.doprint(expr.args[0])
|
||||
return f"static_cast<float>({x})"
|
||||
|
||||
def _print_Float(self, expr: sympy.Expr) -> str:
|
||||
if expr.is_integer:
|
||||
# sympy considers 0.0 to be integer, but Metal doesn't.
|
||||
# this workaround prints the float as an integer
|
||||
# xref: https://github.com/sympy/sympy/issues/26620
|
||||
return str(int(expr))
|
||||
else:
|
||||
return str(expr)
|
||||
|
||||
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||
assert len(expr.args) == 1
|
||||
x = self.doprint(expr.args[0])
|
||||
|
|
|
|||
|
|
@ -745,7 +745,12 @@ class TritonPrinter(PythonPrinter):
|
|||
)
|
||||
|
||||
def _print_Float(self, expr: sympy.Expr) -> str:
|
||||
if config.is_fbcode() and torch.version.hip:
|
||||
if expr.is_integer:
|
||||
# sympy considers 0.0 to be integer, but triton doesn't.
|
||||
# this workaround prints the float as an integer
|
||||
# xref: https://github.com/sympy/sympy/issues/26620
|
||||
ret = str(int(expr))
|
||||
elif config.is_fbcode() and torch.version.hip:
|
||||
ret = f"{expr}"
|
||||
else:
|
||||
ret = f"tl.full([], {expr}, tl.float64)"
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user