Refactor CPUReproTests to be more vector-length agnostic (#141245)

This changes the hardcoded assumptions of a `256-bit` vector length to querying from `cpu_vec_isa` and changes relevant tests to share the logic.

Also refactored the `config.cpp.simdlen != 1` into the assertion so we stop duplicating it throughout the test cases.

Fixes issues on `128-bit` machines.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141245
Approved by: https://github.com/desertfire, https://github.com/malfet
This commit is contained in:
Chris Sidebottom 2025-01-22 04:24:42 +00:00 committed by PyTorch MergeBot
parent dcd9de79e7
commit 40e27fbcf2

View File

@ -69,14 +69,28 @@ requires_vectorization = unittest.skipUnless(
) )
def check_metrics_vec_kernel_count(num_expected_vec_kernels): def _can_check_vec_metrics():
if ( return (
cpu_vec_isa.valid_vec_isa_list() cpu_vec_isa.valid_vec_isa_list()
and os.getenv("ATEN_CPU_CAPABILITY") != "default" and os.getenv("ATEN_CPU_CAPABILITY") != "default"
): and config.cpp.simdlen != 1
)
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
if _can_check_vec_metrics():
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
def simd_lengths_to_test():
"""Returns a minimal list of simd lengths to cover common cases"""
simdlens = [None, 1]
valid_isa_list = cpu_vec_isa.valid_vec_isa_list()
if valid_isa_list:
simdlens.append(valid_isa_list[0].bit_width())
return simdlens
@contextlib.contextmanager @contextlib.contextmanager
def set_num_threads(num_threads): def set_num_threads(num_threads):
orig_num_threads = torch.get_num_threads() orig_num_threads = torch.get_num_threads()
@ -3289,14 +3303,13 @@ class CPUReproTests(TestCase):
x = x.view(batchsize, -1, height, width) x = x.view(batchsize, -1, height, width)
return x.contiguous(memory_format=torch.channels_last) return x.contiguous(memory_format=torch.channels_last)
for simdlen in (None, 256, 1): for simdlen in simd_lengths_to_test():
with config.patch({"cpp.simdlen": simdlen}): with config.patch({"cpp.simdlen": simdlen}):
torch._dynamo.reset() torch._dynamo.reset()
metrics.reset() metrics.reset()
x = torch.randn(64, 58, 28, 28) x = torch.randn(64, 58, 28, 28)
self.common(channel_shuffle, (x, 2)) self.common(channel_shuffle, (x, 2))
if simdlen != 1: check_metrics_vec_kernel_count(2)
check_metrics_vec_kernel_count(2)
@slowTest @slowTest
@requires_vectorization @requires_vectorization
@ -3324,15 +3337,14 @@ class CPUReproTests(TestCase):
return self.fc(y) return self.fc(y)
x = torch.randn(128, 196, 256) x = torch.randn(128, 196, 256)
for simdlen in (None, 256, 1): for simdlen in simd_lengths_to_test():
with config.patch({"cpp.simdlen": simdlen}): with config.patch({"cpp.simdlen": simdlen}):
for eval_mode in [True, False]: for eval_mode in [True, False]:
torch._dynamo.reset() torch._dynamo.reset()
metrics.reset() metrics.reset()
m = Model().eval() if eval_mode else Model() m = Model().eval() if eval_mode else Model()
self.common(m, (x,)) self.common(m, (x,))
if simdlen != 1: check_metrics_vec_kernel_count(8)
check_metrics_vec_kernel_count(8)
@requires_vectorization @requires_vectorization
@config.patch("cpp.enable_tiling_heuristics", False) @config.patch("cpp.enable_tiling_heuristics", False)
@ -3340,7 +3352,7 @@ class CPUReproTests(TestCase):
def fn(a): def fn(a):
return a.t().contiguous() return a.t().contiguous()
for simdlen in (None, 256, 1): for simdlen in simd_lengths_to_test():
with config.patch({"cpp.simdlen": simdlen}): with config.patch({"cpp.simdlen": simdlen}):
for dtype in (torch.float, torch.bfloat16): for dtype in (torch.float, torch.bfloat16):
for shape in ( for shape in (
@ -3356,8 +3368,7 @@ class CPUReproTests(TestCase):
metrics.reset() metrics.reset()
x = torch.randn(shape, dtype=dtype) x = torch.randn(shape, dtype=dtype)
self.common(fn, (x,)) self.common(fn, (x,))
if simdlen != 1: check_metrics_vec_kernel_count(2)
check_metrics_vec_kernel_count(2)
@torch._dynamo.config.patch(specialize_int=False) @torch._dynamo.config.patch(specialize_int=False)
def test_slice_scatter_issue122291(self): def test_slice_scatter_issue122291(self):
@ -4859,7 +4870,8 @@ class CPUReproTests(TestCase):
return float32.to(torch.int64) return float32.to(torch.int64)
x = torch.full((32,), -9223372036854775808, dtype=torch.int64) x = torch.full((32,), -9223372036854775808, dtype=torch.int64)
for simdlen in (None, 256):
for simdlen in simd_lengths_to_test():
with config.patch({"cpp.simdlen": simdlen}): with config.patch({"cpp.simdlen": simdlen}):
torch._dynamo.reset() torch._dynamo.reset()
metrics.reset() metrics.reset()