[aoti] Check longlong upperbound for codegening input size check (#156522)

Summary:
Fixes
```
error: integer literal is too large to be represented in any integer type
 38979 |     if (arg410_1_size[0] > 1171368248680556527362) {
```

Test Plan: ci

Differential Revision: D77057898

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156522
Approved by: https://github.com/jingsh, https://github.com/desertfire
This commit is contained in:
Colin Peppler 2025-06-23 20:38:30 +00:00 committed by PyTorch MergeBot
parent edd9c09e73
commit dfdd636cfa
2 changed files with 23 additions and 2 deletions

View File

@ -1059,6 +1059,23 @@ class AOTInductorTestsTemplate:
example_inputs = (x, y)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
def test_large_dynamic_dim(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
add_0 = x + y
return torch.nn.functional.relu(input=add_0, inplace=False)
x = torch.randn(128, 2048, device=self.device)
y = torch.randn(128, 2048, device=self.device)
# Use a dimension that exceeds the maximum value of a C long long (2^63 - 1)
dim0_x = Dim("dim0_x", min=1, max=1171368248680556527362)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
example_inputs = (x, y)
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FP8,
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
from __future__ import annotations
import ctypes
import functools
import math
import os
@ -405,12 +406,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""
)
if not math.isinf(sym_range.upper):
# Limit upper bound to max C long long value (2^63 - 1)
max_long_long = ctypes.c_longlong(2**63 - 1).value
upper_bound = min(sym_range.upper, max_long_long)
self.prefix.splice(
f"""
if ({name}_size[{dim_idx}] > {sym_range.upper}) {{
if ({name}_size[{dim_idx}] > {upper_bound}) {{
std::stringstream ss;
ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
<< "expected to be <= {sym_range.upper}, " << "but got: "
<< "expected to be <= {upper_bound}, " << "but got: "
<< {name}_size[{dim_idx}] << "\\n";
throw std::runtime_error(ss.str());
}}