Enable additional tests for MPS CI runs (#134356)

As part of the follow up for https://github.com/pytorch/pytorch/issues/133520, adapting existing unused tests for use in MPS CI runs. Focusing on nhwc & other memory formatting tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134356
Approved by: https://github.com/malfet, https://github.com/eqy, https://github.com/huydhn
This commit is contained in:
Siddharth Kotapati 2024-10-04 21:52:35 +00:00 committed by PyTorch MergeBot
parent 7c1d93944e
commit e27c0048db
8 changed files with 144 additions and 24 deletions

View File

@ -223,6 +223,7 @@ jobs:
cache: pip
- name: Install dependencies
run: |
python3 -m pip install --upgrade pip
pip install pytest-rerunfailures==11.1.* pytest-flakefinder==1.1.* pytest-xdist==3.3.* expecttest==0.2.* fbscribelogger==0.1.* numpy==1.24.*
pip install torch --pre --index-url https://download.pytorch.org/whl/nightly/cpu/
- name: Run run_test.py (nonretryable)

View File

@ -22,6 +22,8 @@ from torch.testing._internal.common_device_type import (
disableMkldnn,
dtypes,
dtypesIfCUDA,
dtypesIfMPS,
expectedFailureMPS,
instantiate_device_type_tests,
largeTensorTest,
onlyCPU,
@ -37,6 +39,7 @@ from torch.testing._internal.common_device_type import (
skipCUDAIfRocm,
skipCUDAIfRocmVersionLessThan,
skipMeta,
skipMPS,
)
from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
@ -50,6 +53,7 @@ from torch.testing._internal.common_utils import (
GRADCHECK_NONDET_TOL,
gradgradcheck,
instantiate_parametrized_tests,
IS_MACOS,
parametrize as parametrize_test,
run_tests,
set_default_dtype,
@ -68,6 +72,13 @@ if TEST_SCIPY:
import scipy.ndimage
import scipy.signal
if IS_MACOS:
import platform
product_version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1)
else:
product_version = 0.0
class TestConvolutionNN(NNTestCase):
_do_cuda_memory_leak_check = True
@ -1677,6 +1688,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
@torch.backends.cudnn.flags(enabled=True, benchmark=False)
def test_conv1d_same_padding(self, device, dtype):
# Test padding='same' outputs the correct shape
@ -1716,6 +1730,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
actual = F.conv1d(x, y, padding="same", dilation=3)
self.assertEqual(expect, actual)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
@dtypes(torch.float, torch.cfloat)
def test_conv2d_same_padding(self, device, dtype):
if dtype is torch.cfloat:
@ -1768,6 +1785,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(expect, actual, rtol=rtol, atol=atol)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
def test_conv1d_valid_padding(self, device, dtype):
# Test F.conv1d padding='valid' is the same as no padding
x = torch.rand(1, 1, 10, device=device, dtype=dtype)
@ -1777,6 +1797,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(expect, actual)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
def test_conv2d_valid_padding(self, device, dtype):
# Test F.conv2d padding='valid' is the same as no padding
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype)
@ -1795,6 +1818,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(expect, actual)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(torch.float)
def test_conv1d_same_padding_backward(self, device, dtype):
# Test F.conv1d gradients work with padding='same'
x = torch.rand(1, 1, 12, dtype=dtype, device=device, requires_grad=True)
@ -1824,6 +1848,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(gy_expect, y.grad)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
@tf32_on_and_off(0.001)
def test_conv2d_same_padding_backward(self, device, dtype):
# Test F.conv2d gradients work with padding='same'
@ -1855,6 +1882,10 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(gy_expect, y.grad)
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(
torch.float, torch.cfloat
) # Double, complex double not supported on MPS
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
def test_conv3d_same_padding_backward(self, device, dtype):
check_forward_ad = torch.device(device).type != "xla"
@ -1915,6 +1946,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
)
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
def test_conv1d_valid_padding_backward(self, device, dtype):
# Test F.conv1d gradients work with padding='valid'
x = torch.rand(1, 1, 10, dtype=dtype, device=device, requires_grad=True)
@ -1930,6 +1964,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
@parametrize_test("mode", ("valid", "same"))
def test_conv1d_vs_scipy(self, device, dtype, mode):
t = make_tensor((1, 10), device=device, dtype=dtype)
@ -1969,6 +2006,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
@dtypes(torch.float, torch.cfloat)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
@parametrize_test("mode", ("valid", "same"))
def test_conv2d_vs_scipy(self, device, dtype, mode):
t = make_tensor((1, 5, 10), device=device, dtype=dtype)
@ -2008,6 +2048,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
_test(t, weight_odd, mode)
@unittest.skipIf(not TEST_SCIPY, "Scipy required for the test.")
@skipMPS # Results in CI are inconsistent, forced to skip
@dtypes(torch.float, torch.cfloat)
@parametrize_test("mode", ("valid", "same"))
def test_conv3d_vs_scipy(self, device, dtype, mode):
@ -2061,6 +2102,9 @@ class TestConvolutionNNDeviceType(NNTestCase):
_test(t, weight_odd, mode)
@dtypes(torch.float, torch.complex64)
@dtypesIfMPS(
*([torch.float] if product_version < 14.0 else [torch.float, torch.cfloat])
) # Complex not supported on MacOS13
def test_conv2d_valid_padding_backward(self, device, dtype):
# Test F.conv2d gradients work with padding='valid'
x = torch.rand(1, 1, 1, 10, device=device, dtype=dtype, requires_grad=True)
@ -2075,6 +2119,10 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(gy_expect, gy_actual)
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(
torch.float, torch.cfloat
) # Double, complex double not supported on MPS
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
def test_conv3d_valid_padding_backward(self, device, dtype):
check_forward_ad = torch.device(device).type != "xla"
@ -2101,7 +2149,15 @@ class TestConvolutionNNDeviceType(NNTestCase):
check_fwd_over_rev=check_forward_ad,
)
@parametrize_test("N", range(2, 4), name_fn=lambda N: f"ConvTranspose{N}d")
@parametrize_test(
arg_str="N",
arg_values=[
subtest(arg_values=(2), name="ConvTranspose2d"),
subtest(
arg_values=(3), name="ConvTranspose3d", decorators=[expectedFailureMPS]
),
],
)
def test_conv_transpose_with_output_size_and_no_batch_dim(self, device, N):
# For inputs with no batch dim, verify output is the correct shape when output_size is set.
# See https://github.com/pytorch/pytorch/issues/75889
@ -3067,6 +3123,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
input_large = torch.randn(1, 1, 2048, 1024, dtype=dtype, device=device)
conv2(input_large)
@expectedFailureMPS # ConvTranspose 3D is not supported on MPS
def test_conv_noncontig_weights(self, device):
for dim in (1, 2, 3):
for grouped in (False, True):
@ -3383,6 +3440,8 @@ class TestConvolutionNNDeviceType(NNTestCase):
)
@dtypes(torch.double, torch.cdouble)
@dtypesIfMPS(torch.float, torch.cfloat)
@expectedFailureMPS # https://github.com/pytorch/pytorch/issues/107214
def test_Conv2d_backward_depthwise(self, device, dtype):
x = torch.randn(2, 2, 4, 20, device=device, dtype=dtype, requires_grad=True)
weight = torch.randn(2, 1, 3, 5, device=device, dtype=dtype, requires_grad=True)
@ -4032,7 +4091,7 @@ class TestConvolutionNNDeviceType(NNTestCase):
self.assertEqual(yref, y)
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals())
instantiate_device_type_tests(TestConvolutionNNDeviceType, globals(), allow_mps=True)
instantiate_parametrized_tests(TestConvolutionNN)
if __name__ == "__main__":

