[AOTInductor] Fix autotuning code's codegen (#150522)

Summary:
Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller.
We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args.

Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse`

Differential Revision: D72297084

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522
Approved by: https://github.com/desertfire, https://github.com/22quinn
This commit is contained in:
Mu-Chu Lee 2025-04-03 00:08:19 +00:00 committed by PyTorch MergeBot
parent 24f50653c8
commit f363fe616d
3 changed files with 69 additions and 4 deletions

View File

@ -70,6 +70,7 @@ if HAS_GPU:
add_kernel_with_tma_2d,
mul2_inplace_kernel,
strange_config_matmul_kernel,
sub_kernel_autotuned,
)
if IS_WINDOWS and IS_CI:
@ -4662,6 +4663,42 @@ class AOTInductorTestsTemplate:
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
)
def test_autotuning_args_reuse(self):
if self.device != GPU_TYPE:
raise unittest.SkipTest("requires GPU")
class Model(torch.nn.Module):
def forward(self, x, y):
x_out = torch.empty_strided(
(x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE
)
x_out = torch.permute(x_out, [0, 1])
add_kernel_autotuned[(4,)](x, x, x_out, 16)
y_out = torch.empty_strided(
(y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE
)
y_out = torch.permute(y_out, [0, 1])
add_kernel_autotuned[(64,)](y, y, y_out, 64)
sub_kernel_autotuned[(4,)](x, x, x_out, 16)
return x_out, y_out
example_inputs = (
torch.randn(4, 4, device=GPU_TYPE),
torch.randn(8, 8, device=GPU_TYPE),
)
dim0_x = Dim("dim0_x", min=1, max=2048)
dim0_y = Dim("dim0_y", min=1, max=2048)
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
self.check_model(
Model(),
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"max_autotune": True},
)
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
def test_stft(self):
N_FFT = 400

View File

@ -620,6 +620,7 @@ class PythonWrapperCodegen(CodeGen):
# Map key is the kernel argument name; value is a tuple of the resulting example
# tensor name with the kernel where that tensor was most recently used.
self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
self.kernel_autotune_tmp_arg_idx: int = 0
# If the generated source code is exactly the same, reuse the
# pre-existing kernel for it
self.src_to_kernel: dict[str, str] = {}
@ -1991,7 +1992,7 @@ class PythonWrapperCodegen(CodeGen):
return [wrap_arg(arg) for arg in call_args]
def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None):
def generate_example_arg_value(self, arg, arg_type, raw_arg=None):
if isinstance(arg_type, torch_dtype):
if isinstance(raw_arg, ir.TMADescriptor):
# first we generate the underlying buffer
@ -2004,8 +2005,9 @@ class PythonWrapperCodegen(CodeGen):
assert raw_arg is not None, (
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
)
buf_name = f"tmp_arg_{index}"
buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}"
buf = raw_arg
self.kernel_autotune_tmp_arg_idx += 1
size = tuple(
V.graph.sizevars.atomically_apply_size_hint(
@ -2182,13 +2184,13 @@ class PythonWrapperCodegen(CodeGen):
arg_str = arg
elif arg not in self.kernel_autotune_example_args:
arg_str = self.generate_example_arg_value(
arg, arg_type, raw_arg, i
arg, arg_type, raw_arg
)
else:
arg_str = self.kernel_autotune_example_args[arg][0]
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
else:
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i)
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
all_args.append(arg_str if key is None else f"{key}={arg_str}")
self.kernel_autotune_calls.writeline(

View File

@ -117,6 +117,32 @@ if has_triton():
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def sub_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x - y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),