Move Sympy printers to torch/utils/_sympy/printers.py (#140597)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597
Approved by: https://github.com/ezyang, https://github.com/anijain2305
This commit is contained in:
Isuru Fernando 2024-11-25 22:22:07 +00:00 committed by PyTorch MergeBot
parent 29ca44839e
commit 44186a0a4e
20 changed files with 589 additions and 579 deletions

View File

@ -580,8 +580,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase):
.check_regex( .check_regex(
"torch.ops._c10d_functional.all_to_all_single.default\\(" "torch.ops._c10d_functional.all_to_all_single.default\\("
"arg\\d+_\\d+, " "arg\\d+_\\d+, "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], " "\\[s\\d+ // \\d, s\\d+ // \\d\\], "
"\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]" "\\[s\\d+ // \\d, s\\d+ // \\d\\]"
) )
.run(code) .run(code)
) )

View File

@ -3570,7 +3570,7 @@ class GraphModule(torch.nn.Module):
"cast_symbool_to_symint_guardless(L['pred']) == 1", "cast_symbool_to_symint_guardless(L['pred']) == 1",
] ]
false_guard_code = [ false_guard_code = [
"Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", "cast_symbool_to_symint_guardless(L['pred']) != 1",
] ]
test_symbool_guards( test_symbool_guards(
f, f,

View File

@ -668,7 +668,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre
"""\ """\
+- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # +- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in #
+- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) +- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)
+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in # +- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in #
+- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 +- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950
) )

View File

@ -10457,7 +10457,7 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match:
==> axioms: values don't match. ==> axioms: values don't match.
> Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False} > Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False}
> Right: {} > Right: {}
==> divisible: values don't match. ==> divisible: values don't match.
> Left: {Mod(s0, 3)} > Left: {Mod(s0, 3)}
@ -10576,7 +10576,7 @@ ShapeEnv not equal: field values don't match:
ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match:
==> axioms: values don't match. ==> axioms: values don't match.
> Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True} > Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False}
> Right: {} > Right: {}
==> deferred_runtime_asserts: values don't match. ==> deferred_runtime_asserts: values don't match.
> Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Left: {u0: [Eq(PythonMod(u0, 3), 0)]}

View File

@ -3259,9 +3259,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
(torch.tensor(20),), (torch.tensor(20),),
fixes=[ fixes=[
# Could not guard on data-dependent expression Eq((u0//2), 0) # Could not guard on data-dependent expression Eq((u0//2), 0)
"torch._check(((i//2)) != 0)", "torch._check((i // 2) != 0)",
# Could not guard on data-dependent expression Eq((u0//2), 1) # Could not guard on data-dependent expression Eq((u0//2), 1)
"torch._check(((i//2)) != 1)", "torch._check((i // 2) != 1)",
], ],
) )

View File

@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel,
xoffset = tl.program_id(0) * XBLOCK xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.arange(0, XBLOCK)[:] xindex = xoffset + tl.arange(0, XBLOCK)[:]
xmask = xindex < xnumel xmask = xindex < xnumel
x0 = xindex % 20 x0 = (xindex % 20)
x1 = (xindex // 20) % 20 x1 = ((xindex // 20) % 20)
x2 = (xindex // 400) x2 = xindex // 400
x3 = xindex x3 = xindex
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last')
tmp2 = tmp0 + tmp1 tmp2 = tmp0 + tmp1
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
) )

View File

