mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti] Check longlong upperbound for codegening input size check (#156522)
Summary:
Fixes
```
error: integer literal is too large to be represented in any integer type
38979 | if (arg410_1_size[0] > 1171368248680556527362) {
```
Test Plan: ci
Differential Revision: D77057898
Pull Request resolved: https://github.com/pytorch/pytorch/pull/156522
Approved by: https://github.com/jingsh, https://github.com/desertfire
This commit is contained in:
parent
edd9c09e73
commit
dfdd636cfa
|
|
@ -1059,6 +1059,23 @@ class AOTInductorTestsTemplate:
|
||||||
example_inputs = (x, y)
|
example_inputs = (x, y)
|
||||||
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||||
|
|
||||||
|
def test_large_dynamic_dim(self):
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self) -> None:
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, x, y):
|
||||||
|
add_0 = x + y
|
||||||
|
return torch.nn.functional.relu(input=add_0, inplace=False)
|
||||||
|
|
||||||
|
x = torch.randn(128, 2048, device=self.device)
|
||||||
|
y = torch.randn(128, 2048, device=self.device)
|
||||||
|
# Use a dimension that exceeds the maximum value of a C long long (2^63 - 1)
|
||||||
|
dim0_x = Dim("dim0_x", min=1, max=1171368248680556527362)
|
||||||
|
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_x}}
|
||||||
|
example_inputs = (x, y)
|
||||||
|
self.check_model(Model(), example_inputs, dynamic_shapes=dynamic_shapes)
|
||||||
|
|
||||||
@unittest.skipIf(
|
@unittest.skipIf(
|
||||||
not PLATFORM_SUPPORTS_FP8,
|
not PLATFORM_SUPPORTS_FP8,
|
||||||
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
"FP8 is only supported on H100+, SM 8.9 and MI300+ devices",
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ctypes
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
|
|
@ -405,12 +406,15 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
if not math.isinf(sym_range.upper):
|
if not math.isinf(sym_range.upper):
|
||||||
|
# Limit upper bound to max C long long value (2^63 - 1)
|
||||||
|
max_long_long = ctypes.c_longlong(2**63 - 1).value
|
||||||
|
upper_bound = min(sym_range.upper, max_long_long)
|
||||||
self.prefix.splice(
|
self.prefix.splice(
|
||||||
f"""
|
f"""
|
||||||
if ({name}_size[{dim_idx}] > {sym_range.upper}) {{
|
if ({name}_size[{dim_idx}] > {upper_bound}) {{
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
|
ss << "{handle_kind}[{idx}]: dim value is too large at {dim_idx}, "
|
||||||
<< "expected to be <= {sym_range.upper}, " << "but got: "
|
<< "expected to be <= {upper_bound}, " << "but got: "
|
||||||
<< {name}_size[{dim_idx}] << "\\n";
|
<< {name}_size[{dim_idx}] << "\\n";
|
||||||
throw std::runtime_error(ss.str());
|
throw std::runtime_error(ss.str());
|
||||||
}}
|
}}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user