mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor CPUReproTests to be more vector-length agnostic (#141245)
This changes the hardcoded assumptions of a `256-bit` vector length to querying from `cpu_vec_isa` and changes relevant tests to share the logic. Also refactored the `config.cpp.simdlen != 1` into the assertion so we stop duplicating it throughout the test cases. Fixes issues on `128-bit` machines. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141245 Approved by: https://github.com/desertfire, https://github.com/malfet
This commit is contained in:
parent
dcd9de79e7
commit
40e27fbcf2
|
|
@ -69,14 +69,28 @@ requires_vectorization = unittest.skipUnless(
|
|||
)
|
||||
|
||||
|
||||
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
||||
if (
|
||||
def _can_check_vec_metrics():
|
||||
return (
|
||||
cpu_vec_isa.valid_vec_isa_list()
|
||||
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
|
||||
):
|
||||
and config.cpp.simdlen != 1
|
||||
)
|
||||
|
||||
|
||||
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
||||
if _can_check_vec_metrics():
|
||||
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
|
||||
|
||||
|
||||
def simd_lengths_to_test():
|
||||
"""Returns a minimal list of simd lengths to cover common cases"""
|
||||
simdlens = [None, 1]
|
||||
valid_isa_list = cpu_vec_isa.valid_vec_isa_list()
|
||||
if valid_isa_list:
|
||||
simdlens.append(valid_isa_list[0].bit_width())
|
||||
return simdlens
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_num_threads(num_threads):
|
||||
orig_num_threads = torch.get_num_threads()
|
||||
|
|
@ -3289,14 +3303,13 @@ class CPUReproTests(TestCase):
|
|||
x = x.view(batchsize, -1, height, width)
|
||||
return x.contiguous(memory_format=torch.channels_last)
|
||||
|
||||
for simdlen in (None, 256, 1):
|
||||
for simdlen in simd_lengths_to_test():
|
||||
with config.patch({"cpp.simdlen": simdlen}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
x = torch.randn(64, 58, 28, 28)
|
||||
self.common(channel_shuffle, (x, 2))
|
||||
if simdlen != 1:
|
||||
check_metrics_vec_kernel_count(2)
|
||||
check_metrics_vec_kernel_count(2)
|
||||
|
||||
@slowTest
|
||||
@requires_vectorization
|
||||
|
|
@ -3324,15 +3337,14 @@ class CPUReproTests(TestCase):
|
|||
return self.fc(y)
|
||||
|
||||
x = torch.randn(128, 196, 256)
|
||||
for simdlen in (None, 256, 1):
|
||||
for simdlen in simd_lengths_to_test():
|
||||
with config.patch({"cpp.simdlen": simdlen}):
|
||||
for eval_mode in [True, False]:
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
m = Model().eval() if eval_mode else Model()
|
||||
self.common(m, (x,))
|
||||
if simdlen != 1:
|
||||
check_metrics_vec_kernel_count(8)
|
||||
check_metrics_vec_kernel_count(8)
|
||||
|
||||
@requires_vectorization
|
||||
@config.patch("cpp.enable_tiling_heuristics", False)
|
||||
|
|
@ -3340,7 +3352,7 @@ class CPUReproTests(TestCase):
|
|||
def fn(a):
|
||||
return a.t().contiguous()
|
||||
|
||||
for simdlen in (None, 256, 1):
|
||||
for simdlen in simd_lengths_to_test():
|
||||
with config.patch({"cpp.simdlen": simdlen}):
|
||||
for dtype in (torch.float, torch.bfloat16):
|
||||
for shape in (
|
||||
|
|
@ -3356,8 +3368,7 @@ class CPUReproTests(TestCase):
|
|||
metrics.reset()
|
||||
x = torch.randn(shape, dtype=dtype)
|
||||
self.common(fn, (x,))
|
||||
if simdlen != 1:
|
||||
check_metrics_vec_kernel_count(2)
|
||||
check_metrics_vec_kernel_count(2)
|
||||
|
||||
@torch._dynamo.config.patch(specialize_int=False)
|
||||
def test_slice_scatter_issue122291(self):
|
||||
|
|
@ -4859,7 +4870,8 @@ class CPUReproTests(TestCase):
|
|||
return float32.to(torch.int64)
|
||||
|
||||
x = torch.full((32,), -9223372036854775808, dtype=torch.int64)
|
||||
for simdlen in (None, 256):
|
||||
|
||||
for simdlen in simd_lengths_to_test():
|
||||
with config.patch({"cpp.simdlen": simdlen}):
|
||||
torch._dynamo.reset()
|
||||
metrics.reset()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user