@ -20,7 +20,9 @@ from torch.testing._internal.common_utils import (
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU
from torch.utils._sympy.functions import ( from torch.utils._sympy.functions import (
FloorDiv, FloorDiv,
Mod,
ModularIndexing, ModularIndexing,
PythonMod,
RoundDecimal, RoundDecimal,
RoundToInt, RoundToInt,
) )
@ -236,7 +238,7 @@ class TestIndexingSimplification(InductorTestCase):
triton_code = run_and_get_triton_code(f, x) triton_code = run_and_get_triton_code(f, x)
# Make sure the 2 load uses simpified indexing rather than something like # Make sure the 2 load uses simpified indexing rather than something like
# tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)),
self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),"))
if DO_PERF_TEST: if DO_PERF_TEST:
ms = benchmarker.benchmark_gpu(lambda: f(x)) ms = benchmarker.benchmark_gpu(lambda: f(x))
print(f"{ms=:.03f}") print(f"{ms=:.03f}")
@ -313,6 +315,39 @@ class ExprPrinterTests(InductorTestCase):
self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""")
self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""")
def test_print_mod(self):
x = sympy.Symbol("x", integer=True)
expr = Mod(x - 1, 2)
self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""")
self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""")
self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""")
expr = (x - 10) % x
self.assertExpectedInline(pexpr(expr), """(-10) % x""")
self.assertExpectedInline(cexpr(expr), """(-10L) % x""")
self.assertExpectedInline(texpr(expr), """(-10) % x""")
def test_print_mod_index(self):
x = sympy.Symbol("x", integer=True)
ks = sympy.Symbol("ks", integer=True)
expr = ModularIndexing(x - 10, ks, ks)
self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""")
self.assertExpectedInline(
cexpr(expr),
"""(static_cast<int64_t>(c10::div_floor_integer("""
"""static_cast<int64_t>((-10L) + x), static_cast<int64_t>(ks))) % static_cast<int64_t>(ks))""",
)
self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""")
def test_print_python_mod(self):
x = sympy.Symbol("x", integer=True)
expr = PythonMod(x - 10, x)
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""")
self.assertExpectedInline(
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
)
@parametrize("ndigits", [-1, 0, 1]) @parametrize("ndigits", [-1, 0, 1])
def test_print_round_decimal(self, ndigits): def test_print_round_decimal(self, ndigits):
expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits) expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits)
@ -330,7 +365,7 @@ class ExprPrinterTests(InductorTestCase):
s1 = sympy.Symbol("s1", integer=True) s1 = sympy.Symbol("s1", integer=True)
s2 = sympy.Symbol("s2", integer=True) s2 = sympy.Symbol("s2", integer=True)
expr = FloorDiv(s1, s2) expr = FloorDiv(s1, s2)
self.assertEqual(pexpr(expr), "(s1 // s2)") self.assertEqual(pexpr(expr), "s1 // s2")
self.assertEqual( self.assertEqual(
cexpr(expr), cexpr(expr),
"c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))", "c10::div_floor_integer(static_cast<int64_t>(s1), static_cast<int64_t>(s2))",

View File

@ -58,13 +58,11 @@ class TestMemoryPlanning(TestCase):
result, code = run_and_get_cpp_code(compiled, *args) result, code = run_and_get_cpp_code(compiled, *args)
FileCheck().check( FileCheck().check(
"pool1 = empty_strided_" "pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )"
+ GPU_TYPE
+ "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )"
).check_next( ).check_next(
"buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))"
).check( ).check(
"buf1 = alloc_from_pool(pool1, align(4*(s0*s0))," "buf1 = alloc_from_pool(pool1, align(4*s0*s0),"
).run( ).run(
code code
) )
@ -103,7 +101,7 @@ class TestMemoryPlanning(TestCase):
) )
FileCheck().check( FileCheck().check(
"int64_t int_array_2[] = {24L + (align(12L*s0)), };" "int64_t int_array_2[] = {24L + align(12L*s0), };"
).check_next("int64_t int_array_3[] = {1L, };").check_next( ).check_next("int64_t int_array_3[] = {1L, };").check_next(
"AtenTensorHandle pool1_handle;" "AtenTensorHandle pool1_handle;"
).check_next( ).check_next(

View File

@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase):
# make sure the load for softmax is aligned # make sure the load for softmax is aligned
self.assertTrue( self.assertTrue(
"tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper, "tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper,
f"forward_wrapper: {forward_wrapper}", f"forward_wrapper: {forward_wrapper}",
) )

View File

@ -12505,8 +12505,8 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline( self.assertExpectedInline(
"\n".join(lines), "\n".join(lines),
"""\ """\
tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0)
tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""", tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""",
) )
@config.patch("triton.use_block_ptr", True) @config.patch("triton.use_block_ptr", True)
@ -12538,7 +12538,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
self.assertExpectedInline( self.assertExpectedInline(
"\n".join(lines), "\n".join(lines),
"""\ """\
tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK]) tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK])
tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long
) )

View File

@ -275,7 +275,7 @@ class TritonBlockPointerTest(InductorTestCase):
"\n".join(load_lines), "\n".join(load_lines),
"""\ """\
tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0]) tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0])
tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950
) )
self.assertExpectedInline( self.assertExpectedInline(
"\n".join(store_lines), "\n".join(store_lines),

View File

@ -2792,8 +2792,8 @@ class TestGuardsExpressions(TestCase):
guard_int(sym_int(s0 / 2.0)) guard_int(sym_int(s0 / 2.0))
guards = shape_env.produce_guards_expression([s0]) guards = shape_env.produce_guards_expression([s0])
self.assertIn("ToFloat", guards) self.assertIn("math.trunc(", guards)
self.assertIn("FloatTrueDiv", guards) self.assertIn("float(", guards)
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))

View File

