mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aotinductor] fix std::{min.max} compilation error for sympy expr with multiple args (#150894)
### Compilation error
The issue is that u0 (an unbacked symint) can come from a smaller int dtype e.g. int16, int32.
```
error: no matching function for call to ‘min(int64_t&, short int&)’
759 | call_add_kernel_with_scaling_0(... std::min(100L, s97, u0) ...);
```
### Diff
The fix is to explicitly specify `int64_t` in the std::min template.
```
int64_t s97 = arg0_1_size[0];
int16_t u0_raw; # not a long
auto u0 = u0_raw;
# Before
std::min({100L, s97, u0})
# After
std::min<int64_t>({100L, s97, u0})
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150894
Approved by: https://github.com/desertfire
This commit is contained in:
parent
44ed0c9fbb
commit
5590a0692c
|
|
@ -2414,6 +2414,37 @@ class AOTInductorTestsTemplate:
|
||||||
)
|
)
|
||||||
self.check_model(Model(), example_inputs)
|
self.check_model(Model(), example_inputs)
|
||||||
|
|
||||||
|
@common_utils.parametrize("minmax", [min, max])
|
||||||
|
def test_sympy_cpp_printer_min_max(self, minmax):
|
||||||
|
if self.device != GPU_TYPE:
|
||||||
|
raise unittest.SkipTest("requires GPU")
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def forward(self, a, b, ranks):
|
||||||
|
n_elements = a.numel()
|
||||||
|
out = torch.empty_like(a)
|
||||||
|
backed = a.size(0)
|
||||||
|
unbacked = int(ranks.max())
|
||||||
|
scaling_factor = minmax(backed, unbacked, 100)
|
||||||
|
add_kernel_with_scaling[(n_elements,)](
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
out,
|
||||||
|
n_elements,
|
||||||
|
scaling_factor,
|
||||||
|
BLOCK_SIZE=16,
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
|
example_inputs = (
|
||||||
|
torch.randn(16, device=self.device),
|
||||||
|
torch.randn(16, device=self.device),
|
||||||
|
torch.arange(end=4, device=self.device, dtype=torch.int16),
|
||||||
|
)
|
||||||
|
torch._dynamo.mark_dynamic(example_inputs[0], 0)
|
||||||
|
torch._dynamo.mark_dynamic(example_inputs[1], 0)
|
||||||
|
self.check_model(Model(), example_inputs)
|
||||||
|
|
||||||
@common_utils.parametrize("grid_type", [1, 2, 3])
|
@common_utils.parametrize("grid_type", [1, 2, 3])
|
||||||
@common_utils.parametrize("num_dims", [1, 2])
|
@common_utils.parametrize("num_dims", [1, 2])
|
||||||
@common_utils.parametrize("dynamic", [False, True])
|
@common_utils.parametrize("dynamic", [False, True])
|
||||||
|
|
|
||||||
|
|
@ -432,9 +432,9 @@ class ExprPrinterTests(InductorTestCase):
|
||||||
)
|
)
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
cexpr(expr),
|
cexpr(expr),
|
||||||
f"std::{s}({{x, 2LL*x, 3LL*x}})"
|
f"std::{s}<int64_t>({{x, 2LL*x, 3LL*x}})"
|
||||||
if sys.platform in ["darwin", "win32"]
|
if sys.platform in ["darwin", "win32"]
|
||||||
else f"std::{s}({{x, 2L*x, 3L*x}})",
|
else f"std::{s}<int64_t>({{x, 2L*x, 3L*x}})",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -419,7 +419,7 @@ class CppPrinter(ExprPrinter):
|
||||||
else:
|
else:
|
||||||
# Initializer list overload
|
# Initializer list overload
|
||||||
il = "{" + ", ".join(args) + "}"
|
il = "{" + ", ".join(args) + "}"
|
||||||
return f"std::min({il})"
|
return f"std::min<{INDEX_TYPE}>({il})"
|
||||||
|
|
||||||
def _print_Max(self, expr: sympy.Expr) -> str:
|
def _print_Max(self, expr: sympy.Expr) -> str:
|
||||||
args = [self._print(a) for a in expr.args]
|
args = [self._print(a) for a in expr.args]
|
||||||
|
|
@ -428,7 +428,7 @@ class CppPrinter(ExprPrinter):
|
||||||
else:
|
else:
|
||||||
# Initializer list overload
|
# Initializer list overload
|
||||||
il = "{" + ", ".join(args) + "}"
|
il = "{" + ", ".join(args) + "}"
|
||||||
return f"std::max({il})"
|
return f"std::max<{INDEX_TYPE}>({il})"
|
||||||
|
|
||||||
def _print_Abs(self, expr: sympy.Expr) -> str:
|
def _print_Abs(self, expr: sympy.Expr) -> str:
|
||||||
assert len(expr.args) == 1
|
assert len(expr.args) == 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user