[aotinductor] fix std::{min.max} compilation error for sympy expr with multiple args (#150894)

### Compilation error
The issue is that u0 (an unbacked symint) can come from a smaller int dtype e.g. int16, int32.
```
error: no matching function for call to ‘min(int64_t&, short int&)’
  759 |     call_add_kernel_with_scaling_0(... std::min(100L, s97, u0) ...);
```

### Diff
The fix is to explicitly specify `int64_t` in the std::min template.
```
int64_t s97 = arg0_1_size[0];
int16_t u0_raw;      # not a long
auto u0 = u0_raw;

# Before
std::min({100L, s97, u0})
# After
std::min<int64_t>({100L, s97, u0})
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150894
Approved by: https://github.com/desertfire
This commit is contained in:
Colin Peppler 2025-04-10 22:34:09 +00:00 committed by PyTorch MergeBot
parent 44ed0c9fbb
commit 5590a0692c
3 changed files with 35 additions and 4 deletions

View File

@ -2414,6 +2414,37 @@ class AOTInductorTestsTemplate:
)
self.check_model(Model(), example_inputs)
@common_utils.parametrize("minmax", [min, max])
def test_sympy_cpp_printer_min_max(self, minmax):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, a, b, ranks):
n_elements = a.numel()
out = torch.empty_like(a)
backed = a.size(0)
unbacked = int(ranks.max())
scaling_factor = minmax(backed, unbacked, 100)
add_kernel_with_scaling[(n_elements,)](
a,
b,
out,
n_elements,
scaling_factor,
BLOCK_SIZE=16,
)
return out
example_inputs = (
torch.randn(16, device=self.device),
torch.randn(16, device=self.device),
torch.arange(end=4, device=self.device, dtype=torch.int16),
)
torch._dynamo.mark_dynamic(example_inputs[0], 0)
torch._dynamo.mark_dynamic(example_inputs[1], 0)
self.check_model(Model(), example_inputs)
@common_utils.parametrize("grid_type", [1, 2, 3])
@common_utils.parametrize("num_dims", [1, 2])
@common_utils.parametrize("dynamic", [False, True])

View File

@ -432,9 +432,9 @@ class ExprPrinterTests(InductorTestCase):
)
self.assertEqual(
cexpr(expr),
f"std::{s}({{x, 2LL*x, 3LL*x}})"
f"std::{s}<int64_t>({{x, 2LL*x, 3LL*x}})"
if sys.platform in ["darwin", "win32"]
else f"std::{s}({{x, 2L*x, 3L*x}})",
else f"std::{s}<int64_t>({{x, 2L*x, 3L*x}})",
)

View File

@ -419,7 +419,7 @@ class CppPrinter(ExprPrinter):
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::min({il})"
return f"std::min<{INDEX_TYPE}>({il})"
def _print_Max(self, expr: sympy.Expr) -> str:
args = [self._print(a) for a in expr.args]
@ -428,7 +428,7 @@ class CppPrinter(ExprPrinter):
else:
# Initializer list overload
il = "{" + ", ".join(args) + "}"
return f"std::max({il})"
return f"std::max<{INDEX_TYPE}>({il})"
def _print_Abs(self, expr: sympy.Expr) -> str:
assert len(expr.args) == 1