@ -23,7 +23,6 @@ from typing import (
) )
import sympy import sympy
from sympy.printing.printer import Printer
import torch import torch
import torch.fx import torch.fx
@ -31,6 +30,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from torch.utils import _pytree as pytree from torch.utils import _pytree as pytree
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter
from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT
from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges
@ -609,12 +609,22 @@ class DataTypePropagation:
DataTypePropagation.propagate_loopbody(node._body) DataTypePropagation.propagate_loopbody(node._body)
# This printer contains rules that are supposed to be generic for both C/C++ and class PythonPrinter(_PythonPrinter):
# Python def doprint(self, expr, *, simplify: bool = True, p=True):
class ExprPrinter(Printer): # TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
@staticmethod @staticmethod
def paren(string): def paren(string: str) -> str:
def all_in_parens(string): def all_in_parens(string: str) -> bool:
if string[0] != "(" or len(string) < 2: if string[0] != "(" or len(string) < 2:
return False return False
count = 1 count = 1
@ -640,260 +650,6 @@ class ExprPrinter(Printer):
return string return string
return f"({string})" return f"({string})"
def _print_Relational(self, expr):
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
def _print_Mul(self, expr):
return "*".join(map(self.paren, map(self._print, expr.args)))
def _print_Add(self, expr):
return " + ".join(map(self.paren, map(self._print, expr.args)))
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_CleanDiv(self, expr):
return self._print_FloorDiv(expr)
def _print_Identity(self, expr):
return self._print(expr.args[0])
def _print_GreaterThan(self, expr):
# GreaterThan: >=
# StrictlyGreaterThan: >
# Go figure...
return " >= ".join(map(self.paren, map(self._print, expr.args)))
# NB: The C implementation is injected into codegen at
# torch/_inductor/codegen/wrapper.py
def _print_align(self, expr):
assert len(expr.args) == 1
return f"align({self._print(expr.args[0])})"
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr):
base, exp = expr.args
base = self._print(base)
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return "*".join([self.paren(base)] * exp)
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr):
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr):
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr):
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr):
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr):
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr):
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr):
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr):
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr):
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr):
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr):
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
def doprint(self, expr, *, simplify: bool = True):
# TODO: why are people passing strings to the printer here :think:
if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
mod = self.paren(self.doprint(mod))
if div != "1":
x = f"({x} // {div})"
return f"{x} % {mod}"
def _print_Infinity(self, expr):
return "math.inf"
def _print_NegativeInfinity(self, expr):
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
return f"({x} // {div})"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _helper_sqrt(self, expr):
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr):
base, exp = expr.args
return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}"
def _print_floor(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr):
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr):
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
def __getattr__(self, item): def __getattr__(self, item):
return getattr(self._parent, item) return getattr(self._parent, item)
@ -982,31 +738,31 @@ class OpOverrides:
@staticmethod @staticmethod
def bitwise_not(x): def bitwise_not(x):
return f"~{ExprPrinter.paren(x)}" return f"~{OpOverrides.paren(x)}"
@staticmethod @staticmethod
def logical_not(a): def logical_not(a):
return f"{ExprPrinter.paren(a)} == 0" return f"{OpOverrides.paren(a)} == 0"
@staticmethod @staticmethod
def bitwise_and(x, y): def bitwise_and(x, y):
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}"
@staticmethod @staticmethod
def bitwise_or(x, y): def bitwise_or(x, y):
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}"
@staticmethod @staticmethod
def bitwise_xor(x, y): def bitwise_xor(x, y):
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}"
@staticmethod @staticmethod
def bitwise_left_shift(x, y): def bitwise_left_shift(x, y):
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}"
@staticmethod @staticmethod
def bitwise_right_shift(x, y): def bitwise_right_shift(x, y):
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
@staticmethod @staticmethod
def remainder(a, b): def remainder(a, b):

View File

@ -13,6 +13,7 @@ import sympy
import torch import torch
from torch._prims_common import is_integer_dtype from torch._prims_common import is_integer_dtype
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.printers import CppPrinter as _CppPrinter
from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.symbol import symbol_is_type, SymT
from torch.utils._sympy.value_ranges import ValueRanges from torch.utils._sympy.value_ranges import ValueRanges
@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V
from .common import ( from .common import (
CSEVariable, CSEVariable,
deduce_output_dtype_by_name, deduce_output_dtype_by_name,
ExprPrinter,
Kernel, Kernel,
KernelArgs, KernelArgs,
OptimizationContext, OptimizationContext,
@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable):
return itervar in self.dependent_itervars return itervar in self.dependent_itervars
class CppPrinter(ExprPrinter): class CppPrinter(_CppPrinter):
def _print_Integer(self, expr): def doprint(self, expr, *, simplify: bool = True, p=True):
return ( # TODO: why are people passing strings to the printer here :think:
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"):
) expr = V.graph.sizevars.simplify(expr)
return super().doprint(expr)
def _print_Where(self, expr):
c = self.paren(self.doprint(expr.args[0]))
p = self.paren(self.doprint(expr.args[1]))
q = self.paren(self.doprint(expr.args[2]))
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr):
x, div, mod = expr.args
x = self.paren(self.doprint(x))
if div != 1:
div = self.paren(self.doprint(div))
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.paren(self.doprint(mod))
return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})"
def _print_FloorDiv(self, expr):
x, div = expr.args
x = self.paren(self.doprint(x))
div = self.paren(self.doprint(div))
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr):
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr):
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr):
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_CMod(self, expr):
return " % ".join(map(self.paren, map(self._print, expr.args)))
def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr):
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatTrueDiv(self, expr):
lhs, rhs = expr.args
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
def _print_FloatPow(self, expr):
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr):
# Uses float constants to perform FP div
base, exp = expr.args
base = self._print(base)
if exp == 0.5 or exp == -0.5:
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = "*".join([self.paren(base)] * exp)
elif exp < 0:
r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp)))
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr):
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr):
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr):
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr):
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr):
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr):
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr):
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr):
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr):
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr):
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr):
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr):
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr):
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr):
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr):
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr):
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr):
return "true"
def _print_BooleanFalse(self, expr):
return "false"
# A function to print, useful for printing sympy symbols. # A function to print, useful for printing sympy symbols.

