diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index 7c3405e370c..587ad83cafe 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -1059,6 +1059,23 @@ class AOTInductorTestsTemplate: example_inputs = (x, y) self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + def test_large_dynamic_dim(self): + class Model(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, y): + add_0 = x + y + return torch.nn.functional.relu(input=add_0, inplace=False) + + x = torch.randn(128, 2048, device=self.device) + y = torch.randn(128, 2048, device=self.device) + # Use a dimension that exceeds the maximum value of a C long long (2^63 - 1) + dim0_x = Dim("dim0_x", min=1, max=1171368248680556527362) + dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}} + example_inputs = (x, y) + self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes) + @unittest.skipIf( not PLATFORM_SUPPORTS_FP8, "FP8 is only supported on H100+, SM 8.9 and MI300+ devices", diff --git a/torch/_inductor/codegen/cpp_wrapper_cpu.py b/torch/_inductor/codegen/cpp_wrapper_cpu.py index bdaf74952ce..08c6a586e98 100644 --- a/torch/_inductor/codegen/cpp_wrapper_cpu.py +++ b/torch/_inductor/codegen/cpp_wrapper_cpu.py @@ -1,6 +1,7 @@ # mypy: allow-untyped-defs from __future__ import annotations +import ctypes import functools import math import os @@ -405,12 +406,15 @@ class CppWrapperCpu(PythonWrapperCodegen): """ ) if not math.isinf(sym_range.upper): + # Limit upper bound to max C long long value (2^63 - 1) + max_long_long = ctypes.c_longlong(2**63 - 1).value + upper_bound = min(sym_range.upper, max_long_long) self.prefix.splice( f""" - if ({name}_size[{dim_idx}] > {sym_range.upper}) {{ + if ({name}_size[{dim_idx}] > {upper_bound}) {{ std::stringstream ss; ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, " - << "expected to be <= {sym_range.upper}, " << "but got: " + << "expected to be <= {upper_bound}, " << "but got: " << {name}_size[{dim_idx}] << "\\n"; throw std::runtime_error(ss.str()); }}