pytorch/test/inductor/test_custom_op_autotune.py
2025-10-31 10:20:48 +00:00

508 lines
18 KiB
Python

# Owner(s): ["module: inductor"]
"""
Tests for custom operation autotuning with PyTorch Inductor.
Validates that custom ops can be registered with multiple CustomOpConfigs, where each
config specifies an optional decomposition function and its associated parameters.
Inductor benchmarks all variants and automatically selects the best performing one.
"""
import torch
from torch._inductor import config
from torch._inductor.kernel.custom_op import (
CustomOpConfig,
register_custom_op_autotuning,
)
from torch._inductor.test_case import run_tests, TestCase
from torch.testing._internal.common_utils import skipIfXpu
from torch.testing._internal.inductor_utils import HAS_GPU
torch.set_float32_matmul_precision("high")
class TestCustomOpAutoTune(TestCase):
"""Test custom operation autotuning functionality."""
def setUp(self) -> None:
"""Set up test environment with appropriate device and dtype."""
super().setUp()
self.device = "cuda" if HAS_GPU else "cpu"
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
def _run_autotune_test(self, op_object, inputs, expected, test_name):
"""Shared test infrastructure for autotuning tests."""
@torch.compile
def test_model(*args):
return op_object(*args)
torch._dynamo.reset()
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
with config.patch(
max_autotune=True,
max_autotune_gemm_backends=autotune_backends,
fx_graph_cache=False,
benchmark_kernel=True,
):
compiled_result = test_model(*inputs)
self.assertEqual(
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
)
torch.testing.assert_close(
compiled_result,
expected,
rtol=2e-1,
atol=5e-1,
msg=f"{test_name} numerical mismatch",
)
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
"""Utility to assert that all implementations produce equivalent results."""
implementations = [(func.__name__, func) for func in decompositions]
results = {}
for name, impl in implementations:
result = impl(*inputs)
results[name] = result
# Basic sanity checks
self.assertTrue(
torch.isfinite(result).all(),
f"{op_name} {name} produced non-finite values",
)
# Verify numerical equivalence
reference_name, reference_result = next(iter(results.items()))
for name, result in results.items():
if name != reference_name:
rtol = 1e-1 if "Approximated" in name else 1e-2
atol = 1e-1 if "Approximated" in name else 1e-2
torch.testing.assert_close(
result,
reference_result,
rtol=rtol,
atol=atol,
msg=f"{op_name} {name} differs from {reference_name}",
)
def _create_rmsnorm_inputs(self, batch_size=32, seq_len=2048, hidden_dim=512):
"""Create test inputs for RMSNorm operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
weight = torch.randn(
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
)
return input_tensor, weight
def _create_mlp_inputs(
self,
batch_size=2,
seq_len=32,
hidden_dim=512,
intermediate_dim=1024,
output_dim=256,
):
"""Create test inputs for MLP operations."""
input_tensor = torch.randn(
batch_size,
seq_len,
hidden_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
gate_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
up_weight = torch.randn(
hidden_dim,
intermediate_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
down_weight = torch.randn(
intermediate_dim,
output_dim,
device=self.device,
dtype=self.dtype,
requires_grad=False,
)
return input_tensor, gate_weight, up_weight, down_weight
@skipIfXpu
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
"""Test RMSNorm autotuning with multiple decomposition variants and dynamic shapes.
Validates:
- Multiple decomposition implementations with different computational approaches
- Dynamic shape handling across multiple compilations
"""
test_op_name = f"test_lib::rmsnorm_{id(self)}"
def rmsnorm_decomposition1(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Variance-based approach: compute variance then rsqrt."""
variance = x.pow(2).mean(dim=-1, keepdim=True)
rstd = torch.rsqrt(variance + eps)
return x * rstd * weight
def rmsnorm_decomposition2(
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
"""Separate normalization and scaling: compute normalized value then scale."""
x_var = x
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
x = x * torch.rsqrt(variance + eps)
x = x * weight
return x
@torch.library.custom_op(test_op_name, mutates_args=())
def test_rmsnorm_op(
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
) -> torch.Tensor:
return torch.nn.functional.rms_norm(
input_tensor, input_tensor.shape[-1:], weight, eps=eps
)
@test_rmsnorm_op.register_fake
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
return torch.empty_like(input_tensor)
decompositions = [
rmsnorm_decomposition1,
rmsnorm_decomposition2,
]
register_custom_op_autotuning(
test_rmsnorm_op,
configs=[CustomOpConfig(decomp) for decomp in decompositions],
name="test_rmsnorm_autotuned",
input_gen_fns={
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
"weight": lambda weight: torch.ones_like(weight, device=self.device),
},
)
# Test multiple shapes to verify dynamic shape handling
test_shapes = [(2, 16, 128), (8, 32, 256)]
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
input_tensor, weight = self._create_rmsnorm_inputs(
batch_size, seq_len, hidden_dim
)
# Test numerical equivalence for all decompositions
self._assert_implementations_equivalent(
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
)
# Test autotuning
expected = rmsnorm_decomposition1(input_tensor, weight)
self._run_autotune_test(
test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}"
)
@skipIfXpu
def test_mlp_custom_op_autotune(self):
"""Test MLP autotuning with method parameter controlling different decomposition variants.
Validates parametric tuning where the same decomposition function uses different
algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights).
"""
test_op_name = f"test_lib::mlp_{id(self)}"
def mlp_variants(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
"""MLP implementation with different computational approaches controlled by method parameter."""
if method == 0:
gate_proj = torch.matmul(input_tensor, gate_weight)
up_proj = torch.matmul(input_tensor, up_weight)
gated = torch.relu(gate_proj) * up_proj
return torch.matmul(gated, down_weight)
elif method == 1:
batch_shape = input_tensor.shape[:-1]
hidden_dim = input_tensor.shape[-1]
output_dim = down_weight.shape[-1]
input_2d = input_tensor.view(-1, hidden_dim)
gate_proj = torch.mm(input_2d, gate_weight)
up_proj = torch.mm(input_2d, up_weight)
gated = torch.relu(gate_proj) * up_proj
output_2d = torch.mm(gated, down_weight)
return output_2d.view(*batch_shape, output_dim)
@torch.library.custom_op(test_op_name, mutates_args=())
def test_mlp_op(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
) -> torch.Tensor:
return mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=method
)
@test_mlp_op.register_fake
def _(
input_tensor: torch.Tensor,
gate_weight: torch.Tensor,
up_weight: torch.Tensor,
down_weight: torch.Tensor,
method: int = 0,
):
return torch.empty(
input_tensor.shape[:-1] + (down_weight.shape[-1],),
device=input_tensor.device,
dtype=input_tensor.dtype,
)
# Use explicit config with method parameter as tuning knob
register_custom_op_autotuning(
test_mlp_op,
configs=[
CustomOpConfig(method=0),
CustomOpConfig(method=1),
],
name="test_mlp_autotuned",
input_gen_fns={
"input_tensor": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"gate_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"up_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
"down_weight": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.05,
},
)
# Create test inputs
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
# Test that all method variants produce numerically equivalent results
expected = mlp_variants(
input_tensor, gate_weight, up_weight, down_weight, method=0
)
# Test autotuning
self._run_autotune_test(
test_mlp_op,
(input_tensor, gate_weight, up_weight, down_weight),
expected,
"MLP",
)
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
return a, b
@skipIfXpu
def test_decompose_k_custom_op_autotune(self):
"""Test decompose_k autotuning with parametric tuning for k_splits values.
Validates numerical parameter sweep where k_splits controls how the K dimension
is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]).
"""
test_op_name = f"test_lib::decompose_k_{id(self)}"
def decompose_k_implementation(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - Python implementation."""
m = a.shape[0]
n = b.shape[1]
k = a.shape[1]
k_parts = k // k_splits
B = k_splits
a_reshaped = torch.permute(
a.reshape(m, B, k_parts), (1, 0, 2)
) # [B, m, k_parts]
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
return torch.sum(result, dim=0) # [m, n]
@torch.library.custom_op(test_op_name, mutates_args=())
def test_decompose_k_op(
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
) -> torch.Tensor:
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
return decompose_k_implementation(a, b, k_splits)
@test_decompose_k_op.register_fake
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
# Register autotuning with different k_splits values using decomposition function
register_custom_op_autotuning(
test_decompose_k_op,
configs=[
CustomOpConfig(k_splits=2),
CustomOpConfig(k_splits=4),
CustomOpConfig(k_splits=8),
CustomOpConfig(k_splits=16),
CustomOpConfig(k_splits=32),
CustomOpConfig(k_splits=64),
CustomOpConfig(k_splits=128),
],
name="test_decompose_k_autotuned",
input_gen_fns={
"a": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
"b": lambda fake_tensor: torch.randn_like(
fake_tensor, device=self.device
)
* 0.1,
},
)
a, b = self._create_decompose_k_inputs()
expected = a @ b
self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK")
@skipIfXpu
def test_multi_parameter_tuning(self):
"""Test autotuning with multiple parameters for combinatorial parameter exploration.
Validates parametric tuning with multiple parameters (scale_mode and chunk_size)
to test combinatorial exploration of the parameter space.
"""
test_op_name = f"test_lib::multi_param_{id(self)}"
def multi_param_scaling(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
"""Different scaling approaches controlled by scale_mode parameter."""
if scale_mode == 1:
# Simple broadcasting
return x * factor
elif scale_mode == 2:
# Process in chunks
batch_size, seq_len = x.shape[:2]
chunks = []
for start in range(0, seq_len, chunk_size):
end = min(start + chunk_size, seq_len)
chunk = x[:, start:end]
chunks.append(chunk * factor)
return torch.cat(chunks, dim=1)
elif scale_mode == 3:
# Using einsum for scaling
return torch.einsum("...i,i->...i", x, factor)
@torch.library.custom_op(test_op_name, mutates_args=())
def multi_param_op(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
) -> torch.Tensor:
return multi_param_scaling(x, factor, scale_mode, chunk_size)
@multi_param_op.register_fake
def _(
x: torch.Tensor,
factor: torch.Tensor,
scale_mode: int = 1,
chunk_size: int = 16,
):
return torch.empty_like(x)
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
register_custom_op_autotuning(
multi_param_op,
configs=[
CustomOpConfig(scale_mode=1), # Broadcast
CustomOpConfig(scale_mode=2, chunk_size=16), # Chunked 16
CustomOpConfig(scale_mode=2, chunk_size=32), # Chunked 32
CustomOpConfig(scale_mode=3), # Einsum
],
name="multi_param_autotuned",
input_gen_fns={
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
"factor": lambda t: torch.ones(
t.shape[-1], device=self.device, dtype=t.dtype
),
},
)
# Create test inputs
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
# Verify numerical equivalence across all approaches
expected_result = test_x * test_factor
# Test each scale_mode variant
configs = [
(1, 16), # broadcast, chunk_size ignored
(2, 16), # chunked with size 16
(2, 32), # chunked with size 32
(3, 16), # einsum, chunk_size ignored
]
for scale_mode, chunk_size in configs:
result = multi_param_scaling(
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
)
torch.testing.assert_close(
result,
expected_result,
rtol=1e-5,
atol=1e-5,
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
)
# Test autotuning
self._run_autotune_test(
multi_param_op, (test_x, test_factor), expected_result, "MultiParam"
)
if __name__ == "__main__":
run_tests()