View File

@ -185,8 +185,8 @@ class HalidePrinter(PythonPrinter):
return super()._print_FloorDiv(expr) return super()._print_FloorDiv(expr)
x, div = expr.args x, div = expr.args
x = self.cast_float(self.paren(self.doprint(x))) x = self.cast_float(self.doprint(x))
div = self.cast_float(self.paren(self.doprint(div))) div = self.cast_float(self.doprint(div))
return self.cast_index(f"hl.floor({x} / {div})") return self.cast_index(f"hl.floor({x} / {div})")
def _print_Round(self, expr): def _print_Round(self, expr):

View File

@ -27,6 +27,7 @@ from typing import (
) )
import sympy import sympy
from sympy.printing.precedence import PRECEDENCE
import torch import torch
import torch._logging import torch._logging
@ -504,30 +505,30 @@ class TritonPrinter(PythonPrinter):
def _print_ToFloat(self, expr): def _print_ToFloat(self, expr):
assert len(expr.args) == 1 assert len(expr.args) == 1
return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5)
return f"{s}.to(tl.float64)"
def _print_PythonMod(self, expr): def _print_PythonMod(self, expr):
quot, div = expr.args quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
quot_s = self._print(quot) quot_s = self._print(quot)
div_s = self._print(div) div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"{self.paren(quot_s)} % {self.paren(div_s)}"
return f"triton_helpers.remainder_integer({quot_s}, {div_s})" return f"triton_helpers.remainder_integer({quot_s}, {div_s})"
def _print_FloorDiv(self, expr): def _print_FloorDiv(self, expr):
assert expr.is_integer assert expr.is_integer
quot, div = expr.args quot, div = expr.args
if quot.is_nonnegative and div.is_nonnegative:
return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5)
quot_s = self._print(quot) quot_s = self._print(quot)
div_s = self._print(div) div_s = self._print(div)
if quot.is_nonnegative and div.is_nonnegative:
return f"({self.paren(quot_s)} // {self.paren(div_s)})"
return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" return f"triton_helpers.div_floor_integer({quot_s}, {div_s})"
# TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher
# precision algorithm, which we would need to replicate here # precision algorithm, which we would need to replicate here
def _print_IntTrueDiv(self, expr): def _print_IntTrueDiv(self, expr):
lhs, rhs = expr.args return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}"
# NB: sympy.floor/ceiling produce integers, so we have to do the # NB: sympy.floor/ceiling produce integers, so we have to do the
# conversion to index dtype # conversion to index dtype
@ -646,7 +647,9 @@ class TritonPrinter(PythonPrinter):
raise ValueError( raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
) )
return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}"
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}"
texpr = TritonPrinter().doprint texpr = TritonPrinter().doprint

View File

