mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ARM64][CUDA] skip string pattern matching in test_workspace_allocation_error (#149236)
`unwind()` on ARM64 seems to elide the strings of interest Pull Request resolved: https://github.com/pytorch/pytorch/pull/149236 Approved by: https://github.com/malfet, https://github.com/eellison, https://github.com/BoyuanFeng
This commit is contained in:
parent
bfee141666
commit
6048d88afe
|
|
@ -30,6 +30,7 @@ from torch.testing import FileCheck
|
|||
from torch.testing._internal.common_cuda import TEST_MULTIGPU
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
IS_ARM64,
|
||||
IS_CI,
|
||||
IS_LINUX,
|
||||
IS_WINDOWS,
|
||||
|
|
@ -1402,14 +1403,15 @@ if HAS_CUDA:
|
|||
foo(*inps)
|
||||
except Exception as e:
|
||||
thrown = True
|
||||
self.assertTrue(
|
||||
"at::cuda::blas::gemm<float>" in str(e)
|
||||
or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
|
||||
)
|
||||
self.assertTrue(
|
||||
"getCurrentCUDABlasHandle" in str(e)
|
||||
or "getNewWorkspace" in str(e)
|
||||
)
|
||||
if not IS_ARM64:
|
||||
self.assertTrue(
|
||||
"at::cuda::blas::gemm<float>" in str(e)
|
||||
or "at::cuda::blas::gemm_internal_cublas<float>" in str(e)
|
||||
)
|
||||
self.assertTrue(
|
||||
"getCurrentCUDABlasHandle" in str(e)
|
||||
or "getNewWorkspace" in str(e)
|
||||
)
|
||||
|
||||
self.assertTrue(thrown)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user