View File

@ -9,6 +9,10 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfMPS,
expectedFailureMPS,
expectedFailureMPSPre15,
expectedFailureXLA,
instantiate_device_type_tests,
)
@ -169,6 +173,7 @@ class TestDropoutNNDeviceType(NNTestCase):
else:
self.assertNotEqual(permuted_inp, out)
@expectedFailureMPSPre15
def test_Dropout(self, device):
input = torch.empty(1000)
self._test_dropout(nn.Dropout, device, input)
@ -207,8 +212,11 @@ class TestDropoutNNDeviceType(NNTestCase):
self.assertTrue(result[b, c].count_nonzero() in (0, channel_numel))
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
def test_Dropout1d(self, device):
with set_default_dtype(torch.double):
@dtypes(torch.double)
@dtypesIfMPS(torch.float32)
@expectedFailureMPS
def test_Dropout1d(self, device, dtype):
with set_default_dtype(dtype):
N, C, L = (
random.randint(10, 15),
random.randint(10, 15),
@ -279,6 +287,7 @@ class TestDropoutNNDeviceType(NNTestCase):
self._test_dropoutNd_channel_zero(nn.Dropout2d(p=0.5, inplace=True), input)
@expectedFailureXLA # seems like freeze_rng_state is not honoured by XLA
@expectedFailureMPS # Failing on current pytorch MPS
def test_Dropout3d(self, device):
b = random.randint(1, 5)
w = random.randint(1, 5)
@ -315,7 +324,7 @@ class TestDropoutNNDeviceType(NNTestCase):
self.assertEqual(out.size(), x.size())
instantiate_device_type_tests(TestDropoutNNDeviceType, globals())
instantiate_device_type_tests(TestDropoutNNDeviceType, globals(), allow_mps=True)
instantiate_parametrized_tests(TestDropoutNN)
if __name__ == "__main__":

View File

@ -20,7 +20,9 @@ from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
dtypesIfMPS,
expectedFailureMeta,
expectedFailureMPS,
instantiate_device_type_tests,
largeTensorTest,
onlyCPU,
@ -41,7 +43,6 @@ from torch.testing._internal.common_utils import (
parametrize as parametrize_test,
run_tests,
set_default_dtype,
skipIfMps,
skipIfTorchDynamo,
slowTest,
subtest,
@ -818,6 +819,7 @@ torch.cuda.synchronize()
inp = torch.randn(16, 0, 20, 32, device=device)
avgpool(inp)
@expectedFailureMPS # max_pool3d_with_indices not supported on MPS
def test_pooling_shape(self, device):
"""Test the output shape calculation for pooling functions"""
@ -1328,6 +1330,8 @@ torch.cuda.synchronize()
helper(1, 19, 20, 10, 8, 2, torch.channels_last)
@dtypes(torch.float, torch.double)
@dtypesIfMPS(torch.float)
@expectedFailureMPS # test_adaptive_pooling_max_nhwc currently fails on MPS - ISSUE#
def test_adaptive_pooling_max_nhwc(self, device, dtype):
def helper(input_size, output_plane_size, contig):
n_plane_dims = len(output_plane_size)
@ -1379,6 +1383,8 @@ torch.cuda.synchronize()
helper((2, 1, 3, 3, 3), (1, 1, 1), contig)
@dtypes(torch.float, torch.double)
@dtypesIfMPS(torch.float)
@expectedFailureMPS # test_pooling_max_nhwc currently fails on MPS - ISSUE#
def test_pooling_max_nhwc(self, device, dtype):
def helper(n, c, h, w, kernel_size, stride, padding, dilation, contig, device):
output_height = math.floor(
@ -1585,32 +1591,30 @@ torch.cuda.synchronize()
def test_MaxPool2d_indices(self, device, dtype):
self._test_maxpool_indices(2, device=device, dtype=dtype)
@skipIfMps
@expectedFailureMPS
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@dtypes(torch.float)
def test_MaxPool3d_indices(self, device, dtype):
self._test_maxpool_indices(3, device=device, dtype=dtype)
@skipIfMps
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@dtypes(torch.float)
def test_AdaptiveMaxPool1d_indices(self, device, dtype):
self._test_maxpool_indices(1, adaptive=True, device=device, dtype=dtype)
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@skipIfMps
@dtypes(torch.float)
def test_AdaptiveMaxPool2d_indices(self, device, dtype):
self._test_maxpool_indices(2, adaptive=True, device=device, dtype=dtype)
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@skipIfMps
@expectedFailureMPS
@dtypes(torch.float)
def test_AdaptiveMaxPool3d_indices(self, device, dtype):
self._test_maxpool_indices(3, adaptive=True, device=device, dtype=dtype)
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@skipIfMps
@expectedFailureMPS
@dtypes(torch.float)
def test_maxpool_indices_no_batch_dim(self, device, dtype):
"""Check that indices with no batch dim is consistent with a single batch."""
@ -1831,7 +1835,7 @@ torch.cuda.synchronize()
)
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@skipIfMps
@expectedFailureMPS
@dtypes(torch.float)
def test_pool_large_size(self, device, dtype):
for op in ("max", "avg"):
@ -1864,7 +1868,7 @@ torch.cuda.synchronize()
helper(nn.AdaptiveAvgPool2d((2**6, 2**6)))
@dtypesIfCUDA(*floating_types_and(torch.half, torch.bfloat16))
@skipIfMps
@expectedFailureMPS
@dtypes(torch.float)
def test_pool_invalid_size(self, device, dtype):
for op in ("max", "avg"):
@ -1926,6 +1930,7 @@ torch.cuda.synchronize()
prec=0.05,
)
@expectedFailureMPS # max_pool3d_with_indices not supported on MPS device
def test_maxpool3d_non_square_backward(self, device):
# previous CUDA routine of this backward calculates kernel launch grid size
# with last two dimensions interchanged, so the tailing along the longer dim
@ -1950,7 +1955,7 @@ torch.cuda.synchronize()
imgs_ = F.adaptive_max_pool3d(imgs, (Od, Oh, Ow))
instantiate_device_type_tests(TestPoolingNNDeviceType, globals())
instantiate_device_type_tests(TestPoolingNNDeviceType, globals(), allow_mps=True)
instantiate_parametrized_tests(TestPoolingNN)
if __name__ == "__main__":

View File

@ -1417,7 +1417,16 @@ def get_selected_tests(options) -> List[str]:
options.exclude.extend(CPP_TESTS)
if options.mps:
selected_tests = ["test_mps", "test_metal", "test_modules", "test_nn"]
selected_tests = [
"test_mps",
"test_metal",
"test_modules",
"nn/test_convolution",
"nn/test_dropout",
"nn/test_pooling",
"test_view_ops",
"test_nn",
]
else:
# Exclude all mps tests otherwise
options.exclude.extend(["test_mps", "test_metal"])

View File

@ -8970,8 +8970,8 @@ class TestNNDeviceType(NNTestCase):
else:
self.assertEqual(hx.grad, hx_device.grad)
@dtypesIfMPS(torch.float)
@dtypes(torch.double)
@dtypesIfMPS(torch.float)
def test_BatchNorm_empty(self, device, dtype):
mod = torch.nn.BatchNorm2d(3).to(device)
inp = torch.randn(0, 3, 2, 2, device=device, dtype=dtype)
@ -9000,8 +9000,12 @@ class TestNNDeviceType(NNTestCase):
def test_one_hot(self, device):
# cuda throws device assert for invalid data
# xla ignores out of bound indices
if self.device_type not in ('cuda', 'mps', 'xla'):
# xla & mps ignore out of bound indices
if (
self.device_type != 'cuda'
and self.device_type != 'xla'
and self.device_type != 'mps'
):
with self.assertRaises(RuntimeError):
torch.nn.functional.one_hot(torch.tensor([3, 4, -1, 0], device=device), -1)
@ -12721,6 +12725,8 @@ if __name__ == '__main__':
def test_clip_grad_value(self, foreach, device):
if torch.device(device).type == 'xla' and foreach:
raise SkipTest('foreach not supported on XLA')
if torch.device(device).type == 'mps' and foreach:
raise SkipTest('foreach not supported on MPS')
l = nn.Linear(10, 10).to(device)
clip_value = 2.5
@ -12750,6 +12756,8 @@ if __name__ == '__main__':
def test_clip_grad_norm(self, norm_type, foreach, device):
if torch.device(device).type == 'xla' and foreach:
raise SkipTest('foreach not supported on XLA')
if torch.device(device).type == 'mps' and foreach:
raise SkipTest('foreach not supported on MPS')
l = nn.Linear(10, 10).to(device)
max_norm = 2

View File

@ -10,9 +10,11 @@ import torch
from torch.testing import make_tensor
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfMPS,
instantiate_device_type_tests,
onlyCPU,
onlyNativeDeviceTypes,
onlyNativeDeviceTypesAnd,
skipLazy,
skipMeta,
skipXLA,
@ -22,6 +24,7 @@ from torch.testing._internal.common_dtype import (
all_types_and_complex_and,
complex_types,
floating_and_complex_types_and,
integral_types_and,
)
from torch.testing._internal.common_utils import (
gradcheck,
@ -395,6 +398,9 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
@dtypesIfMPS(
*integral_types_and(torch.half, torch.bfloat16, torch.bool, torch.float32)
)
def test_view_tensor_split(self, device, dtype):
a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9)
a_split_dim0 = a.tensor_split(7, 0)
@ -434,8 +440,9 @@ class TestViewOps(TestCase):
t[2, 2, 2] = 7
self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2])
@onlyNativeDeviceTypes
@onlyNativeDeviceTypesAnd("mps")
@dtypes(*all_types_and(torch.half, torch.bfloat16))
@dtypesIfMPS(*integral_types_and(torch.half, torch.bool, torch.float32))
def test_imag_noncomplex(self, device, dtype):
t = torch.ones((5, 5), dtype=dtype, device=device)
@ -2030,7 +2037,7 @@ class TestOldViewOps(TestCase):
t.col_indices()
instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True)
instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True, allow_mps=True)
instantiate_device_type_tests(TestOldViewOps, globals())
if __name__ == "__main__":

View File

@ -1646,10 +1646,6 @@ def expectedFailureMeta(fn):
return skipIfTorchDynamo()(expectedFailure("meta")(fn))
def expectedFailureMPS(fn):
return expectedFailure("mps")(fn)
def expectedFailureXLA(fn):
return expectedFailure("xla")(fn)
@ -1658,6 +1654,32 @@ def expectedFailureHPU(fn):
return expectedFailure("hpu")(fn)
def expectedFailureMPS(fn):
return expectedFailure("mps")(fn)
def expectedFailureMPSPre15(fn):
import platform
version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1)
if not version or version < 1.0: # cpu or other unsupported device
return fn
if version < 15.0:
return expectedFailure("mps")(fn)
return fn
def expectedFailureMPSPre14(fn):
import platform
version = float(".".join(platform.mac_ver()[0].split(".")[:2]) or -1)
if not version or version < 1.0: # cpu or other unsupported device
return fn
if version < 14.0:
return expectedFailure("mps")(fn)
return fn
# Skips a test on CPU if LAPACK is not available.
def skipCPUIfNoLapack(fn):
return skipCPUIf(not torch._C.has_lapack, "PyTorch compiled without Lapack")(fn)