@ -35,13 +35,12 @@ from .autotune_process import (
TritonGPUBenchmarkRequest, TritonGPUBenchmarkRequest,
) )
from .codecache import code_hash, PersistentCache, PyCodeCache from .codecache import code_hash, PersistentCache, PyCodeCache
from .codegen.common import IndentedBuffer, KernelTemplate, WorkspaceArg from .codegen.common import IndentedBuffer, KernelTemplate, OpOverrides, WorkspaceArg
from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.simd_kernel_features import SIMDKernelFeatures
from .codegen.triton import ( from .codegen.triton import (
gen_common_triton_imports, gen_common_triton_imports,
texpr, texpr,
TritonKernel, TritonKernel,
TritonPrinter,
TritonScheduling, TritonScheduling,
) )
from .codegen.triton_utils import config_of, signature_to_meta from .codegen.triton_utils import config_of, signature_to_meta
@ -562,7 +561,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(val, str) assert isinstance(val, str)
assert isinstance(mask, (str, type(None))) assert isinstance(mask, (str, type(None)))
assert self.template_mask is None assert self.template_mask is None
indices = list(map(TritonPrinter.paren, indices)) indices = list(map(OpOverrides.paren, indices))
index_symbols = [sympy.Symbol(x, integer=True) for x in indices] index_symbols = [sympy.Symbol(x, integer=True) for x in indices]
lengths = [ lengths = [
V.graph.sizevars.simplify(s) for s in self.output_node.get_size() V.graph.sizevars.simplify(s) for s in self.output_node.get_size()
@ -630,7 +629,7 @@ class TritonTemplateKernel(TritonKernel):
assert isinstance(name, str) assert isinstance(name, str)
assert isinstance(mask, str) assert isinstance(mask, str)
stride = self.named_input_nodes[name].get_stride() stride = self.named_input_nodes[name].get_stride()
indices = list(map(TritonPrinter.paren, indices)) indices = list(map(OpOverrides.paren, indices))
assert len(indices) == len(stride) assert len(indices) == len(stride)
index = " + ".join( index = " + ".join(
f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices)

View File

@ -10,7 +10,6 @@ need to make use of these APIs to setup dynamic shapes support appropriately.
""" """
import atexit import atexit
import builtins
import collections import collections
import functools import functools
import inspect import inspect
@ -82,6 +81,7 @@ from torch.utils._sympy.functions import (
PythonMod, PythonMod,
) )
from torch.utils._sympy.numbers import int_oo from torch.utils._sympy.numbers import int_oo
from torch.utils._sympy.printers import PythonPrinter
from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.singleton_int import SingletonInt
from torch.utils._sympy.solve import try_solve from torch.utils._sympy.solve import try_solve
from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT
@ -109,8 +109,6 @@ log = logging.getLogger(__name__)
import sympy import sympy
from sympy import S from sympy import S
from sympy.printing.precedence import PRECEDENCE, precedence
from sympy.printing.str import StrPrinter
class GuardOnDataDependentSymNode(RuntimeError): class GuardOnDataDependentSymNode(RuntimeError):
@ -2032,45 +2030,9 @@ def cast_symbool_to_symint_guardless(
SYMPY_INTERP = { SYMPY_INTERP = {
"Abs": operator.abs,
"Eq": operator.eq,
"Ne": operator.ne,
"Gt": operator.gt,
"Lt": operator.lt,
"Le": operator.le,
"Ge": operator.ge,
"Min": min,
"Max": max,
"Mod": operator.mod,
"PythonMod": operator.mod,
"FloorDiv": operator.floordiv,
"TrueDiv": operator.truediv,
"PowByNatural": operator.pow,
"IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense,
"floor": math.floor,
"ceiling": math.ceil,
"FloorToInt": math.floor,
"FloatPow": math.pow,
"CeilToInt": math.ceil,
"cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless,
"RoundToInt": builtins.round, "math": math,
"RoundDecimal": builtins.round,
"TruncToInt": math.trunc,
"IntTrueDiv": operator.truediv,
"FloatTrueDiv": operator.truediv,
"ToFloat": builtins.float,
"OpaqueUnaryFn_cos": math.cos,
"OpaqueUnaryFn_cosh": math.cosh,
"OpaqueUnaryFn_acos": math.acos,
"OpaqueUnaryFn_sin": math.sin,
"OpaqueUnaryFn_sinh": math.sinh,
"OpaqueUnaryFn_asin": math.asin,
"OpaqueUnaryFn_tan": math.tan,
"OpaqueUnaryFn_tanh": math.tanh,
"OpaqueUnaryFn_atan": math.atan,
"OpaqueUnaryFn_sqrt": math.sqrt,
"BitwiseFn_bitwise_and": operator.and_,
"BitwiseFn_bitwise_or": operator.or_,
} }
@ -2141,12 +2103,12 @@ class RuntimeAssert:
# Used for printing SymExprs in compile_fx # Used for printing SymExprs in compile_fx
class SymExprPrinter(StrPrinter): class SymExprPrinter(PythonPrinter):
def _print_Float(self, expr: sympy.Float) -> str: def _print_Float(self, expr: sympy.Float) -> str:
return str(float(expr)) return str(float(expr))
class ShapeGuardPrinter(SymExprPrinter): class ShapeGuardPrinter(PythonPrinter):
def __init__( def __init__(
self, self,
symbol_to_source: Mapping[sympy.Symbol, List[Source]], symbol_to_source: Mapping[sympy.Symbol, List[Source]],
@ -2158,14 +2120,8 @@ class ShapeGuardPrinter(SymExprPrinter):
self.source_ref = source_ref self.source_ref = source_ref
self.var_to_sources = var_to_sources self.var_to_sources = var_to_sources
def _print_Not(self, expr: SympyBoolean) -> str: def _print_Float(self, expr: sympy.Float) -> str:
return "not {}".format(self.parenthesize(expr.args[0], PRECEDENCE["Not"])) return str(float(expr))
def _print_And(self, expr: SympyBoolean) -> str:
return self.stringify(expr.args, " and ", PRECEDENCE["And"])
def _print_Or(self, expr: SympyBoolean) -> str:
return self.stringify(expr.args, " or ", PRECEDENCE["Or"])
def _print_Symbol(self, expr: sympy.Symbol) -> str: def _print_Symbol(self, expr: sympy.Symbol) -> str:
assert isinstance(expr, sympy.Symbol), str(type(expr)) assert isinstance(expr, sympy.Symbol), str(type(expr))
@ -2191,7 +2147,7 @@ class LoggingShapeGuardPrinter(ShapeGuardPrinter):
super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) super().__init__(var_to_sources, lambda n: n.name(), var_to_sources)
class DynamicDimConstraintPrinter(StrPrinter): class DynamicDimConstraintPrinter(PythonPrinter):
""" """
Printer for dynamic dim constraints. Printer for dynamic dim constraints.
- Instead of symbol s_k it prints its source t.size()[i] - Instead of symbol s_k it prints its source t.size()[i]
@ -2216,9 +2172,6 @@ class DynamicDimConstraintPrinter(StrPrinter):
), f"Unknown symbol {expr} created by constraints solver" ), f"Unknown symbol {expr} created by constraints solver"
return self.symbol_to_source[expr][0].name() return self.symbol_to_source[expr][0].name()
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
return f"{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}" # type: ignore[attr-defined]
class DimConstraints: class DimConstraints:
""" """
@ -6656,7 +6609,7 @@ def _blame_user_code(e: Exception, frame: types.FrameType) -> None:
e.args = (msg,) e.args = (msg,)
class _PythonPrinter(sympy.printing.str.StrPrinter): class _PythonMsgPrinter(PythonPrinter):
""" """
Util printer that replaces sympy symbols with their source-level names Util printer that replaces sympy symbols with their source-level names
and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline
@ -6670,13 +6623,6 @@ class _PythonPrinter(sympy.printing.str.StrPrinter):
def _print_Symbol(self, sym: sympy.Symbol) -> str: def _print_Symbol(self, sym: sympy.Symbol) -> str:
return self.src_map[sym.name][0] return self.src_map[sym.name][0]
def _print_Relational(self, expr: sympy.core.relational.Relational) -> str:
lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr))
assert hasattr(expr, "rel_op")
rel_op = expr.rel_op
rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr))
return f"{lhs} {rel_op} {rhs}"
def _suggest_torch_checks( def _suggest_torch_checks(
e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]] e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]]
@ -6687,7 +6633,7 @@ def _suggest_torch_checks(
if diff: if diff:
log.warning("Unable to find user code corresponding to {%s}", diff) log.warning("Unable to find user code corresponding to {%s}", diff)
return return
printer = _PythonPrinter(src_map) printer = _PythonMsgPrinter(src_map)
msg = e.args[0] msg = e.args[0]
msg += "\nTo fix the error, insert one of the following checks before this call:" msg += "\nTo fix the error, insert one of the following checks before this call:"
# suggested fixes to resolve `cond`` are to tell the compiler to assume # suggested fixes to resolve `cond`` are to tell the compiler to assume

View File

@ -191,7 +191,7 @@ class FloorDiv(sympy.Function):
""" """
nargs: Tuple[int, ...] = (2,) nargs: Tuple[int, ...] = (2,)
precedence: int = 50 # precedence of mul # noqa: F811 precedence: int = 35 # lower precedence than add
is_integer: bool = True is_integer: bool = True
@property @property
@ -291,6 +291,7 @@ class ModularIndexing(sympy.Function):
nargs: Tuple[int, ...] = (3,) nargs: Tuple[int, ...] = (3,)
is_integer: bool = True is_integer: bool = True
precedence: int = 35 # lower precedence than add
@classmethod @classmethod
def eval( def eval(
@ -360,6 +361,7 @@ class Where(sympy.Function):
""" """
nargs: Tuple[int, ...] = (3,) nargs: Tuple[int, ...] = (3,)
precedence: int = 35 # lower precedence than add
def _eval_is_integer(self) -> Optional[bool]: def _eval_is_integer(self) -> Optional[bool]:
return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined]
@ -389,6 +391,7 @@ class Where(sympy.Function):
class PythonMod(sympy.Function): class PythonMod(sympy.Function):
nargs: Tuple[int, ...] = (2,) nargs: Tuple[int, ...] = (2,)
precedence: int = 35 # lower precedence than add
is_integer: bool = True is_integer: bool = True
@classmethod @classmethod
@ -447,6 +450,7 @@ class PythonMod(sympy.Function):
# Generic modulus: only defined on non-negative arguments # Generic modulus: only defined on non-negative arguments
class Mod(sympy.Function): class Mod(sympy.Function):
nargs = (2,) nargs = (2,)
precedence: int = 35 # lower precedence than add
is_integer = True is_integer = True
is_nonnegative = True is_nonnegative = True
@ -1014,6 +1018,8 @@ def _safe_pow(base, exponent):
class PowByNatural(sympy.Function): class PowByNatural(sympy.Function):
is_integer = True is_integer = True
precedence: int = 50 # precedence of mul
@classmethod @classmethod
def eval(cls, base, exp): def eval(cls, base, exp):
if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer):
@ -1039,6 +1045,8 @@ class PowByNatural(sympy.Function):
class FloatPow(sympy.Function): class FloatPow(sympy.Function):
is_real = True is_real = True
precedence: int = 60 # precedence of pow
@classmethod @classmethod
def eval(cls, base, exp): def eval(cls, base, exp):
# NB: These test sympy.Number, not sympy.Float, because: # NB: These test sympy.Number, not sympy.Float, because:
@ -1059,6 +1067,8 @@ class FloatPow(sympy.Function):
class FloatTrueDiv(sympy.Function): class FloatTrueDiv(sympy.Function):
is_real = True is_real = True
precedence: int = 35 # lower precedence than add
@classmethod @classmethod
def eval(cls, base, divisor): def eval(cls, base, divisor):
# assert base.is_integer is not True, base # assert base.is_integer is not True, base
@ -1082,6 +1092,8 @@ class FloatTrueDiv(sympy.Function):
class IntTrueDiv(sympy.Function): class IntTrueDiv(sympy.Function):
is_real = True is_real = True
precedence: int = 35 # lower precedence than add
@classmethod @classmethod
def eval(cls, base, divisor): def eval(cls, base, divisor):
if divisor.is_zero: if divisor.is_zero:
@ -1254,6 +1266,8 @@ class Identity(sympy.Function):
Prevents expansion and other optimizations Prevents expansion and other optimizations
""" """
precedence = 10
def __repr__(self): # type: ignore[override] def __repr__(self): # type: ignore[override]
return f"Identity({self.args[0]})" return f"Identity({self.args[0]})"

View File

@ -0,0 +1,459 @@
import sys
from typing import Optional
import sympy
from sympy.printing.precedence import PRECEDENCE, precedence
from sympy.printing.str import StrPrinter
INDEX_TYPE = "int64_t"
# This printer contains rules that are supposed to be generic for both C/C++ and
# Python
class ExprPrinter(StrPrinter):
# override this so that _print_FloorDiv is used
printmethod = "_torch_sympystr"
def _print_Mul(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, "*", precedence(expr))
def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str:
return self.stringify(expr.args, " + ", precedence(expr))
def _print_Relational(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr))
def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5)
def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5)
# NB: this is OK to put here, because Mod is only defined for positive
# numbers, and so across C/Python its behavior is consistent
def _print_Mod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str:
s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
return f"({s})"
def _print_CleanDiv(self, expr: sympy.Expr) -> str:
return self._print_FloorDiv(expr)
def _print_Identity(self, expr: sympy.Expr) -> str:
return self._print(expr.args[0])
# This must be implemented because sympy will collect x * x into Pow(x, 2), without
# any explicit intervention. We print it just like x * x, notably, we
# never generate sympy.Pow with floats.
#
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
assert exp == int(exp), exp
exp = int(exp)
assert exp >= 0
if exp > 0:
return self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
return "1"
# Explicit NotImplemented functions are to prevent default sympy printing
# behavior, which will just barf out ToFloat(...) to your IR. The error
# message is better here because it tells you which printer class it needs
# to go in.
def _print_ToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}")
def _print_Infinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}")
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_NegativeInfinity not implemented for {type(self)}"
)
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
def _print_PythonMod(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}")
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}")
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}")
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}")
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}")
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_RoundDecimal not implemented for {type(self)}"
)
# NB: Some float operations are INTENTIONALLY not implemented for
# printers. You can implement them as a quick unblock, but it is better
# to ask yourself why we haven't done this computation in the Tensor
# universe instead
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_TruncToFloat not implemented for {type(self)}"
)
class PythonPrinter(ExprPrinter):
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"float({self._print(expr.args[0])})"
def _print_And(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " and ", precedence(expr))
def _print_Or(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " or ", precedence(expr))
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
if div != "1":
x = f"({x} // {div})"
return f"({x} % {mod})"
def _print_Infinity(self, expr: sympy.Expr) -> str:
return "math.inf"
def _print_NegativeInfinity(self, expr: sympy.Expr) -> str:
return "-math.inf"
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
# WARNING: this is dangerous for Triton, which has C-style modulus
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args)
return f"{x} // {div}"
# WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python
# does a special algorithm
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5)
def _helper_sqrt(self, expr: sympy.Expr) -> str:
return f"math.sqrt({self._print(expr)})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return self._helper_sqrt(expr.args[0])
def _print_FloatPow(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
# TODO: Not sure this works with Triton, even when base/exp are integral
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"])
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.floor({self._print(expr.args[0])})"
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# This also could have been int(), they'll do the same thing for float
return f"math.trunc({self._print(expr.args[0])})"
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.ceil({self._print(expr.args[0])})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"abs({self._print(expr.args[0])})"
# NB: It's expected that we've made explicit any promotion in the sympy
# expression, so it doesn't matter that Python max/min doesn't perform
# promotion
def _print_Max(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"max({', '.join(map(self._print, expr.args))})"
def _print_Min(self, expr: sympy.Expr) -> str:
assert len(expr.args) >= 2
return f"min({', '.join(map(self._print, expr.args))})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"math.atan({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"round({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
assert isinstance(ndigits, sympy.Integer)
return f"round({self._print(number)}, {ndigits})"
class CppPrinter(ExprPrinter):
def _print_Integer(self, expr: sympy.Expr) -> str:
return (
f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L"
)
def _print_Where(self, expr: sympy.Expr) -> str:
c, p, q = (
self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args
)
return f"{c} ? {p} : {q}"
def _print_ModularIndexing(self, expr: sympy.Expr) -> str:
x, div, mod = expr.args
x = self.doprint(x)
if div != 1:
div = self.doprint(div)
if expr.is_integer:
x = f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
else:
x = f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
mod = self.doprint(mod)
return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))"
def _print_FloorDiv(self, expr: sympy.Expr) -> str:
x, div = expr.args
x = self.doprint(x)
div = self.doprint(div)
if expr.is_integer:
return f"c10::div_floor_integer(static_cast<int64_t>({x}), static_cast<int64_t>({div}))"
return f"c10::div_floor_floating(static_cast<double>({x}), static_cast<double>({div}))"
def _print_floor(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_FloorToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::floor({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_TruncToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::trunc({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})"
def _print_TruncToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::trunc({self._print(expr.args[0])})"
def _print_ToFloat(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args
# TODO: This is only accurate up to 2**53
return f"static_cast<double>({self._print(lhs)}) / static_cast<double>({self._print(rhs)})"
# TODO: PowByNatural: we need to implement our own int-int pow. Do NOT
# use std::pow, that operates on floats
def _print_PowByNatural(self, expr: sympy.Expr) -> str:
raise NotImplementedError(
f"_print_PowByNatural not implemented for {type(self)}"
)
def _print_FloatPow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
return f"std::pow({self._print(base)}, {self._print(exp)})"
def _print_Pow(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
base, exp = expr.args
if exp == 0.5 or exp == -0.5:
base = self._print(base)
return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})"
if exp.is_integer:
exp = int(exp)
if exp > 0:
r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"])
elif exp < -1:
r = (
"1.0/("
+ self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"])
+ ")"
)
elif exp == -1:
r = "1.0/" + self._print(base)
else: # exp == 0
r = "1.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
else:
# TODO: float vs double
return f"std::pow({base}, {float(exp)})"
def _print_Rational(self, expr: sympy.Expr) -> str:
# Uses float constants to perform FP div
if expr.q == 1:
r = f"{expr.p}"
else:
r = f"{expr.p}.0/{expr.q}.0"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_ceiling(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_CeilToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
r = f"std::ceil({self._print(expr.args[0])})"
return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r
def _print_Min(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
def _print_Max(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
if len(args) == 2:
return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))"
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::abs({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::cosh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::acos({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::sinh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::asin({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::tanh({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
return f"std::atan({self._print(expr.args[0])})"
def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str:
return f"std::sqrt({self._print(expr.args[0])})"
def _print_RoundToInt(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1
# TODO: dispatch to llrint depending on index type
return f"std::lrint({self._print(expr.args[0])})"
def _print_RoundDecimal(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 2
number, ndigits = expr.args
if number.is_integer:
# ndigits < 0 should have been filtered by the sympy function
assert ndigits < 0
raise ValueError(
f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}."
)
number_str = self.parenthesize(number, PRECEDENCE["Mul"])
return f"static_cast<double>(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})"
def _print_BooleanTrue(self, expr: sympy.Expr) -> str:
return "true"
def _print_BooleanFalse(self, expr: sympy.Expr) -> str:
return "false"