[aoti] skip input symbol codegen for sympy expr w/ many symbols (#152579)

Issue was that
- symbol-ids appeared out-of-order w.r.t to the order of the forward inputs
```
def forward(arg0 # [(s3 - 1) + s4, 32], arg1 #[(s3 - 1)] ..)
```
- this causes codegen to fail because it expects all the base symbols `s4,s3` to have been codegen-ed already.
- well, we can skip codegen-ing sympy expr with many symbols e.g. `(s3 - 1) + s4` because `s3` and `s4` will be codegen-ed by other inputs.

```
# for example
s3 = arg1.size(0) + 1
s4 = argN.size(0)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152579
Approved by: https://github.com/jingsh, https://github.com/desertfire
This commit is contained in:
Colin Peppler 2025-05-06 11:11:10 -07:00 committed by PyTorch MergeBot
parent 60ecc560af
commit 81b6920c68
3 changed files with 68 additions and 49 deletions

View File

@ -2155,41 +2155,6 @@ class AOTInductorTestsTemplate:
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
@requires_multigpu()
def test_load_package_multiple_gpus(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x, y):
return x + torch.nn.functional.linear(y, self.weight)
weight = torch.randn(10, 10, device=self.device)
inputs = (
torch.randn(10, 10, device=self.device),
torch.randn(10, 10, device=self.device),
)
model = Model(weight).to(device=self.device)
result_ref = model(*inputs)
package_path = AOTIRunnerUtil.compile(model, inputs)
# Load AOT package on gpu:N
device_interface = get_interface_for_device(GPU_TYPE)
for i in range(device_interface.device_count()):
device = torch.device(GPU_TYPE, i)
with device_interface.device(i), torch.no_grad():
model_package = torch._inductor.aoti_load_package(
package_path, device_index=i
)
inputs_on_device = [input.to(device=device) for input in inputs]
result_package = model_package(*inputs_on_device)
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
def test_reuse_kernel(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
@ -3063,8 +3028,8 @@ class AOTInductorTestsTemplate:
if dynamic:
dim0_xy = Dim("s0", min=2, max=1024)
dynamic_shapes = {
"x": {0: dim0_xy},
"y": {0: dim0_xy},
"x": {0: dim0_xy, 1: None},
"y": {0: dim0_xy, 1: None},
}
example_inputs = (
torch.randn(2, device=self.device),
@ -4513,6 +4478,39 @@ class AOTInductorTestsTemplate:
).run(code)
self.check_model(Model(), example_inputs)
def test_input_codegen_with_sympy_expr(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class MyModel(torch.nn.Module):
def forward(self, getitem_54, getitem_52, getitem_19, values_2, offsets):
bitwise_or = torch.bitwise_or(getitem_54, getitem_52)
combined = torch.cat([getitem_19, values_2], dim=0)
add = combined + bitwise_or
sliced = values_2[:-1] + offsets
return add, sliced
inps = (
torch.randint(0, 1, (240,), device="cuda", dtype=torch.uint8),
torch.randint(0, 1, (240,), device="cuda", dtype=torch.uint8),
torch.randn((192,), device="cuda"),
torch.randn((48,), device="cuda"),
torch.randint(0, 100, (47,), device="cuda", dtype=torch.uint8),
)
dim = torch.export.Dim("dimensionality")
derived_dim = 2 * dim
spec = {
"getitem_54": (Dim.AUTO,), # [s33 + 2*s40 + 1]
"getitem_52": (Dim.AUTO,), # [s33 + 2*s40 + 1]
"getitem_19": (derived_dim,), # [2*s40]
"values_2": (Dim.AUTO,), # [s33 + 1]
"offsets": (Dim.AUTO,), # [s33]
}
self.check_model(MyModel(), inps, dynamic_shapes=spec)
@common_utils.parametrize("mark_unbacked", (True, False))
def test_unbacked_equals_input_size_runtime_assertion(self, mark_unbacked: bool):
# This test checks the unbacked symint runtime assertions, for the following cases:

View File

@ -278,21 +278,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];")
bound_vars.add(sym_or_exp)
elif isinstance(sym_or_exp, sympy.Expr):
free_symbol = None
for sym in sym_or_exp.free_symbols:
if sym not in bound_vars:
if free_symbol is None:
free_symbol = sym
else:
raise AssertionError(
str(sym_or_exp)
+ " contains more than one undefined symbols"
)
if free_symbol is None:
undefined_symbols = [
sym for sym in sym_or_exp.free_symbols if sym not in bound_vars
]
if len(undefined_symbols) != 1:
# Skip if expression contains no symbols or if multiple
# symbols exists since we assume each base symbol is defined
# by other codegen_symbol calls.
return
from torch.utils._sympy.solve import try_solve
free_symbol = undefined_symbols.pop()
base_name = name_fn(base_name)
# Use a size symbol to solve the free symbol
size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True)

View File

@ -12,7 +12,7 @@ import operator
import random
import re
import tempfile
from itertools import count
from itertools import chain, count
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import sympy
@ -1651,6 +1651,30 @@ class PythonWrapperCodegen(CodeGen):
for name, value in inputs:
self.codegen_input_symbol_assignment(name, value, bound_vars)
def _verify_input_symbol_assignment(
value: ir.TensorBox,
bound_vars: OrderedSet[sympy.Symbol],
):
for expr in chain.from_iterable([value.get_size(), value.get_stride()]):
if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol):
continue
undefined_symbols = [
sym for sym in expr.free_symbols if sym not in bound_vars
]
if len(undefined_symbols) > 0:
raise AssertionError(
f"For {expr}, expected {undefined_symbols} to have been codegen-ed."
)
# For inputs with size/strides which contain sympy expressions, we can
# encounter symbols that weren't defined yet. Now, let's check each
# symbol is defined.
for _, value in inputs:
if not isinstance(value, ir.TensorBox):
continue
_verify_input_symbol_assignment(value, bound_vars)
def ensure_size_computed(self, sym: sympy.Symbol):
if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
if sym in self.computed_sizes: