mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move Sympy printers to torch/utils/_sympy/printers.py (#140597)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597 Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
parent
29ca44839e
commit
44186a0a4e
|
|
@ -580,8 +580,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
|
||||||
.check_regex(
|
.check_regex(
|
||||||
"torch.ops._c10d_functional.all_to_all_single.default\\("
|
"torch.ops._c10d_functional.all_to_all_single.default\\("
|
||||||
"arg\\d+_\\d+, "
|
"arg\\d+_\\d+, "
|
||||||
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], "
|
"\\[s\\d+ // \\d, s\\d+ // \\d\\], "
|
||||||
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]"
|
"\\[s\\d+ // \\d, s\\d+ // \\d\\]"
|
||||||
)
|
)
|
||||||
.run(code)
|
.run(code)
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -3570,7 +3570,7 @@ class GraphModule(torch.nn.Module):
|
||||||
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
"cast_symbool_to_symint_guardless(L['pred']) == 1",
|
||||||
]
|
]
|
||||||
false_guard_code = [
|
false_guard_code = [
|
||||||
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)",
|
"cast_symbool_to_symint_guardless(L['pred']) != 1",
|
||||||
]
|
]
|
||||||
test_symbool_guards(
|
test_symbool_guards(
|
||||||
f,
|
f,
|
||||||
|
|
|
||||||
|
|
@ -668,7 +668,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
|
||||||
"""\
|
"""\
|
||||||
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
|
||||||
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
|
||||||
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
+- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
|
||||||
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10457,7 +10457,7 @@ ShapeEnv not equal: field values don't match:
|
||||||
ShapeEnv not equal: field values don't match:
|
ShapeEnv not equal: field values don't match:
|
||||||
|
|
||||||
==> axioms: values don't match.
|
==> axioms: values don't match.
|
||||||
> Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
|
> Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
|
||||||
> Right: {}
|
> Right: {}
|
||||||
==> divisible: values don't match.
|
==> divisible: values don't match.
|
||||||
> Left: {Mod(s0, 3)}
|
> Left: {Mod(s0, 3)}
|
||||||
|
|
@ -10576,7 +10576,7 @@ ShapeEnv not equal: field values don't match:
|
||||||
ShapeEnv not equal: field values don't match:
|
ShapeEnv not equal: field values don't match:
|
||||||
|
|
||||||
==> axioms: values don't match.
|
==> axioms: values don't match.
|
||||||
> Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True}
|
> Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False}
|
||||||
> Right: {}
|
> Right: {}
|
||||||
==> deferred_runtime_asserts: values don't match.
|
==> deferred_runtime_asserts: values don't match.
|
||||||
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]}
|
||||||
|
|
|
||||||
|
|
@ -3259,9 +3259,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
(torch.tensor(20),),
|
(torch.tensor(20),),
|
||||||
fixes=[
|
fixes=[
|
||||||
# Could not guard on data-dependent expression Eq((u0//2), 0)
|
# Could not guard on data-dependent expression Eq((u0//2), 0)
|
||||||
"torch._check(((i//2)) != 0)",
|
"torch._check((i // 2) != 0)",
|
||||||
# Could not guard on data-dependent expression Eq((u0//2), 1)
|
# Could not guard on data-dependent expression Eq((u0//2), 1)
|
||||||
"torch._check(((i//2)) != 1)",
|
"torch._check((i // 2) != 1)",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
|
||||||
xoffset = tl.program_id(0) * XBLOCK
|
xoffset = tl.program_id(0) * XBLOCK
|
||||||
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
||||||
xmask = xindex < xnumel
|
xmask = xindex < xnumel
|
||||||
x0 = xindex % 20
|
x0 = (xindex % 20)
|
||||||
x1 = (xindex // 20) % 20
|
x1 = ((xindex // 20) % 20)
|
||||||
x2 = (xindex // 400)
|
x2 = xindex // 400
|
||||||
x3 = xindex
|
x3 = xindex
|
||||||
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
|
||||||
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
|
||||||
tmp2 = tmp0 + tmp1
|
tmp2 = tmp0 + tmp1
|
||||||
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
|
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,9 @@ from torch.testing._internal.common_utils import (
|
||||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
|
||||||
from torch.utils._sympy.functions import (
|
from torch.utils._sympy.functions import (
|
||||||
FloorDiv,
|
FloorDiv,
|
||||||
|
Mod,
|
||||||
ModularIndexing,
|
ModularIndexing,
|
||||||
|
PythonMod,
|
||||||
RoundDecimal,
|
RoundDecimal,
|
||||||
RoundToInt,
|
RoundToInt,
|
||||||
)
|
)
|
||||||
|
|
@ -236,7 +238,7 @@ class TestIndexingSimplification(InductorTestCase):
|
||||||
triton_code = run_and_get_triton_code(f, x)
|
triton_code = run_and_get_triton_code(f, x)
|
||||||
# Make sure the 2 load uses simpified indexing rather than something like
|
# Make sure the 2 load uses simpified indexing rather than something like
|
||||||
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
|
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
|
||||||
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),"))
|
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),"))
|
||||||
if DO_PERF_TEST:
|
if DO_PERF_TEST:
|
||||||
ms = benchmarker.benchmark_gpu(lambda: f(x))
|
ms = benchmarker.benchmark_gpu(lambda: f(x))
|
||||||
print(f"{ms=:.03f}")
|
print(f"{ms=:.03f}")
|
||||||
|
|
@ -313,6 +315,39 @@ class ExprPrinterTests(InductorTestCase):
|
||||||
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
|
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
|
||||||
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
|
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
|
||||||
|
|
||||||
|
def test_print_mod(self):
|
||||||
|
x = sympy.Symbol("x", integer=True)
|
||||||
|
expr = Mod(x - 1, 2)
|
||||||
|
self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""")
|
||||||
|
self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""")
|
||||||
|
self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""")
|
||||||
|
|
||||||
|
expr = (x - 10) % x
|
||||||
|
self.assertExpectedInline(pexpr(expr), """(-10) % x""")
|
||||||
|
self.assertExpectedInline(cexpr(expr), """(-10L) % x""")
|
||||||
|
self.assertExpectedInline(texpr(expr), """(-10) % x""")
|
||||||
|
|
||||||
|
def test_print_mod_index(self):
|
||||||
|
x = sympy.Symbol("x", integer=True)
|
||||||
|
ks = sympy.Symbol("ks", integer=True)
|
||||||
|
expr = ModularIndexing(x - 10, ks, ks)
|
||||||
|
self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""")
|
||||||
|
self.assertExpectedInline(
|
||||||
|
cexpr(expr),
|
||||||
|
"""(static_cast<int64_t>(c10::div_floor_integer("""
|
||||||
|
"""static_cast<int64_t>((-10L) + x), static_cast<int64_t>(ks))) % static_cast<int64_t>(ks))""",
|
||||||
|
)
|
||||||
|
self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""")
|
||||||
|
|
||||||
|
def test_print_python_mod(self):
|
||||||
|
x = sympy.Symbol("x", integer=True)
|
||||||
|
expr = PythonMod(x - 10, x)
|
||||||
|
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
|
||||||
|
self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""")
|
||||||
|
self.assertExpectedInline(
|
||||||
|
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
|
||||||
|
)
|
||||||
|
|
||||||
@parametrize("ndigits", [-1, 0, 1])
|
@parametrize("ndigits", [-1, 0, 1])
|
||||||
def test_print_round_decimal(self, ndigits):
|
def test_print_round_decimal(self, ndigits):
|
||||||
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
|
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
|
||||||
|
|
@ -330,7 +365,7 @@ class ExprPrinterTests(InductorTestCase):
|
||||||
s1 = sympy.Symbol("s1", integer=True)
|
s1 = sympy.Symbol("s1", integer=True)
|
||||||
s2 = sympy.Symbol("s2", integer=True)
|
s2 = sympy.Symbol("s2", integer=True)
|
||||||
expr = FloorDiv(s1, s2)
|
expr = FloorDiv(s1, s2)
|
||||||
self.assertEqual(pexpr(expr), "(s1 // s2)")
|
self.assertEqual(pexpr(expr), "s1 // s2")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cexpr(expr),
|
cexpr(expr),
|
||||||
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",
|
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",
|
||||||
|
|
|
||||||
|
|
@ -58,13 +58,11 @@ class TestMemoryPlanning(TestCase):
|
||||||
result, code = run_and_get_cpp_code(compiled, *args)
|
result, code = run_and_get_cpp_code(compiled, *args)
|
||||||
|
|
||||||
FileCheck().check(
|
FileCheck().check(
|
||||||
"pool1 = empty_strided_"
|
"pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
|
||||||
+ GPU_TYPE
|
|
||||||
+ "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
|
|
||||||
).check_next(
|
).check_next(
|
||||||
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
|
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
|
||||||
).check(
|
).check(
|
||||||
"buf1 = alloc_from_pool(pool1, align(4*(s0*s0)),"
|
"buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
|
||||||
).run(
|
).run(
|
||||||
code
|
code
|
||||||
)
|
)
|
||||||
|
|
@ -103,7 +101,7 @@ class TestMemoryPlanning(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
FileCheck().check(
|
FileCheck().check(
|
||||||
"int64_t int_array_2[] = {24L + (align(12L*s0)), };"
|
"int64_t int_array_2[] = {24L + align(12L*s0), };"
|
||||||
).check_next("int64_t int_array_3[] = {1L, };").check_next(
|
).check_next("int64_t int_array_3[] = {1L, };").check_next(
|
||||||
"AtenTensorHandle pool1_handle;"
|
"AtenTensorHandle pool1_handle;"
|
||||||
).check_next(
|
).check_next(
|
||||||
|
|
|
||||||
|
|
@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
|
||||||
|
|
||||||
# make sure the load for softmax is aligned
|
# make sure the load for softmax is aligned
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
"tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper,
|
"tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
|
||||||
f"forward_wrapper: {forward_wrapper}",
|
f"forward_wrapper: {forward_wrapper}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12505,8 +12505,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
"\n".join(lines),
|
"\n".join(lines),
|
||||||
"""\
|
"""\
|
||||||
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0)
|
tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
|
||||||
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""",
|
tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
@config.patch("triton.use_block_ptr", True)
|
@config.patch("triton.use_block_ptr", True)
|
||||||
|
|
@ -12538,7 +12538,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
"\n".join(lines),
|
"\n".join(lines),
|
||||||
"""\
|
"""\
|
||||||
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
|
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
|
||||||
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
|
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -275,7 +275,7 @@ class TritonBlockPointerTest(InductorTestCase):
|
||||||
"\n".join(load_lines),
|
"\n".join(load_lines),
|
||||||
"""\
|
"""\
|
||||||
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
|
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
|
||||||
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
|
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
|
||||||
)
|
)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
"\n".join(store_lines),
|
"\n".join(store_lines),
|
||||||
|
|
|
||||||
|
|
@ -2792,8 +2792,8 @@ class TestGuardsExpressions(TestCase):
|
||||||
guard_int(sym_int(s0 / 2.0))
|
guard_int(sym_int(s0 / 2.0))
|
||||||
guards = shape_env.produce_guards_expression([s0])
|
guards = shape_env.produce_guards_expression([s0])
|
||||||
|
|
||||||
self.assertIn("ToFloat", guards)
|
self.assertIn("math.trunc(", guards)
|
||||||
self.assertIn("FloatTrueDiv", guards)
|
self.assertIn("float(", guards)
|
||||||
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
||||||
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy.printing.printer import Printer
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx
|
import torch.fx
|
||||||
|
|
@ -31,6 +30,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||||
from torch.utils import _pytree as pytree
|
from torch.utils import _pytree as pytree
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
from torch.utils._sympy.numbers import int_oo
|
from torch.utils._sympy.numbers import int_oo
|
||||||
|
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
|
||||||
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
|
||||||
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
|
||||||
|
|
||||||
|
|
@ -609,12 +609,22 @@ class DataTypePropagation:
|
||||||
DataTypePropagation.propagate_loopbody(node._body)
|
DataTypePropagation.propagate_loopbody(node._body)
|
||||||
|
|
||||||
|
|
||||||
# This printer contains rules that are supposed to be generic for both C/C++ and
|
class PythonPrinter(_PythonPrinter):
|
||||||
# Python
|
def doprint(self, expr, *, simplify: bool = True, p=True):
|
||||||
class ExprPrinter(Printer):
|
# TODO: why are people passing strings to the printer here :think:
|
||||||
|
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||||
|
expr = V.graph.sizevars.simplify(expr)
|
||||||
|
return super().doprint(expr)
|
||||||
|
|
||||||
|
|
||||||
|
class OpOverrides:
|
||||||
|
def __init__(self, parent):
|
||||||
|
super().__init__()
|
||||||
|
self._parent = parent
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def paren(string):
|
def paren(string: str) -> str:
|
||||||
def all_in_parens(string):
|
def all_in_parens(string: str) -> bool:
|
||||||
if string[0] != "(" or len(string) < 2:
|
if string[0] != "(" or len(string) < 2:
|
||||||
return False
|
return False
|
||||||
count = 1
|
count = 1
|
||||||
|
|
@ -640,260 +650,6 @@ class ExprPrinter(Printer):
|
||||||
return string
|
return string
|
||||||
return f"({string})"
|
return f"({string})"
|
||||||
|
|
||||||
def _print_Relational(self, expr):
|
|
||||||
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
def _print_Mul(self, expr):
|
|
||||||
return "*".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
def _print_Add(self, expr):
|
|
||||||
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
# NB: this is OK to put here, because Mod is only defined for positive
|
|
||||||
# numbers, and so across C/Python its behavior is consistent
|
|
||||||
def _print_Mod(self, expr):
|
|
||||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
def _print_FloatTrueDiv(self, expr):
|
|
||||||
lhs, rhs = expr.args
|
|
||||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
|
||||||
|
|
||||||
def _print_CleanDiv(self, expr):
|
|
||||||
return self._print_FloorDiv(expr)
|
|
||||||
|
|
||||||
def _print_Identity(self, expr):
|
|
||||||
return self._print(expr.args[0])
|
|
||||||
|
|
||||||
def _print_GreaterThan(self, expr):
|
|
||||||
# GreaterThan: >=
|
|
||||||
# StrictlyGreaterThan: >
|
|
||||||
# Go figure...
|
|
||||||
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
# NB: The C implementation is injected into codegen at
|
|
||||||
# torch/_inductor/codegen/wrapper.py
|
|
||||||
def _print_align(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"align({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
|
||||||
# any explicit intervention. We print it just like x * x, notably, we
|
|
||||||
# never generate sympy.Pow with floats.
|
|
||||||
#
|
|
||||||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
|
||||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
|
||||||
# means exp is guaranteed to be integer.
|
|
||||||
def _print_Pow(self, expr):
|
|
||||||
base, exp = expr.args
|
|
||||||
base = self._print(base)
|
|
||||||
assert exp == int(exp), exp
|
|
||||||
exp = int(exp)
|
|
||||||
assert exp >= 0
|
|
||||||
if exp > 0:
|
|
||||||
return "*".join([self.paren(base)] * exp)
|
|
||||||
return "1"
|
|
||||||
|
|
||||||
# Explicit NotImplemented functions are to prevent default sympy printing
|
|
||||||
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
|
||||||
# message is better here because it tells you which printer class it needs
|
|
||||||
# to go in.
|
|
||||||
|
|
||||||
def _print_ToFloat(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_Infinity(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_NegativeInfinity(self, expr):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_print_NegativeInfinity not implemented for {type(self)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _print_FloorDiv(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_PythonMod(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_IntTrueDiv(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_PowByNatural(self, expr):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_print_PowByNatural not implemented for {type(self)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _print_FloatPow(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_TruncToInt(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_RoundToInt(self, expr):
|
|
||||||
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
|
||||||
|
|
||||||
def _print_RoundDecimal(self, expr):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_print_RoundDecimal not implemented for {type(self)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# NB: Some float operations are INTENTIONALLY not implemented for
|
|
||||||
# printers. You can implement them as a quick unblock, but it is better
|
|
||||||
# to ask yourself why we haven't done this computation in the Tensor
|
|
||||||
# universe instead
|
|
||||||
|
|
||||||
def _print_TruncToFloat(self, expr):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_print_TruncToFloat not implemented for {type(self)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def doprint(self, expr, *, simplify: bool = True):
|
|
||||||
# TODO: why are people passing strings to the printer here :think:
|
|
||||||
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
|
||||||
expr = V.graph.sizevars.simplify(expr)
|
|
||||||
return super().doprint(expr)
|
|
||||||
|
|
||||||
|
|
||||||
class PythonPrinter(ExprPrinter):
|
|
||||||
def _print_ToFloat(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"float({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_ModularIndexing(self, expr):
|
|
||||||
x, div, mod = expr.args
|
|
||||||
x = self.paren(self.doprint(x))
|
|
||||||
div = self.paren(self.doprint(div))
|
|
||||||
mod = self.paren(self.doprint(mod))
|
|
||||||
if div != "1":
|
|
||||||
x = f"({x} // {div})"
|
|
||||||
return f"{x} % {mod}"
|
|
||||||
|
|
||||||
def _print_Infinity(self, expr):
|
|
||||||
return "math.inf"
|
|
||||||
|
|
||||||
def _print_NegativeInfinity(self, expr):
|
|
||||||
return "-math.inf"
|
|
||||||
|
|
||||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
|
||||||
def _print_PythonMod(self, expr):
|
|
||||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
# WARNING: this is dangerous for Triton, which has C-style modulus
|
|
||||||
def _print_FloorDiv(self, expr):
|
|
||||||
x, div = expr.args
|
|
||||||
x = self.paren(self.doprint(x))
|
|
||||||
div = self.paren(self.doprint(div))
|
|
||||||
return f"({x} // {div})"
|
|
||||||
|
|
||||||
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
|
||||||
# does a special algorithm
|
|
||||||
def _print_IntTrueDiv(self, expr):
|
|
||||||
lhs, rhs = expr.args
|
|
||||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
|
||||||
|
|
||||||
def _helper_sqrt(self, expr):
|
|
||||||
return f"math.sqrt({self._print(expr)})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
|
||||||
return self._helper_sqrt(expr.args[0])
|
|
||||||
|
|
||||||
def _print_FloatPow(self, expr):
|
|
||||||
base, exp = expr.args
|
|
||||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
|
||||||
|
|
||||||
# TODO: Not sure this works with Triton, even when base/exp are integral
|
|
||||||
def _print_PowByNatural(self, expr):
|
|
||||||
base, exp = expr.args
|
|
||||||
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
|
|
||||||
|
|
||||||
def _print_floor(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.floor({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_FloorToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.floor({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_TruncToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
# This also could have been int(), they'll do the same thing for float
|
|
||||||
return f"math.trunc({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_ceiling(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.ceil({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_CeilToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.ceil({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_Abs(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"abs({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
# NB: It's expected that we've made explicit any promotion in the sympy
|
|
||||||
# expression, so it doesn't matter that Python max/min doesn't perform
|
|
||||||
# promotion
|
|
||||||
def _print_Max(self, expr):
|
|
||||||
assert len(expr.args) >= 2
|
|
||||||
return f"max({', '.join(map(self._print, expr.args))})"
|
|
||||||
|
|
||||||
def _print_Min(self, expr):
|
|
||||||
assert len(expr.args) >= 2
|
|
||||||
return f"min({', '.join(map(self._print, expr.args))})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.cos({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.cosh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.acos({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.sin({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.sinh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.asin({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.tan({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.tanh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"math.atan({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_RoundToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"round({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_RoundDecimal(self, expr):
|
|
||||||
assert len(expr.args) == 2
|
|
||||||
number, ndigits = expr.args
|
|
||||||
assert isinstance(ndigits, sympy.Integer)
|
|
||||||
return f"round({self._print(number)}, {ndigits})"
|
|
||||||
|
|
||||||
|
|
||||||
class OpOverrides:
|
|
||||||
def __init__(self, parent):
|
|
||||||
super().__init__()
|
|
||||||
self._parent = parent
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def __getattr__(self, item):
|
||||||
return getattr(self._parent, item)
|
return getattr(self._parent, item)
|
||||||
|
|
||||||
|
|
@ -982,31 +738,31 @@ class OpOverrides:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_not(x):
|
def bitwise_not(x):
|
||||||
return f"~{ExprPrinter.paren(x)}"
|
return f"~{OpOverrides.paren(x)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def logical_not(a):
|
def logical_not(a):
|
||||||
return f"{ExprPrinter.paren(a)} == 0"
|
return f"{OpOverrides.paren(a)} == 0"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_and(x, y):
|
def bitwise_and(x, y):
|
||||||
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_or(x, y):
|
def bitwise_or(x, y):
|
||||||
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_xor(x, y):
|
def bitwise_xor(x, y):
|
||||||
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_left_shift(x, y):
|
def bitwise_left_shift(x, y):
|
||||||
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def bitwise_right_shift(x, y):
|
def bitwise_right_shift(x, y):
|
||||||
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remainder(a, b):
|
def remainder(a, b):
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ import sympy
|
||||||
import torch
|
import torch
|
||||||
from torch._prims_common import is_integer_dtype
|
from torch._prims_common import is_integer_dtype
|
||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
|
||||||
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
from torch.utils._sympy.symbol import symbol_is_type, SymT
|
||||||
from torch.utils._sympy.value_ranges import ValueRanges
|
from torch.utils._sympy.value_ranges import ValueRanges
|
||||||
|
|
||||||
|
|
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
|
||||||
from .common import (
|
from .common import (
|
||||||
CSEVariable,
|
CSEVariable,
|
||||||
deduce_output_dtype_by_name,
|
deduce_output_dtype_by_name,
|
||||||
ExprPrinter,
|
|
||||||
Kernel,
|
Kernel,
|
||||||
KernelArgs,
|
KernelArgs,
|
||||||
OptimizationContext,
|
OptimizationContext,
|
||||||
|
|
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
|
||||||
return itervar in self.dependent_itervars
|
return itervar in self.dependent_itervars
|
||||||
|
|
||||||
|
|
||||||
class CppPrinter(ExprPrinter):
|
class CppPrinter(_CppPrinter):
|
||||||
def _print_Integer(self, expr):
|
def doprint(self, expr, *, simplify: bool = True, p=True):
|
||||||
return (
|
# TODO: why are people passing strings to the printer here :think:
|
||||||
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
|
||||||
)
|
expr = V.graph.sizevars.simplify(expr)
|
||||||
|
return super().doprint(expr)
|
||||||
def _print_Where(self, expr):
|
|
||||||
c = self.paren(self.doprint(expr.args[0]))
|
|
||||||
p = self.paren(self.doprint(expr.args[1]))
|
|
||||||
q = self.paren(self.doprint(expr.args[2]))
|
|
||||||
return f"{c} ? {p} : {q}"
|
|
||||||
|
|
||||||
def _print_ModularIndexing(self, expr):
|
|
||||||
x, div, mod = expr.args
|
|
||||||
x = self.paren(self.doprint(x))
|
|
||||||
if div != 1:
|
|
||||||
div = self.paren(self.doprint(div))
|
|
||||||
if expr.is_integer:
|
|
||||||
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
|
||||||
else:
|
|
||||||
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
|
||||||
mod = self.paren(self.doprint(mod))
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
|
|
||||||
|
|
||||||
def _print_FloorDiv(self, expr):
|
|
||||||
x, div = expr.args
|
|
||||||
x = self.paren(self.doprint(x))
|
|
||||||
div = self.paren(self.doprint(div))
|
|
||||||
if expr.is_integer:
|
|
||||||
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
|
||||||
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
|
||||||
|
|
||||||
def _print_floor(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
r = f"std::floor({self._print(expr.args[0])})"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
|
|
||||||
def _print_FloorToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
r = f"std::floor({self._print(expr.args[0])})"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
|
|
||||||
def _print_TruncToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
r = f"std::trunc({self._print(expr.args[0])})"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})"
|
|
||||||
|
|
||||||
def _print_TruncToFloat(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::trunc({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_ToFloat(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
|
||||||
# tickle though, as the inputs are typically positive (and if we can prove
|
|
||||||
# they are positive, we will have used Mod instead, for which this codegen
|
|
||||||
# is right).
|
|
||||||
def _print_PythonMod(self, expr):
|
|
||||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
def _print_CMod(self, expr):
|
|
||||||
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
||||||
|
|
||||||
def _print_IntTrueDiv(self, expr):
|
|
||||||
lhs, rhs = expr.args
|
|
||||||
# TODO: This is only accurate up to 2**53
|
|
||||||
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
|
||||||
|
|
||||||
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
|
||||||
# use std::pow, that operates on floats
|
|
||||||
def _print_PowByNatural(self, expr):
|
|
||||||
raise NotImplementedError(
|
|
||||||
f"_print_PowByNatural not implemented for {type(self)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _print_FloatTrueDiv(self, expr):
|
|
||||||
lhs, rhs = expr.args
|
|
||||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
|
||||||
|
|
||||||
def _print_FloatPow(self, expr):
|
|
||||||
base, exp = expr.args
|
|
||||||
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
|
||||||
|
|
||||||
def _print_Pow(self, expr):
|
|
||||||
# Uses float constants to perform FP div
|
|
||||||
base, exp = expr.args
|
|
||||||
base = self._print(base)
|
|
||||||
|
|
||||||
if exp == 0.5 or exp == -0.5:
|
|
||||||
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
|
||||||
if exp.is_integer:
|
|
||||||
exp = int(exp)
|
|
||||||
if exp > 0:
|
|
||||||
r = "*".join([self.paren(base)] * exp)
|
|
||||||
elif exp < 0:
|
|
||||||
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
|
||||||
else: # exp == 0
|
|
||||||
r = "1.0"
|
|
||||||
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
else:
|
|
||||||
# TODO: float vs double
|
|
||||||
return f"std::pow({base}, {float(exp)})"
|
|
||||||
|
|
||||||
def _print_Rational(self, expr):
|
|
||||||
# Uses float constants to perform FP div
|
|
||||||
if expr.q == 1:
|
|
||||||
r = f"{expr.p}"
|
|
||||||
else:
|
|
||||||
r = f"{expr.p}.0/{expr.q}.0"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
|
|
||||||
def _print_ceiling(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
r = f"std::ceil({self._print(expr.args[0])})"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
|
|
||||||
def _print_CeilToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
r = f"std::ceil({self._print(expr.args[0])})"
|
|
||||||
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
|
||||||
|
|
||||||
def _print_Min(self, expr):
|
|
||||||
args = [self._print(a) for a in expr.args]
|
|
||||||
if len(args) == 2:
|
|
||||||
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
|
||||||
else:
|
|
||||||
# Initializer list overload
|
|
||||||
il = "{" + ", ".join(args) + "}"
|
|
||||||
return f"std::min({il})"
|
|
||||||
|
|
||||||
def _print_Max(self, expr):
|
|
||||||
args = [self._print(a) for a in expr.args]
|
|
||||||
if len(args) == 2:
|
|
||||||
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
|
||||||
else:
|
|
||||||
# Initializer list overload
|
|
||||||
il = "{" + ", ".join(args) + "}"
|
|
||||||
return f"std::max({il})"
|
|
||||||
|
|
||||||
def _print_Abs(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::abs({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::cos({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::cosh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::acos({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::sin({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::sinh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::asin({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::tan({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::tanh({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
return f"std::atan({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_OpaqueUnaryFn_sqrt(self, expr):
|
|
||||||
return f"std::sqrt({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_RoundToInt(self, expr):
|
|
||||||
assert len(expr.args) == 1
|
|
||||||
# TODO: dispatch to llrint depending on index type
|
|
||||||
return f"std::lrint({self._print(expr.args[0])})"
|
|
||||||
|
|
||||||
def _print_RoundDecimal(self, expr):
|
|
||||||
assert len(expr.args) == 2
|
|
||||||
number, ndigits = expr.args
|
|
||||||
if number.is_integer:
|
|
||||||
# ndigits < 0 should have been filtered by the sympy function
|
|
||||||
assert ndigits < 0
|
|
||||||
raise ValueError(
|
|
||||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
|
||||||
)
|
|
||||||
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
|
|
||||||
|
|
||||||
def _print_BooleanTrue(self, expr):
|
|
||||||
return "true"
|
|
||||||
|
|
||||||
def _print_BooleanFalse(self, expr):
|
|
||||||
return "false"
|
|
||||||
|
|
||||||
|
|
||||||
# A function to print, useful for printing sympy symbols.
|
# A function to print, useful for printing sympy symbols.
|
||||||
|
|
|
||||||
|
|
@ -185,8 +185,8 @@ class HalidePrinter(PythonPrinter):
|
||||||
return super()._print_FloorDiv(expr)
|
return super()._print_FloorDiv(expr)
|
||||||
|
|
||||||
x, div = expr.args
|
x, div = expr.args
|
||||||
x = self.cast_float(self.paren(self.doprint(x)))
|
x = self.cast_float(self.doprint(x))
|
||||||
div = self.cast_float(self.paren(self.doprint(div)))
|
div = self.cast_float(self.doprint(div))
|
||||||
return self.cast_index(f"hl.floor({x} / {div})")
|
return self.cast_index(f"hl.floor({x} / {div})")
|
||||||
|
|
||||||
def _print_Round(self, expr):
|
def _print_Round(self, expr):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
from sympy.printing.precedence import PRECEDENCE
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._logging
|
import torch._logging
|
||||||
|
|
@ -504,30 +505,30 @@ class TritonPrinter(PythonPrinter):
|
||||||
|
|
||||||
def _print_ToFloat(self, expr):
|
def _print_ToFloat(self, expr):
|
||||||
assert len(expr.args) == 1
|
assert len(expr.args) == 1
|
||||||
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)"
|
s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||||
|
return f"{s}.to(tl.float64)"
|
||||||
|
|
||||||
def _print_PythonMod(self, expr):
|
def _print_PythonMod(self, expr):
|
||||||
quot, div = expr.args
|
quot, div = expr.args
|
||||||
|
if quot.is_nonnegative and div.is_nonnegative:
|
||||||
|
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||||
quot_s = self._print(quot)
|
quot_s = self._print(quot)
|
||||||
div_s = self._print(div)
|
div_s = self._print(div)
|
||||||
if quot.is_nonnegative and div.is_nonnegative:
|
|
||||||
return f"{self.paren(quot_s)} % {self.paren(div_s)}"
|
|
||||||
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
|
return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
|
||||||
|
|
||||||
def _print_FloorDiv(self, expr):
|
def _print_FloorDiv(self, expr):
|
||||||
assert expr.is_integer
|
assert expr.is_integer
|
||||||
quot, div = expr.args
|
quot, div = expr.args
|
||||||
|
if quot.is_nonnegative and div.is_nonnegative:
|
||||||
|
return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5)
|
||||||
quot_s = self._print(quot)
|
quot_s = self._print(quot)
|
||||||
div_s = self._print(div)
|
div_s = self._print(div)
|
||||||
if quot.is_nonnegative and div.is_nonnegative:
|
|
||||||
return f"({self.paren(quot_s)} // {self.paren(div_s)})"
|
|
||||||
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"
|
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"
|
||||||
|
|
||||||
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
|
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
|
||||||
# precision algorithm, which we would need to replicate here
|
# precision algorithm, which we would need to replicate here
|
||||||
def _print_IntTrueDiv(self, expr):
|
def _print_IntTrueDiv(self, expr):
|
||||||
lhs, rhs = expr.args
|
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||||
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
|
|
||||||
|
|
||||||
# NB: sympy.floor/ceiling produce integers, so we have to do the
|
# NB: sympy.floor/ceiling produce integers, so we have to do the
|
||||||
# conversion to index dtype
|
# conversion to index dtype
|
||||||
|
|
@ -646,7 +647,9 @@ class TritonPrinter(PythonPrinter):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||||
)
|
)
|
||||||
return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
|
|
||||||
|
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
|
||||||
|
return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}"
|
||||||
|
|
||||||
|
|
||||||
texpr = TritonPrinter().doprint
|
texpr = TritonPrinter().doprint
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,12 @@ from .autotune_process import (
|
||||||
TritonGPUBenchmarkRequest,
|
TritonGPUBenchmarkRequest,
|
||||||
)
|
)
|
||||||
from .codecache import code_hash, PersistentCache, PyCodeCache
|
from .codecache import code_hash, PersistentCache, PyCodeCache
|
||||||
from .codegen.common import IndentedBuffer, KernelTemplate, WorkspaceArg
|
from .codegen.common import IndentedBuffer, KernelTemplate, OpOverrides, WorkspaceArg
|
||||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||||
from .codegen.triton import (
|
from .codegen.triton import (
|
||||||
gen_common_triton_imports,
|
gen_common_triton_imports,
|
||||||
texpr,
|
texpr,
|
||||||
TritonKernel,
|
TritonKernel,
|
||||||
TritonPrinter,
|
|
||||||
TritonScheduling,
|
TritonScheduling,
|
||||||
)
|
)
|
||||||
from .codegen.triton_utils import config_of, signature_to_meta
|
from .codegen.triton_utils import config_of, signature_to_meta
|
||||||
|
|
@ -562,7 +561,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
assert isinstance(val, str)
|
assert isinstance(val, str)
|
||||||
assert isinstance(mask, (str, type(None)))
|
assert isinstance(mask, (str, type(None)))
|
||||||
assert self.template_mask is None
|
assert self.template_mask is None
|
||||||
indices = list(map(TritonPrinter.paren, indices))
|
indices = list(map(OpOverrides.paren, indices))
|
||||||
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
|
index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
|
||||||
lengths = [
|
lengths = [
|
||||||
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
|
||||||
|
|
@ -630,7 +629,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
assert isinstance(name, str)
|
assert isinstance(name, str)
|
||||||
assert isinstance(mask, str)
|
assert isinstance(mask, str)
|
||||||
stride = self.named_input_nodes[name].get_stride()
|
stride = self.named_input_nodes[name].get_stride()
|
||||||
indices = list(map(TritonPrinter.paren, indices))
|
indices = list(map(OpOverrides.paren, indices))
|
||||||
assert len(indices) == len(stride)
|
assert len(indices) == len(stride)
|
||||||
index = " + ".join(
|
index = " + ".join(
|
||||||
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,6 @@ need to make use of these APIs to setup dynamic shapes support appropriately.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import atexit
|
import atexit
|
||||||
import builtins
|
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
|
|
@ -82,6 +81,7 @@ from torch.utils._sympy.functions import (
|
||||||
PythonMod,
|
PythonMod,
|
||||||
)
|
)
|
||||||
from torch.utils._sympy.numbers import int_oo
|
from torch.utils._sympy.numbers import int_oo
|
||||||
|
from torch.utils._sympy.printers import PythonPrinter
|
||||||
from torch.utils._sympy.singleton_int import SingletonInt
|
from torch.utils._sympy.singleton_int import SingletonInt
|
||||||
from torch.utils._sympy.solve import try_solve
|
from torch.utils._sympy.solve import try_solve
|
||||||
from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
|
from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
|
||||||
|
|
@ -109,8 +109,6 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy import S
|
from sympy import S
|
||||||
from sympy.printing.precedence import PRECEDENCE, precedence
|
|
||||||
from sympy.printing.str import StrPrinter
|
|
||||||
|
|
||||||
|
|
||||||
class GuardOnDataDependentSymNode(RuntimeError):
|
class GuardOnDataDependentSymNode(RuntimeError):
|
||||||
|
|
@ -2032,45 +2030,9 @@ def cast_symbool_to_symint_guardless(
|
||||||
|
|
||||||
|
|
||||||
SYMPY_INTERP = {
|
SYMPY_INTERP = {
|
||||||
"Abs": operator.abs,
|
|
||||||
"Eq": operator.eq,
|
|
||||||
"Ne": operator.ne,
|
|
||||||
"Gt": operator.gt,
|
|
||||||
"Lt": operator.lt,
|
|
||||||
"Le": operator.le,
|
|
||||||
"Ge": operator.ge,
|
|
||||||
"Min": min,
|
|
||||||
"Max": max,
|
|
||||||
"Mod": operator.mod,
|
|
||||||
"PythonMod": operator.mod,
|
|
||||||
"FloorDiv": operator.floordiv,
|
|
||||||
"TrueDiv": operator.truediv,
|
|
||||||
"PowByNatural": operator.pow,
|
|
||||||
"IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
|
"IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
|
||||||
"floor": math.floor,
|
|
||||||
"ceiling": math.ceil,
|
|
||||||
"FloorToInt": math.floor,
|
|
||||||
"FloatPow": math.pow,
|
|
||||||
"CeilToInt": math.ceil,
|
|
||||||
"cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
|
"cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
|
||||||
"RoundToInt": builtins.round,
|
"math": math,
|
||||||
"RoundDecimal": builtins.round,
|
|
||||||
"TruncToInt": math.trunc,
|
|
||||||
"IntTrueDiv": operator.truediv,
|
|
||||||
"FloatTrueDiv": operator.truediv,
|
|
||||||
"ToFloat": builtins.float,
|
|
||||||
"OpaqueUnaryFn_cos": math.cos,
|
|
||||||
"OpaqueUnaryFn_cosh": math.cosh,
|
|
||||||
"OpaqueUnaryFn_acos": math.acos,
|
|
||||||
"OpaqueUnaryFn_sin": math.sin,
|
|
||||||
"OpaqueUnaryFn_sinh": math.sinh,
|
|
||||||
"OpaqueUnaryFn_asin": math.asin,
|
|
||||||
"OpaqueUnaryFn_tan": math.tan,
|
|
||||||
"OpaqueUnaryFn_tanh": math.tanh,
|
|
||||||
"OpaqueUnaryFn_atan": math.atan,
|
|
||||||
"OpaqueUnaryFn_sqrt": math.sqrt,
|
|
||||||
"BitwiseFn_bitwise_and": operator.and_,
|
|
||||||
"BitwiseFn_bitwise_or": operator.or_,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2141,12 +2103,12 @@ class RuntimeAssert:
|
||||||
|
|
||||||
|
|
||||||
# Used for printing SymExprs in compile_fx
|
# Used for printing SymExprs in compile_fx
|
||||||
class SymExprPrinter(StrPrinter):
|
class SymExprPrinter(PythonPrinter):
|
||||||
def _print_Float(self, expr: sympy.Float) -> str:
|
def _print_Float(self, expr: sympy.Float) -> str:
|
||||||
return str(float(expr))
|
return str(float(expr))
|
||||||
|
|
||||||
|
|
||||||
class ShapeGuardPrinter(SymExprPrinter):
|
class ShapeGuardPrinter(PythonPrinter):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
|
symbol_to_source: Mapping[sympy.Symbol, List[Source]],
|
||||||
|
|
@ -2158,14 +2120,8 @@ class ShapeGuardPrinter(SymExprPrinter):
|
||||||
self.source_ref = source_ref
|
self.source_ref = source_ref
|
||||||
self.var_to_sources = var_to_sources
|
self.var_to_sources = var_to_sources
|
||||||
|
|
||||||
def _print_Not(self, expr: SympyBoolean) -> str:
|
def _print_Float(self, expr: sympy.Float) -> str:
|
||||||
return "not {}".format(self.parenthesize(expr.args[0], PRECEDENCE["Not"]))
|
return str(float(expr))
|
||||||
|
|
||||||
def _print_And(self, expr: SympyBoolean) -> str:
|
|
||||||
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
|
|
||||||
|
|
||||||
def _print_Or(self, expr: SympyBoolean) -> str:
|
|
||||||
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
|
|
||||||
|
|
||||||
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
def _print_Symbol(self, expr: sympy.Symbol) -> str:
|
||||||
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
assert isinstance(expr, sympy.Symbol), str(type(expr))
|
||||||
|
|
@ -2191,7 +2147,7 @@ class LoggingShapeGuardPrinter(ShapeGuardPrinter):
|
||||||
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
|
||||||
|
|
||||||
|
|
||||||
class DynamicDimConstraintPrinter(StrPrinter):
|
class DynamicDimConstraintPrinter(PythonPrinter):
|
||||||
"""
|
"""
|
||||||
Printer for dynamic dim constraints.
|
Printer for dynamic dim constraints.
|
||||||
- Instead of symbol s_k it prints its source t.size()[i]
|
- Instead of symbol s_k it prints its source t.size()[i]
|
||||||
|
|
@ -2216,9 +2172,6 @@ class DynamicDimConstraintPrinter(StrPrinter):
|
||||||
), f"Unknown symbol {expr} created by constraints solver"
|
), f"Unknown symbol {expr} created by constraints solver"
|
||||||
return self.symbol_to_source[expr][0].name()
|
return self.symbol_to_source[expr][0].name()
|
||||||
|
|
||||||
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
|
|
||||||
return f"{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}" # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
|
|
||||||
class DimConstraints:
|
class DimConstraints:
|
||||||
"""
|
"""
|
||||||
|
|
@ -6656,7 +6609,7 @@ def _blame_user_code(e: Exception, frame: types.FrameType) -> None:
|
||||||
e.args = (msg,)
|
e.args = (msg,)
|
||||||
|
|
||||||
|
|
||||||
class _PythonPrinter(sympy.printing.str.StrPrinter):
|
class _PythonMsgPrinter(PythonPrinter):
|
||||||
"""
|
"""
|
||||||
Util printer that replaces sympy symbols with their source-level names
|
Util printer that replaces sympy symbols with their source-level names
|
||||||
and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
|
and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
|
||||||
|
|
@ -6670,13 +6623,6 @@ class _PythonPrinter(sympy.printing.str.StrPrinter):
|
||||||
def _print_Symbol(self, sym: sympy.Symbol) -> str:
|
def _print_Symbol(self, sym: sympy.Symbol) -> str:
|
||||||
return self.src_map[sym.name][0]
|
return self.src_map[sym.name][0]
|
||||||
|
|
||||||
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
|
|
||||||
lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr))
|
|
||||||
assert hasattr(expr, "rel_op")
|
|
||||||
rel_op = expr.rel_op
|
|
||||||
rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr))
|
|
||||||
return f"{lhs} {rel_op} {rhs}"
|
|
||||||
|
|
||||||
|
|
||||||
def _suggest_torch_checks(
|
def _suggest_torch_checks(
|
||||||
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
|
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
|
||||||
|
|
@ -6687,7 +6633,7 @@ def _suggest_torch_checks(
|
||||||
if diff:
|
if diff:
|
||||||
log.warning("Unable to find user code corresponding to {%s}", diff)
|
log.warning("Unable to find user code corresponding to {%s}", diff)
|
||||||
return
|
return
|
||||||
printer = _PythonPrinter(src_map)
|
printer = _PythonMsgPrinter(src_map)
|
||||||
msg = e.args[0]
|
msg = e.args[0]
|
||||||
msg += "\nTo fix the error, insert one of the following checks before this call:"
|
msg += "\nTo fix the error, insert one of the following checks before this call:"
|
||||||
# suggested fixes to resolve `cond`` are to tell the compiler to assume
|
# suggested fixes to resolve `cond`` are to tell the compiler to assume
|
||||||
|
|
|
||||||
|
|
@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nargs: Tuple[int, ...] = (2,)
|
nargs: Tuple[int, ...] = (2,)
|
||||||
precedence: int = 50 # precedence of mul # noqa: F811
|
precedence: int = 35 # lower precedence than add
|
||||||
is_integer: bool = True
|
is_integer: bool = True
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -291,6 +291,7 @@ class ModularIndexing(sympy.Function):
|
||||||
|
|
||||||
nargs: Tuple[int, ...] = (3,)
|
nargs: Tuple[int, ...] = (3,)
|
||||||
is_integer: bool = True
|
is_integer: bool = True
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(
|
def eval(
|
||||||
|
|
@ -360,6 +361,7 @@ class Where(sympy.Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
nargs: Tuple[int, ...] = (3,)
|
nargs: Tuple[int, ...] = (3,)
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
|
|
||||||
def _eval_is_integer(self) -> Optional[bool]:
|
def _eval_is_integer(self) -> Optional[bool]:
|
||||||
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
|
||||||
|
|
@ -389,6 +391,7 @@ class Where(sympy.Function):
|
||||||
class PythonMod(sympy.Function):
|
class PythonMod(sympy.Function):
|
||||||
nargs: Tuple[int, ...] = (2,)
|
nargs: Tuple[int, ...] = (2,)
|
||||||
|
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
is_integer: bool = True
|
is_integer: bool = True
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -447,6 +450,7 @@ class PythonMod(sympy.Function):
|
||||||
# Generic modulus: only defined on non-negative arguments
|
# Generic modulus: only defined on non-negative arguments
|
||||||
class Mod(sympy.Function):
|
class Mod(sympy.Function):
|
||||||
nargs = (2,)
|
nargs = (2,)
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
|
|
||||||
is_integer = True
|
is_integer = True
|
||||||
is_nonnegative = True
|
is_nonnegative = True
|
||||||
|
|
@ -1014,6 +1018,8 @@ def _safe_pow(base, exponent):
|
||||||
class PowByNatural(sympy.Function):
|
class PowByNatural(sympy.Function):
|
||||||
is_integer = True
|
is_integer = True
|
||||||
|
|
||||||
|
precedence: int = 50 # precedence of mul
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, base, exp):
|
def eval(cls, base, exp):
|
||||||
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
|
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
|
||||||
|
|
@ -1039,6 +1045,8 @@ class PowByNatural(sympy.Function):
|
||||||
class FloatPow(sympy.Function):
|
class FloatPow(sympy.Function):
|
||||||
is_real = True
|
is_real = True
|
||||||
|
|
||||||
|
precedence: int = 60 # precedence of pow
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, base, exp):
|
def eval(cls, base, exp):
|
||||||
# NB: These test sympy.Number, not sympy.Float, because:
|
# NB: These test sympy.Number, not sympy.Float, because:
|
||||||
|
|
@ -1059,6 +1067,8 @@ class FloatPow(sympy.Function):
|
||||||
class FloatTrueDiv(sympy.Function):
|
class FloatTrueDiv(sympy.Function):
|
||||||
is_real = True
|
is_real = True
|
||||||
|
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, base, divisor):
|
def eval(cls, base, divisor):
|
||||||
# assert base.is_integer is not True, base
|
# assert base.is_integer is not True, base
|
||||||
|
|
@ -1082,6 +1092,8 @@ class FloatTrueDiv(sympy.Function):
|
||||||
class IntTrueDiv(sympy.Function):
|
class IntTrueDiv(sympy.Function):
|
||||||
is_real = True
|
is_real = True
|
||||||
|
|
||||||
|
precedence: int = 35 # lower precedence than add
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def eval(cls, base, divisor):
|
def eval(cls, base, divisor):
|
||||||
if divisor.is_zero:
|
if divisor.is_zero:
|
||||||
|
|
@ -1254,6 +1266,8 @@ class Identity(sympy.Function):
|
||||||
Prevents expansion and other optimizations
|
Prevents expansion and other optimizations
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
precedence = 10
|
||||||
|
|
||||||
def __repr__(self): # type: ignore[override]
|
def __repr__(self): # type: ignore[override]
|
||||||
return f"Identity({self.args[0]})"
|
return f"Identity({self.args[0]})"
|
||||||
|
|
||||||
|
|
|
||||||
459
torch/utils/_sympy/printers.py
Normal file
459
torch/utils/_sympy/printers.py
Normal file
|
|
@ -0,0 +1,459 @@
|
||||||
|
import sys
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import sympy
|
||||||
|
from sympy.printing.precedence import PRECEDENCE, precedence
|
||||||
|
from sympy.printing.str import StrPrinter
|
||||||
|
|
||||||
|
|
||||||
|
INDEX_TYPE = "int64_t"
|
||||||
|
|
||||||
|
|
||||||
|
# This printer contains rules that are supposed to be generic for both C/C++ and
|
||||||
|
# Python
|
||||||
|
class ExprPrinter(StrPrinter):
|
||||||
|
# override this so that _print_FloorDiv is used
|
||||||
|
printmethod = "_torch_sympystr"
|
||||||
|
|
||||||
|
def _print_Mul(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, "*", precedence(expr))
|
||||||
|
|
||||||
|
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
|
||||||
|
return self.stringify(expr.args, " + ", precedence(expr))
|
||||||
|
|
||||||
|
def _print_Relational(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
|
||||||
|
|
||||||
|
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
# NB: this is OK to put here, because Mod is only defined for positive
|
||||||
|
# numbers, and so across C/Python its behavior is consistent
|
||||||
|
def _print_Mod(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
return f"({s})"
|
||||||
|
|
||||||
|
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
return self._print_FloorDiv(expr)
|
||||||
|
|
||||||
|
def _print_Identity(self, expr: sympy.Expr) -> str:
|
||||||
|
return self._print(expr.args[0])
|
||||||
|
|
||||||
|
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
|
||||||
|
# any explicit intervention. We print it just like x * x, notably, we
|
||||||
|
# never generate sympy.Pow with floats.
|
||||||
|
#
|
||||||
|
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||||
|
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||||
|
# means exp is guaranteed to be integer.
|
||||||
|
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||||
|
base, exp = expr.args
|
||||||
|
assert exp == int(exp), exp
|
||||||
|
exp = int(exp)
|
||||||
|
assert exp >= 0
|
||||||
|
if exp > 0:
|
||||||
|
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
||||||
|
return "1"
|
||||||
|
|
||||||
|
# Explicit NotImplemented functions are to prevent default sympy printing
|
||||||
|
# behavior, which will just barf out ToFloat(...) to your IR. The error
|
||||||
|
# message is better here because it tells you which printer class it needs
|
||||||
|
# to go in.
|
||||||
|
|
||||||
|
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_print_NegativeInfinity not implemented for {type(self)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_print_PowByNatural not implemented for {type(self)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
|
||||||
|
|
||||||
|
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_print_RoundDecimal not implemented for {type(self)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# NB: Some float operations are INTENTIONALLY not implemented for
|
||||||
|
# printers. You can implement them as a quick unblock, but it is better
|
||||||
|
# to ask yourself why we haven't done this computation in the Tensor
|
||||||
|
# universe instead
|
||||||
|
|
||||||
|
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_print_TruncToFloat not implemented for {type(self)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonPrinter(ExprPrinter):
|
||||||
|
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"float({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_And(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " and ", precedence(expr))
|
||||||
|
|
||||||
|
def _print_Or(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " or ", precedence(expr))
|
||||||
|
|
||||||
|
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||||
|
x, div, mod = (
|
||||||
|
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
|
||||||
|
)
|
||||||
|
if div != "1":
|
||||||
|
x = f"({x} // {div})"
|
||||||
|
return f"({x} % {mod})"
|
||||||
|
|
||||||
|
def _print_Infinity(self, expr: sympy.Expr) -> str:
|
||||||
|
return "math.inf"
|
||||||
|
|
||||||
|
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
|
||||||
|
return "-math.inf"
|
||||||
|
|
||||||
|
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||||
|
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
# WARNING: this is dangerous for Triton, which has C-style modulus
|
||||||
|
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
|
||||||
|
return f"{x} // {div}"
|
||||||
|
|
||||||
|
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
|
||||||
|
# does a special algorithm
|
||||||
|
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
def _helper_sqrt(self, expr: sympy.Expr) -> str:
|
||||||
|
return f"math.sqrt({self._print(expr)})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
||||||
|
return self._helper_sqrt(expr.args[0])
|
||||||
|
|
||||||
|
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
||||||
|
|
||||||
|
# TODO: Not sure this works with Triton, even when base/exp are integral
|
||||||
|
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
|
||||||
|
|
||||||
|
def _print_floor(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.floor({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.floor({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
# This also could have been int(), they'll do the same thing for float
|
||||||
|
return f"math.trunc({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.ceil({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.ceil({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"abs({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
# NB: It's expected that we've made explicit any promotion in the sympy
|
||||||
|
# expression, so it doesn't matter that Python max/min doesn't perform
|
||||||
|
# promotion
|
||||||
|
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) >= 2
|
||||||
|
return f"max({', '.join(map(self._print, expr.args))})"
|
||||||
|
|
||||||
|
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) >= 2
|
||||||
|
return f"min({', '.join(map(self._print, expr.args))})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.cos({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.cosh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.acos({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.sin({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.sinh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.asin({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.tan({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.tanh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"math.atan({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"round({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 2
|
||||||
|
number, ndigits = expr.args
|
||||||
|
assert isinstance(ndigits, sympy.Integer)
|
||||||
|
return f"round({self._print(number)}, {ndigits})"
|
||||||
|
|
||||||
|
|
||||||
|
class CppPrinter(ExprPrinter):
|
||||||
|
def _print_Integer(self, expr: sympy.Expr) -> str:
|
||||||
|
return (
|
||||||
|
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _print_Where(self, expr: sympy.Expr) -> str:
|
||||||
|
c, p, q = (
|
||||||
|
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
|
||||||
|
)
|
||||||
|
return f"{c} ? {p} : {q}"
|
||||||
|
|
||||||
|
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
|
||||||
|
x, div, mod = expr.args
|
||||||
|
x = self.doprint(x)
|
||||||
|
if div != 1:
|
||||||
|
div = self.doprint(div)
|
||||||
|
if expr.is_integer:
|
||||||
|
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||||
|
else:
|
||||||
|
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||||
|
mod = self.doprint(mod)
|
||||||
|
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
|
||||||
|
|
||||||
|
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
x, div = expr.args
|
||||||
|
x = self.doprint(x)
|
||||||
|
div = self.doprint(div)
|
||||||
|
if expr.is_integer:
|
||||||
|
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
|
||||||
|
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
|
||||||
|
|
||||||
|
def _print_floor(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
r = f"std::floor({self._print(expr.args[0])})"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
|
||||||
|
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
r = f"std::floor({self._print(expr.args[0])})"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
|
||||||
|
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
r = f"std::trunc({self._print(expr.args[0])})"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})"
|
||||||
|
|
||||||
|
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::trunc({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_ToFloat(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||||
|
# tickle though, as the inputs are typically positive (and if we can prove
|
||||||
|
# they are positive, we will have used Mod instead, for which this codegen
|
||||||
|
# is right).
|
||||||
|
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||||
|
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||||
|
|
||||||
|
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||||
|
lhs, rhs = expr.args
|
||||||
|
# TODO: This is only accurate up to 2**53
|
||||||
|
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
|
||||||
|
|
||||||
|
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
|
||||||
|
# use std::pow, that operates on floats
|
||||||
|
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
|
||||||
|
raise NotImplementedError(
|
||||||
|
f"_print_PowByNatural not implemented for {type(self)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _print_FloatPow(self, expr: sympy.Expr) -> str:
|
||||||
|
base, exp = expr.args
|
||||||
|
return f"std::pow({self._print(base)}, {self._print(exp)})"
|
||||||
|
|
||||||
|
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||||
|
# Uses float constants to perform FP div
|
||||||
|
base, exp = expr.args
|
||||||
|
|
||||||
|
if exp == 0.5 or exp == -0.5:
|
||||||
|
base = self._print(base)
|
||||||
|
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
|
||||||
|
if exp.is_integer:
|
||||||
|
exp = int(exp)
|
||||||
|
if exp > 0:
|
||||||
|
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
|
||||||
|
elif exp < -1:
|
||||||
|
r = (
|
||||||
|
"1.0/("
|
||||||
|
+ self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
|
||||||
|
+ ")"
|
||||||
|
)
|
||||||
|
elif exp == -1:
|
||||||
|
r = "1.0/" + self._print(base)
|
||||||
|
else: # exp == 0
|
||||||
|
r = "1.0"
|
||||||
|
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
else:
|
||||||
|
# TODO: float vs double
|
||||||
|
return f"std::pow({base}, {float(exp)})"
|
||||||
|
|
||||||
|
def _print_Rational(self, expr: sympy.Expr) -> str:
|
||||||
|
# Uses float constants to perform FP div
|
||||||
|
if expr.q == 1:
|
||||||
|
r = f"{expr.p}"
|
||||||
|
else:
|
||||||
|
r = f"{expr.p}.0/{expr.q}.0"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
|
||||||
|
def _print_ceiling(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
r = f"std::ceil({self._print(expr.args[0])})"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
|
||||||
|
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
r = f"std::ceil({self._print(expr.args[0])})"
|
||||||
|
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
|
||||||
|
|
||||||
|
def _print_Min(self, expr: sympy.Expr) -> str:
|
||||||
|
args = [self._print(a) for a in expr.args]
|
||||||
|
if len(args) == 2:
|
||||||
|
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||||
|
else:
|
||||||
|
# Initializer list overload
|
||||||
|
il = "{" + ", ".join(args) + "}"
|
||||||
|
return f"std::min({il})"
|
||||||
|
|
||||||
|
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||||
|
args = [self._print(a) for a in expr.args]
|
||||||
|
if len(args) == 2:
|
||||||
|
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
|
||||||
|
else:
|
||||||
|
# Initializer list overload
|
||||||
|
il = "{" + ", ".join(args) + "}"
|
||||||
|
return f"std::max({il})"
|
||||||
|
|
||||||
|
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::abs({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::cos({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::cosh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::acos({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::sin({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::sinh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::asin({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::tan({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::tanh({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
return f"std::atan({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
|
||||||
|
return f"std::sqrt({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 1
|
||||||
|
# TODO: dispatch to llrint depending on index type
|
||||||
|
return f"std::lrint({self._print(expr.args[0])})"
|
||||||
|
|
||||||
|
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
|
||||||
|
assert len(expr.args) == 2
|
||||||
|
number, ndigits = expr.args
|
||||||
|
if number.is_integer:
|
||||||
|
# ndigits < 0 should have been filtered by the sympy function
|
||||||
|
assert ndigits < 0
|
||||||
|
raise ValueError(
|
||||||
|
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
|
||||||
|
)
|
||||||
|
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
|
||||||
|
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
|
||||||
|
|
||||||
|
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
|
||||||
|
return "true"
|
||||||
|
|
||||||
|
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
|
||||||
|
return "false"
|
||||||
Loading…
Reference in New Issue
Block a user