From 44186a0a4ee9a52fc0f4c886d94b9f146bbda58a Mon Sep 17 00:00:00 2001 From: Isuru Fernando Date: Mon, 25 Nov 2024 22:22:07 +0000 Subject: [PATCH] Move Sympy printers to torch/utils/_sympy/printers.py (#140597) Pull Request resolved: https://github.com/pytorch/pytorch/pull/140597 Approved by: https://github.com/ezyang, https://github.com/anijain2305 --- test/distributed/test_inductor_collectives.py | 4 +- test/dynamo/test_export.py | 2 +- test/dynamo/test_logging.py | 2 +- test/dynamo/test_misc.py | 4 +- test/export/test_export.py | 4 +- test/inductor/test_cuda_repro.py | 10 +- test/inductor/test_indexing.py | 39 +- test/inductor/test_memory_planning.py | 8 +- test/inductor/test_padding.py | 2 +- test/inductor/test_torchinductor.py | 6 +- .../test_torchinductor_strided_blocks.py | 2 +- test/test_dynamic_shapes.py | 4 +- torch/_inductor/codegen/common.py | 290 +---------- torch/_inductor/codegen/cpp_utils.py | 214 +------- torch/_inductor/codegen/halide.py | 4 +- torch/_inductor/codegen/triton.py | 19 +- torch/_inductor/select_algorithm.py | 7 +- torch/fx/experimental/symbolic_shapes.py | 72 +-- torch/utils/_sympy/functions.py | 16 +- torch/utils/_sympy/printers.py | 459 ++++++++++++++++++ 20 files changed, 589 insertions(+), 579 deletions(-) create mode 100644 torch/utils/_sympy/printers.py diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 8f0928fe91e..92a2fd6ee2c 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -580,8 +580,8 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): .check_regex( "torch.ops._c10d_functional.all_to_all_single.default\\(" "arg\\d+_\\d+, " - "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\], " - "\\[\\(s\\d+ // \\d\\), \\(s\\d+ // \\d\\)\\]" + "\\[s\\d+ // \\d, s\\d+ // \\d\\], " + "\\[s\\d+ // \\d, s\\d+ // \\d\\]" ) .run(code) ) diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 7c34b09579b..78d6e43a20b 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -3570,7 +3570,7 @@ class GraphModule(torch.nn.Module): "cast_symbool_to_symint_guardless(L['pred']) == 1", ] false_guard_code = [ - "Ne(cast_symbool_to_symint_guardless(L['pred']), 1)", + "cast_symbool_to_symint_guardless(L['pred']) != 1", ] test_symbool_guards( f, diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 35c1f916d2b..f014e393973 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -668,7 +668,7 @@ TRACE FX call mul from test_logging.py:N in fn (LoggingTests.test_trace_call_pre """\ +- LAMBDA_GUARD: L['x'].size()[0] == 2*L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # #:# in # +- LAMBDA_GUARD: L['y'].size()[0] == L['z'].size()[0] # duck sizing added this equality because these variables had the same size 3 (to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False) -+- LAMBDA_GUARD: Eq(Mod(2*L['z'].size()[0], 3), 0) # if x.size(0) % 3 == 0: # #:# in # #:# in # ++- LAMBDA_GUARD: ((2*L['z'].size()[0]) % 3) == 0 # if x.size(0) % 3 == 0: # #:# in # #:# in # +- LAMBDA_GUARD: 2 <= L['z'].size()[0] # return x + torch.cat([y, z]) # #:# in # (user code shown is first use of this value--the guard itself is not due user code but due to 0/1 specialization in the framework; to avoid specialization try torch._dynamo.mark_unbacked(tensor, dim))""", # noqa: B950 ) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 823c8f3d193..8e9610cad5b 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -10457,7 +10457,7 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> axioms: values don't match. - > Left: {0 < Mod(s0, 3): False, 0 <= Mod(s0, 3): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Mod(s0, 3) < 0: False, Mod(s0, 3) <= 0: True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False} + > Left: {(Mod(s0, 3)) < 0: False, (Mod(s0, 3)) <= 0: True, 0 < (Mod(s0, 3)): False, 0 <= (Mod(s0, 3)): True, Eq(0, Mod(s0, 3)): True, Eq(Mod(s0, 3), 0): True, Ne(0, Mod(s0, 3)): False, Ne(Mod(s0, 3), 0): False} > Right: {} ==> divisible: values don't match. > Left: {Mod(s0, 3)} @@ -10576,7 +10576,7 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> axioms: values don't match. - > Left: {0 < PythonMod(u0, 3): False, 0 <= PythonMod(u0, 3): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False, PythonMod(u0, 3) < 0: False, PythonMod(u0, 3) <= 0: True} + > Left: {(PythonMod(u0, 3)) < 0: False, (PythonMod(u0, 3)) <= 0: True, 0 < (PythonMod(u0, 3)): False, 0 <= (PythonMod(u0, 3)): True, Eq(0, PythonMod(u0, 3)): True, Eq(PythonMod(u0, 3), 0): True, Ne(0, PythonMod(u0, 3)): False, Ne(PythonMod(u0, 3), 0): False} > Right: {} ==> deferred_runtime_asserts: values don't match. > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} diff --git a/test/export/test_export.py b/test/export/test_export.py index b96e96ca161..50437616d42 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -3259,9 +3259,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x): (torch.tensor(20),), fixes=[ # Could not guard on data-dependent expression Eq((u0//2), 0) - "torch._check(((i//2)) != 0)", + "torch._check((i // 2) != 0)", # Could not guard on data-dependent expression Eq((u0//2), 1) - "torch._check(((i//2)) != 1)", + "torch._check((i // 2) != 1)", ], ) diff --git a/test/inductor/test_cuda_repro.py b/test/inductor/test_cuda_repro.py index 64188c8eb7f..81842775bc6 100644 --- a/test/inductor/test_cuda_repro.py +++ b/test/inductor/test_cuda_repro.py @@ -1426,12 +1426,12 @@ def triton_poi_fused_add_reflection_pad2d_0(in_ptr0, in_ptr1, out_ptr0, xnumel, xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel - x0 = xindex % 20 - x1 = (xindex // 20) % 20 - x2 = (xindex // 400) + x0 = (xindex % 20) + x1 = ((xindex // 20) % 20) + x2 = xindex // 400 x3 = xindex - tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') - tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last') + tmp0 = tl.load(in_ptr0 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last') + tmp1 = tl.load(in_ptr1 + (99 + ((-1)*tl_math.abs((-9) + tl_math.abs((-5) + x0))) + ((-10)*tl_math.abs((-9) + tl_math.abs((-5) + x1))) + 100*x2), xmask, eviction_policy='evict_last') tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950 ) diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index d56a42b8252..53f625e24bf 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -20,7 +20,9 @@ from torch.testing._internal.common_utils import ( from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_CPU, HAS_GPU from torch.utils._sympy.functions import ( FloorDiv, + Mod, ModularIndexing, + PythonMod, RoundDecimal, RoundToInt, ) @@ -236,7 +238,7 @@ class TestIndexingSimplification(InductorTestCase): triton_code = run_and_get_triton_code(f, x) # Make sure the 2 load uses simpified indexing rather than something like # tl.load(in_ptr0 + ((5504*x1) + (x0 // 2)), - self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + ((x2 // 2)),")) + self.assertEqual(2, triton_code.count("tl.load(in_ptr0 + (x2 // 2),")) if DO_PERF_TEST: ms = benchmarker.benchmark_gpu(lambda: f(x)) print(f"{ms=:.03f}") @@ -313,6 +315,39 @@ class ExprPrinterTests(InductorTestCase): self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") + def test_print_mod(self): + x = sympy.Symbol("x", integer=True) + expr = Mod(x - 1, 2) + self.assertExpectedInline(pexpr(expr), """((-1) + x) % 2""") + self.assertExpectedInline(cexpr(expr), """((-1L) + x) % 2L""") + self.assertExpectedInline(texpr(expr), """((-1) + x) % 2""") + + expr = (x - 10) % x + self.assertExpectedInline(pexpr(expr), """(-10) % x""") + self.assertExpectedInline(cexpr(expr), """(-10L) % x""") + self.assertExpectedInline(texpr(expr), """(-10) % x""") + + def test_print_mod_index(self): + x = sympy.Symbol("x", integer=True) + ks = sympy.Symbol("ks", integer=True) + expr = ModularIndexing(x - 10, ks, ks) + self.assertExpectedInline(pexpr(expr), """((((-10) + x) // ks) % ks)""") + self.assertExpectedInline( + cexpr(expr), + """(static_cast(c10::div_floor_integer(""" + """static_cast((-10L) + x), static_cast(ks))) % static_cast(ks))""", + ) + self.assertExpectedInline(texpr(expr), """((((-10) + x) // ks) % ks)""") + + def test_print_python_mod(self): + x = sympy.Symbol("x", integer=True) + expr = PythonMod(x - 10, x) + self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""") + self.assertExpectedInline(cexpr(expr), """((-10L) + x) % x""") + self.assertExpectedInline( + texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)""" + ) + @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): expr = RoundDecimal(sympy.Symbol("x", integer=True) / 2, ndigits) @@ -330,7 +365,7 @@ class ExprPrinterTests(InductorTestCase): s1 = sympy.Symbol("s1", integer=True) s2 = sympy.Symbol("s2", integer=True) expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(pexpr(expr), "s1 // s2") self.assertEqual( cexpr(expr), "c10::div_floor_integer(static_cast(s1), static_cast(s2))", diff --git a/test/inductor/test_memory_planning.py b/test/inductor/test_memory_planning.py index c7431324b55..765a51d6b17 100644 --- a/test/inductor/test_memory_planning.py +++ b/test/inductor/test_memory_planning.py @@ -58,13 +58,11 @@ class TestMemoryPlanning(TestCase): result, code = run_and_get_cpp_code(compiled, *args) FileCheck().check( - "pool1 = empty_strided_" - + GPU_TYPE - + "(((4*s0*s1) + (align(4*(s0*s0))), ), (1, )" + "pool1 = empty_strided_" + GPU_TYPE + "((4*s0*s1 + align(4*s0*s0), ), (1, )" ).check_next( "buf0 = alloc_from_pool(pool1, 0, torch.float32, (s0, s0), (s0, 1))" ).check( - "buf1 = alloc_from_pool(pool1, align(4*(s0*s0))," + "buf1 = alloc_from_pool(pool1, align(4*s0*s0)," ).run( code ) @@ -103,7 +101,7 @@ class TestMemoryPlanning(TestCase): ) FileCheck().check( - "int64_t int_array_2[] = {24L + (align(12L*s0)), };" + "int64_t int_array_2[] = {24L + align(12L*s0), };" ).check_next("int64_t int_array_3[] = {1L, };").check_next( "AtenTensorHandle pool1_handle;" ).check_next( diff --git a/test/inductor/test_padding.py b/test/inductor/test_padding.py index fd976f69d93..b320bf482d9 100644 --- a/test/inductor/test_padding.py +++ b/test/inductor/test_padding.py @@ -487,7 +487,7 @@ class PaddingTest(TestCaseBase): # make sure the load for softmax is aligned self.assertTrue( - "tl.load(in_ptr0 + (r1 + (30528*x0))" in forward_wrapper, + "tl.load(in_ptr0 + (r1 + 30528*x0)" in forward_wrapper, f"forward_wrapper: {forward_wrapper}", ) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 24dc8de068e..bc7ef79c958 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -12505,8 +12505,8 @@ if HAS_GPU and not TEST_WITH_ASAN: self.assertExpectedInline( "\n".join(lines), """\ - tmp0 = tl.load(in_ptr0 + (x1 + (512*x0) + (262144*r2)), rmask, eviction_policy='evict_last', other=0.0) - tmp1 = tl.load(in_ptr1 + (x3 + (262144*r2)), rmask, eviction_policy='evict_first', other=0.0)""", + tmp0 = tl.load(in_ptr0 + (x1 + 512*x0 + 262144*r2), rmask, eviction_policy='evict_last', other=0.0) + tmp1 = tl.load(in_ptr1 + (x3 + 262144*r2), rmask, eviction_policy='evict_first', other=0.0)""", ) @config.patch("triton.use_block_ptr", True) @@ -12538,7 +12538,7 @@ if HAS_GPU and not TEST_WITH_ASAN: self.assertExpectedInline( "\n".join(lines), """\ - tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [((511 + XBLOCK) // 512), ((1) * ((1) <= (((511 + XBLOCK) // 512))) + (((511 + XBLOCK) // 512)) * ((((511 + XBLOCK) // 512)) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK]) + tmp0 = tl.reshape(tl.broadcast_to(tl.load(block_ptr0, boundary_check=[2], padding_option='zero', eviction_policy='evict_last')[:, None, :, :], [(511 + XBLOCK) // 512, ((1) * ((1) <= ((511 + XBLOCK) // 512)) + ((511 + XBLOCK) // 512) * (((511 + XBLOCK) // 512) < (1))), ((512) * ((512) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (512))), RBLOCK]), [XBLOCK, RBLOCK]) tmp1 = tl.load(block_ptr1, boundary_check=[1], padding_option='zero', eviction_policy='evict_first')""", # noqa: B950 line too long ) diff --git a/test/inductor/test_torchinductor_strided_blocks.py b/test/inductor/test_torchinductor_strided_blocks.py index f4cdedda1b2..db2d24e9917 100644 --- a/test/inductor/test_torchinductor_strided_blocks.py +++ b/test/inductor/test_torchinductor_strided_blocks.py @@ -275,7 +275,7 @@ class TritonBlockPointerTest(InductorTestCase): "\n".join(load_lines), """\ tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[64], strides=[1], block_shape=[XBLOCK], order=[0], offsets=[xoffset]), boundary_check=[0]) - tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[((7 + XBLOCK) // 8)], order=[0], offsets=[(xoffset // 8)]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [((7 + XBLOCK) // 8), ((1) * ((1) <= (((7 + XBLOCK) // 8))) + (((7 + XBLOCK) // 8)) * ((((7 + XBLOCK) // 8)) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 + tmp1 = tl.reshape(tl.broadcast_to(tl.load(tl.make_block_ptr(in_ptr1, shape=[8], strides=[8], block_shape=[(7 + XBLOCK) // 8], order=[0], offsets=[xoffset // 8]), boundary_check=[0], eviction_policy='evict_last')[:, None, None], [(7 + XBLOCK) // 8, ((1) * ((1) <= ((7 + XBLOCK) // 8)) + ((7 + XBLOCK) // 8) * (((7 + XBLOCK) // 8) < (1))), ((8) * ((8) <= (XBLOCK)) + (XBLOCK) * ((XBLOCK) < (8)))]), [XBLOCK])""", # noqa: B950 ) self.assertExpectedInline( "\n".join(store_lines), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 857f732164a..46eaab3f3b3 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -2792,8 +2792,8 @@ class TestGuardsExpressions(TestCase): guard_int(sym_int(s0 / 2.0)) guards = shape_env.produce_guards_expression([s0]) - self.assertIn("ToFloat", guards) - self.assertIn("FloatTrueDiv", guards) + self.assertIn("math.trunc(", guards) + self.assertIn("float(", guards) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 0222fae2403..dd1bbbee440 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -23,7 +23,6 @@ from typing import ( ) import sympy -from sympy.printing.printer import Printer import torch import torch.fx @@ -31,6 +30,7 @@ from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from torch.utils import _pytree as pytree from torch.utils._ordered_set import OrderedSet from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import PythonPrinter as _PythonPrinter from torch.utils._sympy.symbol import free_symbol_is_type, symbol_is_type, SymT from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, ValueRanges @@ -609,12 +609,22 @@ class DataTypePropagation: DataTypePropagation.propagate_loopbody(node._body) -# This printer contains rules that are supposed to be generic for both C/C++ and -# Python -class ExprPrinter(Printer): +class PythonPrinter(_PythonPrinter): + def doprint(self, expr, *, simplify: bool = True, p=True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) + + +class OpOverrides: + def __init__(self, parent): + super().__init__() + self._parent = parent + @staticmethod - def paren(string): - def all_in_parens(string): + def paren(string: str) -> str: + def all_in_parens(string: str) -> bool: if string[0] != "(" or len(string) < 2: return False count = 1 @@ -640,260 +650,6 @@ class ExprPrinter(Printer): return string return f"({string})" - def _print_Relational(self, expr): - return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) - - def _print_Mul(self, expr): - return "*".join(map(self.paren, map(self._print, expr.args))) - - def _print_Add(self, expr): - return " + ".join(map(self.paren, map(self._print, expr.args))) - - # NB: this is OK to put here, because Mod is only defined for positive - # numbers, and so across C/Python its behavior is consistent - def _print_Mod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - def _print_CleanDiv(self, expr): - return self._print_FloorDiv(expr) - - def _print_Identity(self, expr): - return self._print(expr.args[0]) - - def _print_GreaterThan(self, expr): - # GreaterThan: >= - # StrictlyGreaterThan: > - # Go figure... - return " >= ".join(map(self.paren, map(self._print, expr.args))) - - # NB: The C implementation is injected into codegen at - # torch/_inductor/codegen/wrapper.py - def _print_align(self, expr): - assert len(expr.args) == 1 - return f"align({self._print(expr.args[0])})" - - # This must be implemented because sympy will collect x * x into Pow(x, 2), without - # any explicit intervention. We print it just like x * x, notably, we - # never generate sympy.Pow with floats. - # - # NB: this pow by natural, you should never have used builtin sympy.pow - # for FloatPow, and a symbolic exponent should be PowByNatural. These - # means exp is guaranteed to be integer. - def _print_Pow(self, expr): - base, exp = expr.args - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - assert exp >= 0 - if exp > 0: - return "*".join([self.paren(base)] * exp) - return "1" - - # Explicit NotImplemented functions are to prevent default sympy printing - # behavior, which will just barf out ToFloat(...) to your IR. The error - # message is better here because it tells you which printer class it needs - # to go in. - - def _print_ToFloat(self, expr): - raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") - - def _print_Infinity(self, expr): - raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") - - def _print_NegativeInfinity(self, expr): - raise NotImplementedError( - f"_print_NegativeInfinity not implemented for {type(self)}" - ) - - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") - - def _print_PythonMod(self, expr): - raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") - - def _print_IntTrueDiv(self, expr): - raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") - - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatPow(self, expr): - raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") - - def _print_TruncToInt(self, expr): - raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") - - def _print_RoundToInt(self, expr): - raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") - - def _print_RoundDecimal(self, expr): - raise NotImplementedError( - f"_print_RoundDecimal not implemented for {type(self)}" - ) - - # NB: Some float operations are INTENTIONALLY not implemented for - # printers. You can implement them as a quick unblock, but it is better - # to ask yourself why we haven't done this computation in the Tensor - # universe instead - - def _print_TruncToFloat(self, expr): - raise NotImplementedError( - f"_print_TruncToFloat not implemented for {type(self)}" - ) - - def doprint(self, expr, *, simplify: bool = True): - # TODO: why are people passing strings to the printer here :think: - if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): - expr = V.graph.sizevars.simplify(expr) - return super().doprint(expr) - - -class PythonPrinter(ExprPrinter): - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"float({self._print(expr.args[0])})" - - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - mod = self.paren(self.doprint(mod)) - if div != "1": - x = f"({x} // {div})" - return f"{x} % {mod}" - - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - - # WARNING: this is dangerous for Triton, which has C-style modulus - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - # WARNING: this is dangerous for Triton, which has C-style modulus - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"({x} // {div})" - - # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python - # does a special algorithm - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - def _helper_sqrt(self, expr): - return f"math.sqrt({self._print(expr)})" - - def _print_OpaqueUnaryFn_sqrt(self, expr): - return self._helper_sqrt(expr.args[0]) - - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" - - # TODO: Not sure this works with Triton, even when base/exp are integral - def _print_PowByNatural(self, expr): - base, exp = expr.args - return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_FloorToInt(self, expr): - assert len(expr.args) == 1 - return f"math.floor({self._print(expr.args[0])})" - - def _print_TruncToInt(self, expr): - assert len(expr.args) == 1 - # This also could have been int(), they'll do the same thing for float - return f"math.trunc({self._print(expr.args[0])})" - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - - def _print_CeilToInt(self, expr): - assert len(expr.args) == 1 - return f"math.ceil({self._print(expr.args[0])})" - - def _print_Abs(self, expr): - assert len(expr.args) == 1 - return f"abs({self._print(expr.args[0])})" - - # NB: It's expected that we've made explicit any promotion in the sympy - # expression, so it doesn't matter that Python max/min doesn't perform - # promotion - def _print_Max(self, expr): - assert len(expr.args) >= 2 - return f"max({', '.join(map(self._print, expr.args))})" - - def _print_Min(self, expr): - assert len(expr.args) >= 2 - return f"min({', '.join(map(self._print, expr.args))})" - - def _print_OpaqueUnaryFn_cos(self, expr): - assert len(expr.args) == 1 - return f"math.cos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cosh(self, expr): - assert len(expr.args) == 1 - return f"math.cosh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_acos(self, expr): - assert len(expr.args) == 1 - return f"math.acos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sin(self, expr): - assert len(expr.args) == 1 - return f"math.sin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sinh(self, expr): - assert len(expr.args) == 1 - return f"math.sinh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_asin(self, expr): - assert len(expr.args) == 1 - return f"math.asin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tan(self, expr): - assert len(expr.args) == 1 - return f"math.tan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tanh(self, expr): - assert len(expr.args) == 1 - return f"math.tanh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_atan(self, expr): - assert len(expr.args) == 1 - return f"math.atan({self._print(expr.args[0])})" - - def _print_RoundToInt(self, expr): - assert len(expr.args) == 1 - return f"round({self._print(expr.args[0])})" - - def _print_RoundDecimal(self, expr): - assert len(expr.args) == 2 - number, ndigits = expr.args - assert isinstance(ndigits, sympy.Integer) - return f"round({self._print(number)}, {ndigits})" - - -class OpOverrides: - def __init__(self, parent): - super().__init__() - self._parent = parent - def __getattr__(self, item): return getattr(self._parent, item) @@ -982,31 +738,31 @@ class OpOverrides: @staticmethod def bitwise_not(x): - return f"~{ExprPrinter.paren(x)}" + return f"~{OpOverrides.paren(x)}" @staticmethod def logical_not(a): - return f"{ExprPrinter.paren(a)} == 0" + return f"{OpOverrides.paren(a)} == 0" @staticmethod def bitwise_and(x, y): - return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}" + return f"{OpOverrides.paren(x)} & {OpOverrides.paren(y)}" @staticmethod def bitwise_or(x, y): - return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}" + return f"{OpOverrides.paren(x)} | {OpOverrides.paren(y)}" @staticmethod def bitwise_xor(x, y): - return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}" + return f"{OpOverrides.paren(x)} ^ {OpOverrides.paren(y)}" @staticmethod def bitwise_left_shift(x, y): - return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}" + return f"{OpOverrides.paren(x)} << {OpOverrides.paren(y)}" @staticmethod def bitwise_right_shift(x, y): - return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}" + return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" @staticmethod def remainder(a, b): diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 3aa17c287f9..4a62f929fec 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -13,6 +13,7 @@ import sympy import torch from torch._prims_common import is_integer_dtype from torch.utils._ordered_set import OrderedSet +from torch.utils._sympy.printers import CppPrinter as _CppPrinter from torch.utils._sympy.symbol import symbol_is_type, SymT from torch.utils._sympy.value_ranges import ValueRanges @@ -25,7 +26,6 @@ from ..virtualized import ops, OpsValue, V from .common import ( CSEVariable, deduce_output_dtype_by_name, - ExprPrinter, Kernel, KernelArgs, OptimizationContext, @@ -232,212 +232,12 @@ class CppCSEVariable(CSEVariable): return itervar in self.dependent_itervars -class CppPrinter(ExprPrinter): - def _print_Integer(self, expr): - return ( - f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" - ) - - def _print_Where(self, expr): - c = self.paren(self.doprint(expr.args[0])) - p = self.paren(self.doprint(expr.args[1])) - q = self.paren(self.doprint(expr.args[2])) - return f"{c} ? {p} : {q}" - - def _print_ModularIndexing(self, expr): - x, div, mod = expr.args - x = self.paren(self.doprint(x)) - if div != 1: - div = self.paren(self.doprint(div)) - if expr.is_integer: - x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" - else: - x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - mod = self.paren(self.doprint(mod)) - return f"static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod})" - - def _print_FloorDiv(self, expr): - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - if expr.is_integer: - return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" - return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" - - def _print_floor(self, expr): - assert len(expr.args) == 1 - r = f"std::floor({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_FloorToInt(self, expr): - assert len(expr.args) == 1 - r = f"std::floor({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_TruncToInt(self, expr): - assert len(expr.args) == 1 - r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" - - def _print_TruncToFloat(self, expr): - assert len(expr.args) == 1 - return f"std::trunc({self._print(expr.args[0])})" - - def _print_ToFloat(self, expr): - assert len(expr.args) == 1 - return f"static_cast({self._print(expr.args[0])})" - - # TODO: This is wrong if one of the inputs is negative. This is hard to - # tickle though, as the inputs are typically positive (and if we can prove - # they are positive, we will have used Mod instead, for which this codegen - # is right). - def _print_PythonMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_CMod(self, expr): - return " % ".join(map(self.paren, map(self._print, expr.args))) - - def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - # TODO: This is only accurate up to 2**53 - return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" - - # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT - # use std::pow, that operates on floats - def _print_PowByNatural(self, expr): - raise NotImplementedError( - f"_print_PowByNatural not implemented for {type(self)}" - ) - - def _print_FloatTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" - - def _print_FloatPow(self, expr): - base, exp = expr.args - return f"std::pow({self._print(base)}, {self._print(exp)})" - - def _print_Pow(self, expr): - # Uses float constants to perform FP div - base, exp = expr.args - base = self._print(base) - - if exp == 0.5 or exp == -0.5: - return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" - if exp.is_integer: - exp = int(exp) - if exp > 0: - r = "*".join([self.paren(base)] * exp) - elif exp < 0: - r = "1.0/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - r = "1.0" - - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - else: - # TODO: float vs double - return f"std::pow({base}, {float(exp)})" - - def _print_Rational(self, expr): - # Uses float constants to perform FP div - if expr.q == 1: - r = f"{expr.p}" - else: - r = f"{expr.p}.0/{expr.q}.0" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_ceiling(self, expr): - assert len(expr.args) == 1 - r = f"std::ceil({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_CeilToInt(self, expr): - assert len(expr.args) == 1 - r = f"std::ceil({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - - def _print_Min(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::min({il})" - - def _print_Max(self, expr): - args = [self._print(a) for a in expr.args] - if len(args) == 2: - return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" - else: - # Initializer list overload - il = "{" + ", ".join(args) + "}" - return f"std::max({il})" - - def _print_Abs(self, expr): - assert len(expr.args) == 1 - return f"std::abs({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cos(self, expr): - assert len(expr.args) == 1 - return f"std::cos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_cosh(self, expr): - assert len(expr.args) == 1 - return f"std::cosh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_acos(self, expr): - assert len(expr.args) == 1 - return f"std::acos({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sin(self, expr): - assert len(expr.args) == 1 - return f"std::sin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sinh(self, expr): - assert len(expr.args) == 1 - return f"std::sinh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_asin(self, expr): - assert len(expr.args) == 1 - return f"std::asin({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tan(self, expr): - assert len(expr.args) == 1 - return f"std::tan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_tanh(self, expr): - assert len(expr.args) == 1 - return f"std::tanh({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_atan(self, expr): - assert len(expr.args) == 1 - return f"std::atan({self._print(expr.args[0])})" - - def _print_OpaqueUnaryFn_sqrt(self, expr): - return f"std::sqrt({self._print(expr.args[0])})" - - def _print_RoundToInt(self, expr): - assert len(expr.args) == 1 - # TODO: dispatch to llrint depending on index type - return f"std::lrint({self._print(expr.args[0])})" - - def _print_RoundDecimal(self, expr): - assert len(expr.args) == 2 - number, ndigits = expr.args - if number.is_integer: - # ndigits < 0 should have been filtered by the sympy function - assert ndigits < 0 - raise ValueError( - f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." - ) - return f"static_cast(std::nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits})" - - def _print_BooleanTrue(self, expr): - return "true" - - def _print_BooleanFalse(self, expr): - return "false" +class CppPrinter(_CppPrinter): + def doprint(self, expr, *, simplify: bool = True, p=True): + # TODO: why are people passing strings to the printer here :think: + if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): + expr = V.graph.sizevars.simplify(expr) + return super().doprint(expr) # A function to print, useful for printing sympy symbols. diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 5efb93902da..dd30412d169 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -185,8 +185,8 @@ class HalidePrinter(PythonPrinter): return super()._print_FloorDiv(expr) x, div = expr.args - x = self.cast_float(self.paren(self.doprint(x))) - div = self.cast_float(self.paren(self.doprint(div))) + x = self.cast_float(self.doprint(x)) + div = self.cast_float(self.doprint(div)) return self.cast_index(f"hl.floor({x} / {div})") def _print_Round(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 0eb75a1fe9a..b717380bf01 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -27,6 +27,7 @@ from typing import ( ) import sympy +from sympy.printing.precedence import PRECEDENCE import torch import torch._logging @@ -504,30 +505,30 @@ class TritonPrinter(PythonPrinter): def _print_ToFloat(self, expr): assert len(expr.args) == 1 - return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + s = self.parenthesize(expr.args[0], PRECEDENCE["Atom"] - 0.5) + return f"{s}.to(tl.float64)" def _print_PythonMod(self, expr): quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) quot_s = self._print(quot) div_s = self._print(div) - if quot.is_nonnegative and div.is_nonnegative: - return f"{self.paren(quot_s)} % {self.paren(div_s)}" return f"triton_helpers.remainder_integer({quot_s}, {div_s})" def _print_FloorDiv(self, expr): assert expr.is_integer quot, div = expr.args + if quot.is_nonnegative and div.is_nonnegative: + return self.stringify(expr.args, " // ", PRECEDENCE["Atom"] - 0.5) quot_s = self._print(quot) div_s = self._print(div) - if quot.is_nonnegative and div.is_nonnegative: - return f"({self.paren(quot_s)} // {self.paren(div_s)})" return f"triton_helpers.div_floor_integer({quot_s}, {div_s})" # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher # precision algorithm, which we would need to replicate here def _print_IntTrueDiv(self, expr): - lhs, rhs = expr.args - return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) # NB: sympy.floor/ceiling produce integers, so we have to do the # conversion to index dtype @@ -646,7 +647,9 @@ class TritonPrinter(PythonPrinter): raise ValueError( f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." ) - return f"libdevice.nearbyint(1e{ndigits} * {self.paren(self._print(number))}) * 1e{-ndigits}" + + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"libdevice.nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits}" texpr = TritonPrinter().doprint diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 8b07f6528e7..8986a774d89 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -35,13 +35,12 @@ from .autotune_process import ( TritonGPUBenchmarkRequest, ) from .codecache import code_hash, PersistentCache, PyCodeCache -from .codegen.common import IndentedBuffer, KernelTemplate, WorkspaceArg +from .codegen.common import IndentedBuffer, KernelTemplate, OpOverrides, WorkspaceArg from .codegen.simd_kernel_features import SIMDKernelFeatures from .codegen.triton import ( gen_common_triton_imports, texpr, TritonKernel, - TritonPrinter, TritonScheduling, ) from .codegen.triton_utils import config_of, signature_to_meta @@ -562,7 +561,7 @@ class TritonTemplateKernel(TritonKernel): assert isinstance(val, str) assert isinstance(mask, (str, type(None))) assert self.template_mask is None - indices = list(map(TritonPrinter.paren, indices)) + indices = list(map(OpOverrides.paren, indices)) index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() @@ -630,7 +629,7 @@ class TritonTemplateKernel(TritonKernel): assert isinstance(name, str) assert isinstance(mask, str) stride = self.named_input_nodes[name].get_stride() - indices = list(map(TritonPrinter.paren, indices)) + indices = list(map(OpOverrides.paren, indices)) assert len(indices) == len(stride) index = " + ".join( f"{texpr(self.rename_indexing(s))} * {i}" for s, i in zip(stride, indices) diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 1b4f65cb2d9..84e62f34312 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -10,7 +10,6 @@ need to make use of these APIs to setup dynamic shapes support appropriately. """ import atexit -import builtins import collections import functools import inspect @@ -82,6 +81,7 @@ from torch.utils._sympy.functions import ( PythonMod, ) from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import PythonPrinter from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import try_solve from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT @@ -109,8 +109,6 @@ log = logging.getLogger(__name__) import sympy from sympy import S -from sympy.printing.precedence import PRECEDENCE, precedence -from sympy.printing.str import StrPrinter class GuardOnDataDependentSymNode(RuntimeError): @@ -2032,45 +2030,9 @@ def cast_symbool_to_symint_guardless( SYMPY_INTERP = { - "Abs": operator.abs, - "Eq": operator.eq, - "Ne": operator.ne, - "Gt": operator.gt, - "Lt": operator.lt, - "Le": operator.le, - "Ge": operator.ge, - "Min": min, - "Max": max, - "Mod": operator.mod, - "PythonMod": operator.mod, - "FloorDiv": operator.floordiv, - "TrueDiv": operator.truediv, - "PowByNatural": operator.pow, "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, - "floor": math.floor, - "ceiling": math.ceil, - "FloorToInt": math.floor, - "FloatPow": math.pow, - "CeilToInt": math.ceil, "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, - "RoundToInt": builtins.round, - "RoundDecimal": builtins.round, - "TruncToInt": math.trunc, - "IntTrueDiv": operator.truediv, - "FloatTrueDiv": operator.truediv, - "ToFloat": builtins.float, - "OpaqueUnaryFn_cos": math.cos, - "OpaqueUnaryFn_cosh": math.cosh, - "OpaqueUnaryFn_acos": math.acos, - "OpaqueUnaryFn_sin": math.sin, - "OpaqueUnaryFn_sinh": math.sinh, - "OpaqueUnaryFn_asin": math.asin, - "OpaqueUnaryFn_tan": math.tan, - "OpaqueUnaryFn_tanh": math.tanh, - "OpaqueUnaryFn_atan": math.atan, - "OpaqueUnaryFn_sqrt": math.sqrt, - "BitwiseFn_bitwise_and": operator.and_, - "BitwiseFn_bitwise_or": operator.or_, + "math": math, } @@ -2141,12 +2103,12 @@ class RuntimeAssert: # Used for printing SymExprs in compile_fx -class SymExprPrinter(StrPrinter): +class SymExprPrinter(PythonPrinter): def _print_Float(self, expr: sympy.Float) -> str: return str(float(expr)) -class ShapeGuardPrinter(SymExprPrinter): +class ShapeGuardPrinter(PythonPrinter): def __init__( self, symbol_to_source: Mapping[sympy.Symbol, List[Source]], @@ -2158,14 +2120,8 @@ class ShapeGuardPrinter(SymExprPrinter): self.source_ref = source_ref self.var_to_sources = var_to_sources - def _print_Not(self, expr: SympyBoolean) -> str: - return "not {}".format(self.parenthesize(expr.args[0], PRECEDENCE["Not"])) - - def _print_And(self, expr: SympyBoolean) -> str: - return self.stringify(expr.args, " and ", PRECEDENCE["And"]) - - def _print_Or(self, expr: SympyBoolean) -> str: - return self.stringify(expr.args, " or ", PRECEDENCE["Or"]) + def _print_Float(self, expr: sympy.Float) -> str: + return str(float(expr)) def _print_Symbol(self, expr: sympy.Symbol) -> str: assert isinstance(expr, sympy.Symbol), str(type(expr)) @@ -2191,7 +2147,7 @@ class LoggingShapeGuardPrinter(ShapeGuardPrinter): super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) -class DynamicDimConstraintPrinter(StrPrinter): +class DynamicDimConstraintPrinter(PythonPrinter): """ Printer for dynamic dim constraints. - Instead of symbol s_k it prints its source t.size()[i] @@ -2216,9 +2172,6 @@ class DynamicDimConstraintPrinter(StrPrinter): ), f"Unknown symbol {expr} created by constraints solver" return self.symbol_to_source[expr][0].name() - def _print_Relational(self, expr: sympy.core.relational.Relational) -> str: - return f"{self.parenthesize(expr.lhs, precedence(expr))} {expr.rel_op} {self.parenthesize(expr.rhs, precedence(expr))}" # type: ignore[attr-defined] - class DimConstraints: """ @@ -6656,7 +6609,7 @@ def _blame_user_code(e: Exception, frame: types.FrameType) -> None: e.args = (msg,) -class _PythonPrinter(sympy.printing.str.StrPrinter): +class _PythonMsgPrinter(PythonPrinter): """ Util printer that replaces sympy symbols with their source-level names and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline @@ -6670,13 +6623,6 @@ class _PythonPrinter(sympy.printing.str.StrPrinter): def _print_Symbol(self, sym: sympy.Symbol) -> str: return self.src_map[sym.name][0] - def _print_Relational(self, expr: sympy.core.relational.Relational) -> str: - lhs = self.parenthesize(expr.lhs, sympy.printing.precedence.precedence(expr)) - assert hasattr(expr, "rel_op") - rel_op = expr.rel_op - rhs = self.parenthesize(expr.rhs, sympy.printing.precedence.precedence(expr)) - return f"{lhs} {rel_op} {rhs}" - def _suggest_torch_checks( e: GuardOnDataDependentSymNode, src_map: DefaultDict[str, List[str]] @@ -6687,7 +6633,7 @@ def _suggest_torch_checks( if diff: log.warning("Unable to find user code corresponding to {%s}", diff) return - printer = _PythonPrinter(src_map) + printer = _PythonMsgPrinter(src_map) msg = e.args[0] msg += "\nTo fix the error, insert one of the following checks before this call:" # suggested fixes to resolve `cond`` are to tell the compiler to assume diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 49f82dfc700..eeb8f10a888 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -191,7 +191,7 @@ class FloorDiv(sympy.Function): """ nargs: Tuple[int, ...] = (2,) - precedence: int = 50 # precedence of mul # noqa: F811 + precedence: int = 35 # lower precedence than add is_integer: bool = True @property @@ -291,6 +291,7 @@ class ModularIndexing(sympy.Function): nargs: Tuple[int, ...] = (3,) is_integer: bool = True + precedence: int = 35 # lower precedence than add @classmethod def eval( @@ -360,6 +361,7 @@ class Where(sympy.Function): """ nargs: Tuple[int, ...] = (3,) + precedence: int = 35 # lower precedence than add def _eval_is_integer(self) -> Optional[bool]: return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] @@ -389,6 +391,7 @@ class Where(sympy.Function): class PythonMod(sympy.Function): nargs: Tuple[int, ...] = (2,) + precedence: int = 35 # lower precedence than add is_integer: bool = True @classmethod @@ -447,6 +450,7 @@ class PythonMod(sympy.Function): # Generic modulus: only defined on non-negative arguments class Mod(sympy.Function): nargs = (2,) + precedence: int = 35 # lower precedence than add is_integer = True is_nonnegative = True @@ -1014,6 +1018,8 @@ def _safe_pow(base, exponent): class PowByNatural(sympy.Function): is_integer = True + precedence: int = 50 # precedence of mul + @classmethod def eval(cls, base, exp): if isinstance(base, sympy.Integer) and isinstance(exp, sympy.Integer): @@ -1039,6 +1045,8 @@ class PowByNatural(sympy.Function): class FloatPow(sympy.Function): is_real = True + precedence: int = 60 # precedence of pow + @classmethod def eval(cls, base, exp): # NB: These test sympy.Number, not sympy.Float, because: @@ -1059,6 +1067,8 @@ class FloatPow(sympy.Function): class FloatTrueDiv(sympy.Function): is_real = True + precedence: int = 35 # lower precedence than add + @classmethod def eval(cls, base, divisor): # assert base.is_integer is not True, base @@ -1082,6 +1092,8 @@ class FloatTrueDiv(sympy.Function): class IntTrueDiv(sympy.Function): is_real = True + precedence: int = 35 # lower precedence than add + @classmethod def eval(cls, base, divisor): if divisor.is_zero: @@ -1254,6 +1266,8 @@ class Identity(sympy.Function): Prevents expansion and other optimizations """ + precedence = 10 + def __repr__(self): # type: ignore[override] return f"Identity({self.args[0]})" diff --git a/torch/utils/_sympy/printers.py b/torch/utils/_sympy/printers.py new file mode 100644 index 00000000000..5f64487b253 --- /dev/null +++ b/torch/utils/_sympy/printers.py @@ -0,0 +1,459 @@ +import sys +from typing import Optional + +import sympy +from sympy.printing.precedence import PRECEDENCE, precedence +from sympy.printing.str import StrPrinter + + +INDEX_TYPE = "int64_t" + + +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python +class ExprPrinter(StrPrinter): + # override this so that _print_FloorDiv is used + printmethod = "_torch_sympystr" + + def _print_Mul(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, "*", precedence(expr)) + + def _print_Add(self, expr: sympy.Expr, order: Optional[str] = None) -> str: + return self.stringify(expr.args, " + ", precedence(expr)) + + def _print_Relational(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, f" {expr.rel_op} ", precedence(expr)) + + def _print_BitwiseFn_bitwise_and(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " & ", PRECEDENCE["Atom"] - 0.5) + + def _print_BitwiseFn_bitwise_or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " | ", PRECEDENCE["Atom"] - 0.5) + + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent + def _print_Mod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + def _print_FloatTrueDiv(self, expr: sympy.Expr) -> str: + s = self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + return f"({s})" + + def _print_CleanDiv(self, expr: sympy.Expr) -> str: + return self._print_FloorDiv(expr) + + def _print_Identity(self, expr: sympy.Expr) -> str: + return self._print(expr.args[0]) + + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + + +class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + + def _print_And(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " and ", precedence(expr)) + + def _print_Or(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " or ", precedence(expr)) + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + if div != "1": + x = f"({x} // {div})" + return f"({x} % {mod})" + + def _print_Infinity(self, expr: sympy.Expr) -> str: + return "math.inf" + + def _print_NegativeInfinity(self, expr: sympy.Expr) -> str: + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = (self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args) + return f"{x} // {div}" + + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " / ", PRECEDENCE["Atom"] - 0.5) + + def _helper_sqrt(self, expr: sympy.Expr) -> str: + return f"math.sqrt({self._print(expr)})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return self._helper_sqrt(expr.args[0]) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " ** ", PRECEDENCE["Pow"]) + + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float + return f"math.trunc({self._print(expr.args[0])})" + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"abs({self._print(expr.args[0])})" + + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion + def _print_Max(self, expr: sympy.Expr) -> str: + assert len(expr.args) >= 2 + return f"max({', '.join(map(self._print, expr.args))})" + + def _print_Min(self, expr: sympy.Expr) -> str: + assert len(expr.args) >= 2 + return f"min({', '.join(map(self._print, expr.args))})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"math.atan({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"round({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + assert isinstance(ndigits, sympy.Integer) + return f"round({self._print(number)}, {ndigits})" + + +class CppPrinter(ExprPrinter): + def _print_Integer(self, expr: sympy.Expr) -> str: + return ( + f"{int(expr)}LL" if sys.platform in ["darwin", "win32"] else f"{int(expr)}L" + ) + + def _print_Where(self, expr: sympy.Expr) -> str: + c, p, q = ( + self.parenthesize(arg, PRECEDENCE["Atom"] - 0.5) for arg in expr.args + ) + return f"{c} ? {p} : {q}" + + def _print_ModularIndexing(self, expr: sympy.Expr) -> str: + x, div, mod = expr.args + x = self.doprint(x) + if div != 1: + div = self.doprint(div) + if expr.is_integer: + x = f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + else: + x = f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + mod = self.doprint(mod) + return f"(static_cast<{INDEX_TYPE}>({x}) % static_cast<{INDEX_TYPE}>({mod}))" + + def _print_FloorDiv(self, expr: sympy.Expr) -> str: + x, div = expr.args + x = self.doprint(x) + div = self.doprint(div) + if expr.is_integer: + return f"c10::div_floor_integer(static_cast({x}), static_cast({div}))" + return f"c10::div_floor_floating(static_cast({x}), static_cast({div}))" + + def _print_floor(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_FloorToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_TruncToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::trunc({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr: sympy.Expr) -> str: + return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5) + + def _print_IntTrueDiv(self, expr: sympy.Expr) -> str: + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr: sympy.Expr) -> str: + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr: sympy.Expr) -> str: + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" + + def _print_Pow(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + base, exp = expr.args + + if exp == 0.5 or exp == -0.5: + base = self._print(base) + return f"std::sqrt({base})" if exp == 0.5 else f"1.0/std::sqrt({base})" + if exp.is_integer: + exp = int(exp) + if exp > 0: + r = self.stringify([base] * exp, "*", PRECEDENCE["Mul"]) + elif exp < -1: + r = ( + "1.0/(" + + self.stringify([base] * abs(exp), "*", PRECEDENCE["Mul"]) + + ")" + ) + elif exp == -1: + r = "1.0/" + self._print(base) + else: # exp == 0 + r = "1.0" + + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + else: + # TODO: float vs double + return f"std::pow({base}, {float(exp)})" + + def _print_Rational(self, expr: sympy.Expr) -> str: + # Uses float constants to perform FP div + if expr.q == 1: + r = f"{expr.p}" + else: + r = f"{expr.p}.0/{expr.q}.0" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_ceiling(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_CeilToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_Min(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::min(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::min({il})" + + def _print_Max(self, expr: sympy.Expr) -> str: + args = [self._print(a) for a in expr.args] + if len(args) == 2: + return f"std::max(static_cast<{INDEX_TYPE}>({args[0]}), static_cast<{INDEX_TYPE}>({args[1]}))" + else: + # Initializer list overload + il = "{" + ", ".join(args) + "}" + return f"std::max({il})" + + def _print_Abs(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::abs({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::cos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_cosh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::cosh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_acos(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::acos({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::sin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sinh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::sinh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_asin(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::asin({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::tan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_tanh(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::tanh({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_atan(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + return f"std::atan({self._print(expr.args[0])})" + + def _print_OpaqueUnaryFn_sqrt(self, expr: sympy.Expr) -> str: + return f"std::sqrt({self._print(expr.args[0])})" + + def _print_RoundToInt(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type + return f"std::lrint({self._print(expr.args[0])})" + + def _print_RoundDecimal(self, expr: sympy.Expr) -> str: + assert len(expr.args) == 2 + number, ndigits = expr.args + if number.is_integer: + # ndigits < 0 should have been filtered by the sympy function + assert ndigits < 0 + raise ValueError( + f"For integer inputs, only non-negative ndigits are currently supported, but got {ndigits}." + ) + number_str = self.parenthesize(number, PRECEDENCE["Mul"]) + return f"static_cast(std::nearbyint(1e{ndigits} * {number_str}) * 1e{-ndigits})" + + def _print_BooleanTrue(self, expr: sympy.Expr) -> str: + return "true" + + def _print_BooleanFalse(self, expr: sympy.Expr) -> str: + return "false"