mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aoti] skip input symbol codegen for sympy expr w/ many symbols (#152579)
Issue was that - symbol-ids appeared out-of-order w.r.t to the order of the forward inputs ``` def forward(arg0 # [(s3 - 1) + s4, 32], arg1 #[(s3 - 1)] ..) ``` - this causes codegen to fail because it expects all the base symbols `s4,s3` to have been codegen-ed already. - well, we can skip codegen-ing sympy expr with many symbols e.g. `(s3 - 1) + s4` because `s3` and `s4` will be codegen-ed by other inputs. ``` # for example s3 = arg1.size(0) + 1 s4 = argN.size(0) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/152579 Approved by: https://github.com/jingsh, https://github.com/desertfire
This commit is contained in:
parent
60ecc560af
commit
81b6920c68
|
|
@ -2155,41 +2155,6 @@ class AOTInductorTestsTemplate:
|
|||
self.assertTrue(same(result_cpu, result_gpu_0.cpu()))
|
||||
self.assertTrue(same(result_cpu, result_gpu_1.cpu()))
|
||||
|
||||
@requires_multigpu()
|
||||
def test_load_package_multiple_gpus(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, x, y):
|
||||
return x + torch.nn.functional.linear(y, self.weight)
|
||||
|
||||
weight = torch.randn(10, 10, device=self.device)
|
||||
inputs = (
|
||||
torch.randn(10, 10, device=self.device),
|
||||
torch.randn(10, 10, device=self.device),
|
||||
)
|
||||
model = Model(weight).to(device=self.device)
|
||||
result_ref = model(*inputs)
|
||||
|
||||
package_path = AOTIRunnerUtil.compile(model, inputs)
|
||||
|
||||
# Load AOT package on gpu:N
|
||||
device_interface = get_interface_for_device(GPU_TYPE)
|
||||
for i in range(device_interface.device_count()):
|
||||
device = torch.device(GPU_TYPE, i)
|
||||
with device_interface.device(i), torch.no_grad():
|
||||
model_package = torch._inductor.aoti_load_package(
|
||||
package_path, device_index=i
|
||||
)
|
||||
inputs_on_device = [input.to(device=device) for input in inputs]
|
||||
result_package = model_package(*inputs_on_device)
|
||||
self.assertTrue(same(result_ref.cpu(), result_package.cpu()))
|
||||
|
||||
def test_reuse_kernel(self):
|
||||
class Model(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -3063,8 +3028,8 @@ class AOTInductorTestsTemplate:
|
|||
if dynamic:
|
||||
dim0_xy = Dim("s0", min=2, max=1024)
|
||||
dynamic_shapes = {
|
||||
"x": {0: dim0_xy},
|
||||
"y": {0: dim0_xy},
|
||||
"x": {0: dim0_xy, 1: None},
|
||||
"y": {0: dim0_xy, 1: None},
|
||||
}
|
||||
example_inputs = (
|
||||
torch.randn(2, device=self.device),
|
||||
|
|
@ -4513,6 +4478,39 @@ class AOTInductorTestsTemplate:
|
|||
).run(code)
|
||||
self.check_model(Model(), example_inputs)
|
||||
|
||||
def test_input_codegen_with_sympy_expr(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class MyModel(torch.nn.Module):
|
||||
def forward(self, getitem_54, getitem_52, getitem_19, values_2, offsets):
|
||||
bitwise_or = torch.bitwise_or(getitem_54, getitem_52)
|
||||
combined = torch.cat([getitem_19, values_2], dim=0)
|
||||
add = combined + bitwise_or
|
||||
|
||||
sliced = values_2[:-1] + offsets
|
||||
return add, sliced
|
||||
|
||||
inps = (
|
||||
torch.randint(0, 1, (240,), device="cuda", dtype=torch.uint8),
|
||||
torch.randint(0, 1, (240,), device="cuda", dtype=torch.uint8),
|
||||
torch.randn((192,), device="cuda"),
|
||||
torch.randn((48,), device="cuda"),
|
||||
torch.randint(0, 100, (47,), device="cuda", dtype=torch.uint8),
|
||||
)
|
||||
|
||||
dim = torch.export.Dim("dimensionality")
|
||||
derived_dim = 2 * dim
|
||||
spec = {
|
||||
"getitem_54": (Dim.AUTO,), # [s33 + 2*s40 + 1]
|
||||
"getitem_52": (Dim.AUTO,), # [s33 + 2*s40 + 1]
|
||||
"getitem_19": (derived_dim,), # [2*s40]
|
||||
"values_2": (Dim.AUTO,), # [s33 + 1]
|
||||
"offsets": (Dim.AUTO,), # [s33]
|
||||
}
|
||||
|
||||
self.check_model(MyModel(), inps, dynamic_shapes=spec)
|
||||
|
||||
@common_utils.parametrize("mark_unbacked", (True, False))
|
||||
def test_unbacked_equals_input_size_runtime_assertion(self, mark_unbacked: bool):
|
||||
# This test checks the unbacked symint runtime assertions, for the following cases:
|
||||
|
|
|
|||
|
|
@ -278,21 +278,18 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
code.writeline(f"int64_t {sym_or_exp} = {name_fn(base_name)}[{dim}];")
|
||||
bound_vars.add(sym_or_exp)
|
||||
elif isinstance(sym_or_exp, sympy.Expr):
|
||||
free_symbol = None
|
||||
for sym in sym_or_exp.free_symbols:
|
||||
if sym not in bound_vars:
|
||||
if free_symbol is None:
|
||||
free_symbol = sym
|
||||
else:
|
||||
raise AssertionError(
|
||||
str(sym_or_exp)
|
||||
+ " contains more than one undefined symbols"
|
||||
)
|
||||
if free_symbol is None:
|
||||
undefined_symbols = [
|
||||
sym for sym in sym_or_exp.free_symbols if sym not in bound_vars
|
||||
]
|
||||
if len(undefined_symbols) != 1:
|
||||
# Skip if expression contains no symbols or if multiple
|
||||
# symbols exists since we assume each base symbol is defined
|
||||
# by other codegen_symbol calls.
|
||||
return
|
||||
|
||||
from torch.utils._sympy.solve import try_solve
|
||||
|
||||
free_symbol = undefined_symbols.pop()
|
||||
base_name = name_fn(base_name)
|
||||
# Use a size symbol to solve the free symbol
|
||||
size_symbol = sympy.Symbol(f"{base_name}_{dim}", integer=True)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ import operator
|
|||
import random
|
||||
import re
|
||||
import tempfile
|
||||
from itertools import count
|
||||
from itertools import chain, count
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
|
@ -1651,6 +1651,30 @@ class PythonWrapperCodegen(CodeGen):
|
|||
for name, value in inputs:
|
||||
self.codegen_input_symbol_assignment(name, value, bound_vars)
|
||||
|
||||
def _verify_input_symbol_assignment(
|
||||
value: ir.TensorBox,
|
||||
bound_vars: OrderedSet[sympy.Symbol],
|
||||
):
|
||||
for expr in chain.from_iterable([value.get_size(), value.get_stride()]):
|
||||
if not isinstance(expr, Expr) or isinstance(expr, sympy.Symbol):
|
||||
continue
|
||||
|
||||
undefined_symbols = [
|
||||
sym for sym in expr.free_symbols if sym not in bound_vars
|
||||
]
|
||||
if len(undefined_symbols) > 0:
|
||||
raise AssertionError(
|
||||
f"For {expr}, expected {undefined_symbols} to have been codegen-ed."
|
||||
)
|
||||
|
||||
# For inputs with size/strides which contain sympy expressions, we can
|
||||
# encounter symbols that weren't defined yet. Now, let's check each
|
||||
# symbol is defined.
|
||||
for _, value in inputs:
|
||||
if not isinstance(value, ir.TensorBox):
|
||||
continue
|
||||
_verify_input_symbol_assignment(value, bound_vars)
|
||||
|
||||
def ensure_size_computed(self, sym: sympy.Symbol):
|
||||
if isinstance(sym, sympy.Symbol) and symbol_is_type(sym, SymT.PRECOMPUTED_SIZE):
|
||||
if sym in self.computed_sizes:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user