mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AOTInductor] Fix autotuning code's codegen (#150522)
Summary:
Codegen used to generate tmp_arg_{index} as temporary args, and index is the position of the caller.
We changed the logic of codegen such that we can reuse previous generated samples, and only delete after arg is no longer used. In this case, we need to make {index} unique, since different functions could reuse the same "tmp_arg_{index}" name string, but corresponds to different args.
Test Plan: `python test/inductor/test_aot_inductor.py -k test_autotuning_args_reuse`
Differential Revision: D72297084
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150522
Approved by: https://github.com/desertfire, https://github.com/22quinn
This commit is contained in:
parent
24f50653c8
commit
f363fe616d
|
|
@ -70,6 +70,7 @@ if HAS_GPU:
|
|||
add_kernel_with_tma_2d,
|
||||
mul2_inplace_kernel,
|
||||
strange_config_matmul_kernel,
|
||||
sub_kernel_autotuned,
|
||||
)
|
||||
|
||||
if IS_WINDOWS and IS_CI:
|
||||
|
|
@ -4662,6 +4663,42 @@ class AOTInductorTestsTemplate:
|
|||
model, example_inputs, "aoti_torch_clone_preserve_strides", 0
|
||||
)
|
||||
|
||||
def test_autotuning_args_reuse(self):
|
||||
if self.device != GPU_TYPE:
|
||||
raise unittest.SkipTest("requires GPU")
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y):
|
||||
x_out = torch.empty_strided(
|
||||
(x.size()[0], x.size()[1]), (x.size()[1], 1), device=GPU_TYPE
|
||||
)
|
||||
x_out = torch.permute(x_out, [0, 1])
|
||||
add_kernel_autotuned[(4,)](x, x, x_out, 16)
|
||||
|
||||
y_out = torch.empty_strided(
|
||||
(y.size()[0], y.size()[1]), (y.size()[1], 1), device=GPU_TYPE
|
||||
)
|
||||
y_out = torch.permute(y_out, [0, 1])
|
||||
add_kernel_autotuned[(64,)](y, y, y_out, 64)
|
||||
|
||||
sub_kernel_autotuned[(4,)](x, x, x_out, 16)
|
||||
|
||||
return x_out, y_out
|
||||
|
||||
example_inputs = (
|
||||
torch.randn(4, 4, device=GPU_TYPE),
|
||||
torch.randn(8, 8, device=GPU_TYPE),
|
||||
)
|
||||
dim0_x = Dim("dim0_x", min=1, max=2048)
|
||||
dim0_y = Dim("dim0_y", min=1, max=2048)
|
||||
dynamic_shapes = {"x": {0: dim0_x}, "y": {0: dim0_y}}
|
||||
self.check_model(
|
||||
Model(),
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"max_autotune": True},
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Not runnable in fbcode")
|
||||
def test_stft(self):
|
||||
N_FFT = 400
|
||||
|
|
|
|||
|
|
@ -620,6 +620,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
# Map key is the kernel argument name; value is a tuple of the resulting example
|
||||
# tensor name with the kernel where that tensor was most recently used.
|
||||
self.kernel_autotune_example_args: dict[str, tuple[str, str]] = {}
|
||||
self.kernel_autotune_tmp_arg_idx: int = 0
|
||||
# If the generated source code is exactly the same, reuse the
|
||||
# pre-existing kernel for it
|
||||
self.src_to_kernel: dict[str, str] = {}
|
||||
|
|
@ -1991,7 +1992,7 @@ class PythonWrapperCodegen(CodeGen):
|
|||
|
||||
return [wrap_arg(arg) for arg in call_args]
|
||||
|
||||
def generate_example_arg_value(self, arg, arg_type, raw_arg=None, index=None):
|
||||
def generate_example_arg_value(self, arg, arg_type, raw_arg=None):
|
||||
if isinstance(arg_type, torch_dtype):
|
||||
if isinstance(raw_arg, ir.TMADescriptor):
|
||||
# first we generate the underlying buffer
|
||||
|
|
@ -2004,8 +2005,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
assert raw_arg is not None, (
|
||||
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
||||
)
|
||||
buf_name = f"tmp_arg_{index}"
|
||||
buf_name = f"tmp_arg_{self.kernel_autotune_tmp_arg_idx}"
|
||||
buf = raw_arg
|
||||
self.kernel_autotune_tmp_arg_idx += 1
|
||||
|
||||
size = tuple(
|
||||
V.graph.sizevars.atomically_apply_size_hint(
|
||||
|
|
@ -2182,13 +2184,13 @@ class PythonWrapperCodegen(CodeGen):
|
|||
arg_str = arg
|
||||
elif arg not in self.kernel_autotune_example_args:
|
||||
arg_str = self.generate_example_arg_value(
|
||||
arg, arg_type, raw_arg, i
|
||||
arg, arg_type, raw_arg
|
||||
)
|
||||
else:
|
||||
arg_str = self.kernel_autotune_example_args[arg][0]
|
||||
self.kernel_autotune_example_args[arg] = (arg_str, kernel_name)
|
||||
else:
|
||||
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg, i)
|
||||
arg_str = self.generate_example_arg_value(arg, arg_type, raw_arg)
|
||||
all_args.append(arg_str if key is None else f"{key}={arg_str}")
|
||||
|
||||
self.kernel_autotune_calls.writeline(
|
||||
|
|
|
|||
|
|
@ -117,6 +117,32 @@ if has_triton():
|
|||
output = x + y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
|
||||
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
|
||||
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
|
||||
],
|
||||
key=[],
|
||||
)
|
||||
@triton.jit
|
||||
def sub_kernel_autotuned(
|
||||
in_ptr0,
|
||||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x - y
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
@triton.autotune(
|
||||
configs=[
|
||||
triton.Config({"BLOCK_SIZE": 16}, num_stages=2, num_warps=2),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user