mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This PR removes unused loop variables in tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166690 Approved by: https://github.com/justinchuby, https://github.com/mlazos
508 lines
18 KiB
Python
508 lines
18 KiB
Python
# Owner(s): ["module: inductor"]
|
|
"""
|
|
Tests for custom operation autotuning with PyTorch Inductor.
|
|
|
|
Validates that custom ops can be registered with multiple CustomOpConfigs, where each
|
|
config specifies an optional decomposition function and its associated parameters.
|
|
Inductor benchmarks all variants and automatically selects the best performing one.
|
|
"""
|
|
|
|
import torch
|
|
from torch._inductor import config
|
|
from torch._inductor.kernel.custom_op import (
|
|
CustomOpConfig,
|
|
register_custom_op_autotuning,
|
|
)
|
|
from torch._inductor.test_case import run_tests, TestCase
|
|
from torch.testing._internal.common_utils import skipIfXpu
|
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
|
|
class TestCustomOpAutoTune(TestCase):
|
|
"""Test custom operation autotuning functionality."""
|
|
|
|
def setUp(self) -> None:
|
|
"""Set up test environment with appropriate device and dtype."""
|
|
super().setUp()
|
|
self.device = "cuda" if HAS_GPU else "cpu"
|
|
self.dtype = torch.float16 if self.device == "cuda" else torch.float32
|
|
|
|
def _run_autotune_test(self, op_object, inputs, expected, test_name):
|
|
"""Shared test infrastructure for autotuning tests."""
|
|
|
|
@torch.compile
|
|
def test_model(*args):
|
|
return op_object(*args)
|
|
|
|
torch._dynamo.reset()
|
|
autotune_backends = "TRITON" if self.device == "cuda" else "ATEN"
|
|
|
|
with config.patch(
|
|
max_autotune=True,
|
|
max_autotune_gemm_backends=autotune_backends,
|
|
fx_graph_cache=False,
|
|
benchmark_kernel=True,
|
|
):
|
|
compiled_result = test_model(*inputs)
|
|
|
|
self.assertEqual(
|
|
compiled_result.shape, expected.shape, f"{test_name} shape mismatch"
|
|
)
|
|
torch.testing.assert_close(
|
|
compiled_result,
|
|
expected,
|
|
rtol=2e-1,
|
|
atol=5e-1,
|
|
msg=f"{test_name} numerical mismatch",
|
|
)
|
|
|
|
def _assert_implementations_equivalent(self, decompositions, inputs, op_name):
|
|
"""Utility to assert that all implementations produce equivalent results."""
|
|
implementations = [(func.__name__, func) for func in decompositions]
|
|
results = {}
|
|
for name, impl in implementations:
|
|
result = impl(*inputs)
|
|
results[name] = result
|
|
|
|
# Basic sanity checks
|
|
self.assertTrue(
|
|
torch.isfinite(result).all(),
|
|
f"{op_name} {name} produced non-finite values",
|
|
)
|
|
|
|
# Verify numerical equivalence
|
|
reference_name, reference_result = next(iter(results.items()))
|
|
for name, result in results.items():
|
|
if name != reference_name:
|
|
rtol = 1e-1 if "Approximated" in name else 1e-2
|
|
atol = 1e-1 if "Approximated" in name else 1e-2
|
|
torch.testing.assert_close(
|
|
result,
|
|
reference_result,
|
|
rtol=rtol,
|
|
atol=atol,
|
|
msg=f"{op_name} {name} differs from {reference_name}",
|
|
)
|
|
|
|
def _create_rmsnorm_inputs(self, batch_size=32, seq_len=2048, hidden_dim=512):
|
|
"""Create test inputs for RMSNorm operations."""
|
|
input_tensor = torch.randn(
|
|
batch_size,
|
|
seq_len,
|
|
hidden_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=False,
|
|
)
|
|
weight = torch.randn(
|
|
hidden_dim, device=self.device, dtype=self.dtype, requires_grad=False
|
|
)
|
|
return input_tensor, weight
|
|
|
|
def _create_mlp_inputs(
|
|
self,
|
|
batch_size=2,
|
|
seq_len=32,
|
|
hidden_dim=512,
|
|
intermediate_dim=1024,
|
|
output_dim=256,
|
|
):
|
|
"""Create test inputs for MLP operations."""
|
|
input_tensor = torch.randn(
|
|
batch_size,
|
|
seq_len,
|
|
hidden_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=False,
|
|
)
|
|
gate_weight = torch.randn(
|
|
hidden_dim,
|
|
intermediate_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=False,
|
|
)
|
|
up_weight = torch.randn(
|
|
hidden_dim,
|
|
intermediate_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=False,
|
|
)
|
|
down_weight = torch.randn(
|
|
intermediate_dim,
|
|
output_dim,
|
|
device=self.device,
|
|
dtype=self.dtype,
|
|
requires_grad=False,
|
|
)
|
|
return input_tensor, gate_weight, up_weight, down_weight
|
|
|
|
@skipIfXpu
|
|
def test_rmsnorm_custom_op_autotune_with_dynamic_shape(self):
|
|
"""Test RMSNorm autotuning with multiple decomposition variants and dynamic shapes.
|
|
|
|
Validates:
|
|
- Multiple decomposition implementations with different computational approaches
|
|
- Dynamic shape handling across multiple compilations
|
|
"""
|
|
test_op_name = f"test_lib::rmsnorm_{id(self)}"
|
|
|
|
def rmsnorm_decomposition1(
|
|
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
|
) -> torch.Tensor:
|
|
"""Variance-based approach: compute variance then rsqrt."""
|
|
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
|
rstd = torch.rsqrt(variance + eps)
|
|
return x * rstd * weight
|
|
|
|
def rmsnorm_decomposition2(
|
|
x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
|
) -> torch.Tensor:
|
|
"""Separate normalization and scaling: compute normalized value then scale."""
|
|
x_var = x
|
|
variance = x_var.pow(2).mean(dim=-1, keepdim=True)
|
|
x = x * torch.rsqrt(variance + eps)
|
|
x = x * weight
|
|
return x
|
|
|
|
@torch.library.custom_op(test_op_name, mutates_args=())
|
|
def test_rmsnorm_op(
|
|
input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8
|
|
) -> torch.Tensor:
|
|
return torch.nn.functional.rms_norm(
|
|
input_tensor, input_tensor.shape[-1:], weight, eps=eps
|
|
)
|
|
|
|
@test_rmsnorm_op.register_fake
|
|
def _(input_tensor: torch.Tensor, weight: torch.Tensor, eps: float = 1e-8):
|
|
return torch.empty_like(input_tensor)
|
|
|
|
decompositions = [
|
|
rmsnorm_decomposition1,
|
|
rmsnorm_decomposition2,
|
|
]
|
|
|
|
register_custom_op_autotuning(
|
|
test_rmsnorm_op,
|
|
configs=[CustomOpConfig(decomp) for decomp in decompositions],
|
|
name="test_rmsnorm_autotuned",
|
|
input_gen_fns={
|
|
"x": lambda x: torch.randn_like(x, device=self.device) * 0.02,
|
|
"weight": lambda weight: torch.ones_like(weight, device=self.device),
|
|
},
|
|
)
|
|
|
|
# Test multiple shapes to verify dynamic shape handling
|
|
test_shapes = [(2, 16, 128), (8, 32, 256)]
|
|
|
|
for i, (batch_size, seq_len, hidden_dim) in enumerate(test_shapes):
|
|
input_tensor, weight = self._create_rmsnorm_inputs(
|
|
batch_size, seq_len, hidden_dim
|
|
)
|
|
|
|
# Test numerical equivalence for all decompositions
|
|
self._assert_implementations_equivalent(
|
|
decompositions, (input_tensor, weight), f"RMSNorm_{i}"
|
|
)
|
|
|
|
# Test autotuning
|
|
expected = rmsnorm_decomposition1(input_tensor, weight)
|
|
self._run_autotune_test(
|
|
test_rmsnorm_op, (input_tensor, weight), expected, f"RMSNorm_{i}"
|
|
)
|
|
|
|
@skipIfXpu
|
|
def test_mlp_custom_op_autotune(self):
|
|
"""Test MLP autotuning with method parameter controlling different decomposition variants.
|
|
|
|
Validates parametric tuning where the same decomposition function uses different
|
|
algorithmic approaches based on a method parameter (standard matmul, batched mm, fused weights).
|
|
"""
|
|
test_op_name = f"test_lib::mlp_{id(self)}"
|
|
|
|
def mlp_variants(
|
|
input_tensor: torch.Tensor,
|
|
gate_weight: torch.Tensor,
|
|
up_weight: torch.Tensor,
|
|
down_weight: torch.Tensor,
|
|
method: int = 0,
|
|
) -> torch.Tensor:
|
|
"""MLP implementation with different computational approaches controlled by method parameter."""
|
|
|
|
if method == 0:
|
|
gate_proj = torch.matmul(input_tensor, gate_weight)
|
|
up_proj = torch.matmul(input_tensor, up_weight)
|
|
gated = torch.relu(gate_proj) * up_proj
|
|
return torch.matmul(gated, down_weight)
|
|
|
|
elif method == 1:
|
|
batch_shape = input_tensor.shape[:-1]
|
|
hidden_dim = input_tensor.shape[-1]
|
|
output_dim = down_weight.shape[-1]
|
|
|
|
input_2d = input_tensor.view(-1, hidden_dim)
|
|
|
|
gate_proj = torch.mm(input_2d, gate_weight)
|
|
up_proj = torch.mm(input_2d, up_weight)
|
|
|
|
gated = torch.relu(gate_proj) * up_proj
|
|
output_2d = torch.mm(gated, down_weight)
|
|
|
|
return output_2d.view(*batch_shape, output_dim)
|
|
|
|
@torch.library.custom_op(test_op_name, mutates_args=())
|
|
def test_mlp_op(
|
|
input_tensor: torch.Tensor,
|
|
gate_weight: torch.Tensor,
|
|
up_weight: torch.Tensor,
|
|
down_weight: torch.Tensor,
|
|
method: int = 0,
|
|
) -> torch.Tensor:
|
|
return mlp_variants(
|
|
input_tensor, gate_weight, up_weight, down_weight, method=method
|
|
)
|
|
|
|
@test_mlp_op.register_fake
|
|
def _(
|
|
input_tensor: torch.Tensor,
|
|
gate_weight: torch.Tensor,
|
|
up_weight: torch.Tensor,
|
|
down_weight: torch.Tensor,
|
|
method: int = 0,
|
|
):
|
|
return torch.empty(
|
|
input_tensor.shape[:-1] + (down_weight.shape[-1],),
|
|
device=input_tensor.device,
|
|
dtype=input_tensor.dtype,
|
|
)
|
|
|
|
# Use explicit config with method parameter as tuning knob
|
|
register_custom_op_autotuning(
|
|
test_mlp_op,
|
|
configs=[
|
|
CustomOpConfig(method=0),
|
|
CustomOpConfig(method=1),
|
|
],
|
|
name="test_mlp_autotuned",
|
|
input_gen_fns={
|
|
"input_tensor": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.1,
|
|
"gate_weight": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.05,
|
|
"up_weight": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.05,
|
|
"down_weight": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.05,
|
|
},
|
|
)
|
|
|
|
# Create test inputs
|
|
input_tensor, gate_weight, up_weight, down_weight = self._create_mlp_inputs()
|
|
|
|
# Test that all method variants produce numerically equivalent results
|
|
expected = mlp_variants(
|
|
input_tensor, gate_weight, up_weight, down_weight, method=0
|
|
)
|
|
|
|
# Test autotuning
|
|
self._run_autotune_test(
|
|
test_mlp_op,
|
|
(input_tensor, gate_weight, up_weight, down_weight),
|
|
expected,
|
|
"MLP",
|
|
)
|
|
|
|
def _create_decompose_k_inputs(self, m=256, k=65536, n=1024):
|
|
"""Create test inputs for decompose_k matrix multiplication - divisible by all k_splits values."""
|
|
# Ensure k is divisible by all k_splits values: [2, 32, 64, 128, 256]
|
|
k = ((k + 255) // 256) * 256 # Round up to nearest multiple of 256
|
|
a = torch.randn(m, k, device=self.device, dtype=self.dtype, requires_grad=False)
|
|
b = torch.randn(k, n, device=self.device, dtype=self.dtype, requires_grad=False)
|
|
return a, b
|
|
|
|
@skipIfXpu
|
|
def test_decompose_k_custom_op_autotune(self):
|
|
"""Test decompose_k autotuning with parametric tuning for k_splits values.
|
|
|
|
Validates numerical parameter sweep where k_splits controls how the K dimension
|
|
is decomposed for matrix multiplication (k_splits in [32, 64, 128, 256]).
|
|
"""
|
|
test_op_name = f"test_lib::decompose_k_{id(self)}"
|
|
|
|
def decompose_k_implementation(
|
|
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
|
) -> torch.Tensor:
|
|
"""Matrix multiply with k-way decomposition - Python implementation."""
|
|
m = a.shape[0]
|
|
n = b.shape[1]
|
|
k = a.shape[1]
|
|
|
|
k_parts = k // k_splits
|
|
B = k_splits
|
|
|
|
a_reshaped = torch.permute(
|
|
a.reshape(m, B, k_parts), (1, 0, 2)
|
|
) # [B, m, k_parts]
|
|
b_reshaped = b.reshape(B, k_parts, n) # [B, k_parts, n]
|
|
|
|
result = torch.bmm(a_reshaped, b_reshaped) # [B, m, n]
|
|
|
|
return torch.sum(result, dim=0) # [m, n]
|
|
|
|
@torch.library.custom_op(test_op_name, mutates_args=())
|
|
def test_decompose_k_op(
|
|
a: torch.Tensor, b: torch.Tensor, k_splits: int = 4
|
|
) -> torch.Tensor:
|
|
"""Matrix multiply with k-way decomposition - custom op using the decomposition."""
|
|
return decompose_k_implementation(a, b, k_splits)
|
|
|
|
@test_decompose_k_op.register_fake
|
|
def _(a: torch.Tensor, b: torch.Tensor, k_splits: int = 4):
|
|
return torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=a.dtype)
|
|
|
|
# Register autotuning with different k_splits values using decomposition function
|
|
register_custom_op_autotuning(
|
|
test_decompose_k_op,
|
|
configs=[
|
|
CustomOpConfig(k_splits=2),
|
|
CustomOpConfig(k_splits=4),
|
|
CustomOpConfig(k_splits=8),
|
|
CustomOpConfig(k_splits=16),
|
|
CustomOpConfig(k_splits=32),
|
|
CustomOpConfig(k_splits=64),
|
|
CustomOpConfig(k_splits=128),
|
|
],
|
|
name="test_decompose_k_autotuned",
|
|
input_gen_fns={
|
|
"a": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.1,
|
|
"b": lambda fake_tensor: torch.randn_like(
|
|
fake_tensor, device=self.device
|
|
)
|
|
* 0.1,
|
|
},
|
|
)
|
|
|
|
a, b = self._create_decompose_k_inputs()
|
|
expected = a @ b
|
|
self._run_autotune_test(test_decompose_k_op, (a, b), expected, "DecomposeK")
|
|
|
|
@skipIfXpu
|
|
def test_multi_parameter_tuning(self):
|
|
"""Test autotuning with multiple parameters for combinatorial parameter exploration.
|
|
|
|
Validates parametric tuning with multiple parameters (scale_mode and chunk_size)
|
|
to test combinatorial exploration of the parameter space.
|
|
"""
|
|
test_op_name = f"test_lib::multi_param_{id(self)}"
|
|
|
|
def multi_param_scaling(
|
|
x: torch.Tensor,
|
|
factor: torch.Tensor,
|
|
scale_mode: int = 1,
|
|
chunk_size: int = 16,
|
|
) -> torch.Tensor:
|
|
"""Different scaling approaches controlled by scale_mode parameter."""
|
|
if scale_mode == 1:
|
|
# Simple broadcasting
|
|
return x * factor
|
|
elif scale_mode == 2:
|
|
# Process in chunks
|
|
batch_size, seq_len = x.shape[:2]
|
|
chunks = []
|
|
for start in range(0, seq_len, chunk_size):
|
|
end = min(start + chunk_size, seq_len)
|
|
chunk = x[:, start:end]
|
|
chunks.append(chunk * factor)
|
|
return torch.cat(chunks, dim=1)
|
|
elif scale_mode == 3:
|
|
# Using einsum for scaling
|
|
return torch.einsum("...i,i->...i", x, factor)
|
|
|
|
@torch.library.custom_op(test_op_name, mutates_args=())
|
|
def multi_param_op(
|
|
x: torch.Tensor,
|
|
factor: torch.Tensor,
|
|
scale_mode: int = 1,
|
|
chunk_size: int = 16,
|
|
) -> torch.Tensor:
|
|
return multi_param_scaling(x, factor, scale_mode, chunk_size)
|
|
|
|
@multi_param_op.register_fake
|
|
def _(
|
|
x: torch.Tensor,
|
|
factor: torch.Tensor,
|
|
scale_mode: int = 1,
|
|
chunk_size: int = 16,
|
|
):
|
|
return torch.empty_like(x)
|
|
|
|
# Use explicit configs with scale_mode and chunk_size parameters as tuning knobs
|
|
register_custom_op_autotuning(
|
|
multi_param_op,
|
|
configs=[
|
|
CustomOpConfig(scale_mode=1), # Broadcast
|
|
CustomOpConfig(scale_mode=2, chunk_size=16), # Chunked 16
|
|
CustomOpConfig(scale_mode=2, chunk_size=32), # Chunked 32
|
|
CustomOpConfig(scale_mode=3), # Einsum
|
|
],
|
|
name="multi_param_autotuned",
|
|
input_gen_fns={
|
|
"x": lambda t: torch.randn_like(t, device=self.device) * 0.1,
|
|
"factor": lambda t: torch.ones(
|
|
t.shape[-1], device=self.device, dtype=t.dtype
|
|
),
|
|
},
|
|
)
|
|
|
|
# Create test inputs
|
|
test_x = torch.randn(4, 64, 128, device=self.device, dtype=self.dtype)
|
|
test_factor = torch.ones(128, device=self.device, dtype=self.dtype) * 2.0
|
|
|
|
# Verify numerical equivalence across all approaches
|
|
expected_result = test_x * test_factor
|
|
|
|
# Test each scale_mode variant
|
|
configs = [
|
|
(1, 16), # broadcast, chunk_size ignored
|
|
(2, 16), # chunked with size 16
|
|
(2, 32), # chunked with size 32
|
|
(3, 16), # einsum, chunk_size ignored
|
|
]
|
|
|
|
for scale_mode, chunk_size in configs:
|
|
result = multi_param_scaling(
|
|
test_x, test_factor, scale_mode=scale_mode, chunk_size=chunk_size
|
|
)
|
|
torch.testing.assert_close(
|
|
result,
|
|
expected_result,
|
|
rtol=1e-5,
|
|
atol=1e-5,
|
|
msg=f"scale_mode {scale_mode} with chunk_size {chunk_size} not equivalent to expected",
|
|
)
|
|
|
|
# Test autotuning
|
|
self._run_autotune_test(
|
|
multi_param_op, (test_x, test_factor), expected_result, "MultiParam"
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|