mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Refactor CPUReproTests to be more vector-length agnostic (#141245)
This changes the hardcoded assumptions of a `256-bit` vector length to querying from `cpu_vec_isa` and changes relevant tests to share the logic. Also refactored the `config.cpp.simdlen != 1` into the assertion so we stop duplicating it throughout the test cases. Fixes issues on `128-bit` machines. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141245 Approved by: https://github.com/desertfire, https://github.com/malfet
This commit is contained in:
parent
dcd9de79e7
commit
40e27fbcf2
|
|
@ -69,14 +69,28 @@ requires_vectorization = unittest.skipUnless(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
def _can_check_vec_metrics():
|
||||||
if (
|
return (
|
||||||
cpu_vec_isa.valid_vec_isa_list()
|
cpu_vec_isa.valid_vec_isa_list()
|
||||||
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
|
and os.getenv("ATEN_CPU_CAPABILITY") != "default"
|
||||||
):
|
and config.cpp.simdlen != 1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def check_metrics_vec_kernel_count(num_expected_vec_kernels):
|
||||||
|
if _can_check_vec_metrics():
|
||||||
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
|
assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels
|
||||||
|
|
||||||
|
|
||||||
|
def simd_lengths_to_test():
|
||||||
|
"""Returns a minimal list of simd lengths to cover common cases"""
|
||||||
|
simdlens = [None, 1]
|
||||||
|
valid_isa_list = cpu_vec_isa.valid_vec_isa_list()
|
||||||
|
if valid_isa_list:
|
||||||
|
simdlens.append(valid_isa_list[0].bit_width())
|
||||||
|
return simdlens
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def set_num_threads(num_threads):
|
def set_num_threads(num_threads):
|
||||||
orig_num_threads = torch.get_num_threads()
|
orig_num_threads = torch.get_num_threads()
|
||||||
|
|
@ -3289,14 +3303,13 @@ class CPUReproTests(TestCase):
|
||||||
x = x.view(batchsize, -1, height, width)
|
x = x.view(batchsize, -1, height, width)
|
||||||
return x.contiguous(memory_format=torch.channels_last)
|
return x.contiguous(memory_format=torch.channels_last)
|
||||||
|
|
||||||
for simdlen in (None, 256, 1):
|
for simdlen in simd_lengths_to_test():
|
||||||
with config.patch({"cpp.simdlen": simdlen}):
|
with config.patch({"cpp.simdlen": simdlen}):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
x = torch.randn(64, 58, 28, 28)
|
x = torch.randn(64, 58, 28, 28)
|
||||||
self.common(channel_shuffle, (x, 2))
|
self.common(channel_shuffle, (x, 2))
|
||||||
if simdlen != 1:
|
check_metrics_vec_kernel_count(2)
|
||||||
check_metrics_vec_kernel_count(2)
|
|
||||||
|
|
||||||
@slowTest
|
@slowTest
|
||||||
@requires_vectorization
|
@requires_vectorization
|
||||||
|
|
@ -3324,15 +3337,14 @@ class CPUReproTests(TestCase):
|
||||||
return self.fc(y)
|
return self.fc(y)
|
||||||
|
|
||||||
x = torch.randn(128, 196, 256)
|
x = torch.randn(128, 196, 256)
|
||||||
for simdlen in (None, 256, 1):
|
for simdlen in simd_lengths_to_test():
|
||||||
with config.patch({"cpp.simdlen": simdlen}):
|
with config.patch({"cpp.simdlen": simdlen}):
|
||||||
for eval_mode in [True, False]:
|
for eval_mode in [True, False]:
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
m = Model().eval() if eval_mode else Model()
|
m = Model().eval() if eval_mode else Model()
|
||||||
self.common(m, (x,))
|
self.common(m, (x,))
|
||||||
if simdlen != 1:
|
check_metrics_vec_kernel_count(8)
|
||||||
check_metrics_vec_kernel_count(8)
|
|
||||||
|
|
||||||
@requires_vectorization
|
@requires_vectorization
|
||||||
@config.patch("cpp.enable_tiling_heuristics", False)
|
@config.patch("cpp.enable_tiling_heuristics", False)
|
||||||
|
|
@ -3340,7 +3352,7 @@ class CPUReproTests(TestCase):
|
||||||
def fn(a):
|
def fn(a):
|
||||||
return a.t().contiguous()
|
return a.t().contiguous()
|
||||||
|
|
||||||
for simdlen in (None, 256, 1):
|
for simdlen in simd_lengths_to_test():
|
||||||
with config.patch({"cpp.simdlen": simdlen}):
|
with config.patch({"cpp.simdlen": simdlen}):
|
||||||
for dtype in (torch.float, torch.bfloat16):
|
for dtype in (torch.float, torch.bfloat16):
|
||||||
for shape in (
|
for shape in (
|
||||||
|
|
@ -3356,8 +3368,7 @@ class CPUReproTests(TestCase):
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
x = torch.randn(shape, dtype=dtype)
|
x = torch.randn(shape, dtype=dtype)
|
||||||
self.common(fn, (x,))
|
self.common(fn, (x,))
|
||||||
if simdlen != 1:
|
check_metrics_vec_kernel_count(2)
|
||||||
check_metrics_vec_kernel_count(2)
|
|
||||||
|
|
||||||
@torch._dynamo.config.patch(specialize_int=False)
|
@torch._dynamo.config.patch(specialize_int=False)
|
||||||
def test_slice_scatter_issue122291(self):
|
def test_slice_scatter_issue122291(self):
|
||||||
|
|
@ -4859,7 +4870,8 @@ class CPUReproTests(TestCase):
|
||||||
return float32.to(torch.int64)
|
return float32.to(torch.int64)
|
||||||
|
|
||||||
x = torch.full((32,), -9223372036854775808, dtype=torch.int64)
|
x = torch.full((32,), -9223372036854775808, dtype=torch.int64)
|
||||||
for simdlen in (None, 256):
|
|
||||||
|
for simdlen in simd_lengths_to_test():
|
||||||
with config.patch({"cpp.simdlen": simdlen}):
|
with config.patch({"cpp.simdlen": simdlen}):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
metrics.reset()
|
metrics.reset()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user