from __future__ import absolute_import, division, print_function, unicode_literals import copy import unittest import torch from torch.utils import mkldnn as mkldnn_utils from common_utils import TestCase, run_tests from torch.autograd.gradcheck import gradgradcheck, gradcheck # Comment the line below to find out the CI machines having MKL-DNN build disabled @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled") class TestMkldnn(TestCase): def test_conversion(self): for cpu_tensor in [torch.randn((1, 2, 3, 4), dtype=torch.float, device=torch.device('cpu')), torch.randn((1, 2, 3, 4, 5), dtype=torch.float, device=torch.device('cpu'))[:, :, :, :, 1]]: cpu_tensor.requires_grad_() mkldnn_tensor = cpu_tensor.to_mkldnn() cpu_tensor_1 = mkldnn_tensor.to_dense() self.assertEqual(cpu_tensor, cpu_tensor_1) self.assertEqual(mkldnn_tensor.dtype, torch.float) self.assertEqual(mkldnn_tensor.device, torch.device('cpu')) self.assertEqual(mkldnn_tensor.size(), torch.Size([1, 2, 3, 4])) self.assertEqual(mkldnn_tensor.numel(), cpu_tensor.numel()) self.assertEqual(mkldnn_tensor.element_size(), cpu_tensor.element_size()) self.assertRaisesRegex(RuntimeError, "Cannot access data pointer of Tensor that doesn't have storage", lambda: mkldnn_tensor.data_ptr() != 0) def test_unsupported(self): # unsupported types and unsupported types with gpu for dtype in [torch.double, torch.half, torch.uint8, torch.int8, torch.short, torch.int, torch.long]: with self.assertRaises(RuntimeError) as context: torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cpu')).to_mkldnn() if torch.cuda.is_available(): with self.assertRaises(RuntimeError) as context: torch.randn(1, 2, 3, 4, dtype=dtype, device=torch.device('cuda')).to_mkldnn() # supported type with gpu if torch.cuda.is_available(): with self.assertRaises(RuntimeError) as context: torch.randn(1, 2, 3, 4, dtype=torch.float, device=torch.device('cuda')).to_mkldnn() # some factory functions for creator in [torch.empty, torch.ones, torch.zeros, torch.randn, torch.rand]: with self.assertRaises(RuntimeError) as context: creator(1, 2, 3, 4, dtype=torch.float, device=torch.device('cpu'), layout=torch._mkldnn) def test_autograd_to_mkldnn(self): # MKLDNN only supports float32 root = torch.randn(4, 5, dtype=torch.float32, requires_grad=True) def func(root): return root.to_mkldnn().to_dense() # because MKLDNN only supports float32, we need to lessen the precision. # these numbers are just empirical results that seem to work. self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), 'double precision floating point') self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2), 'double precision floating point') def test_autograd_from_mkldnn(self): # MKLDNN only supports float32 root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() def func(root): return root.to_dense() # because MKLDNN only supports float32, we need to lessen the precision. # these numbers are just empirical results that seem to work. self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2), 'double precision floating point') def test_detach(self): root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_() detach = root.detach() self.assertEqual((4, 5), detach.size()) self.assertFalse(detach.requires_grad) self.assertTrue(root.requires_grad) detach_ = root.detach_() self.assertEqual((4, 5), detach_.size()) self.assertFalse(detach_.requires_grad) self.assertFalse(root.requires_grad) def test_repr(self): self.assertTrue("layout=torch._mkldnn" in str(torch.randn((1, 2, 3, 4), dtype=torch.float, device=torch.device('cpu')).to_mkldnn())) def test_conv2d(self): for groups in [1, 4]: N = torch.randint(3, 10, (1,)).item() C = torch.randint(1, 3, (1,)).item() * groups M = torch.randint(1, 3, (1,)).item() * groups x = torch.randn(N, C, 224, 224, dtype=torch.float32) * 100 for bias in [True, False]: conv2d = torch.nn.Conv2d(in_channels=C, out_channels=M, kernel_size=3, stride=2, padding=1, bias=bias, groups=groups).float() mkldnn_conv2d = mkldnn_utils.to_mkldnn(copy.deepcopy(conv2d)) self.assertEqual( conv2d(x), mkldnn_conv2d(x.to_mkldnn()).to_dense()) def test_relu(self): x = torch.randn((4, 5), dtype=torch.float32) * 10 self.assertEqual(torch.relu(x), torch.relu(x.to_mkldnn()).to_dense()) def test_relu_(self): x1 = torch.randn((4, 5), dtype=torch.float32) * 10 x2 = x1.clone().to_mkldnn() self.assertEqual(torch.relu_(x1), torch.relu_(x2).to_dense()) def test_max_pool2d(self): N = torch.randint(3, 10, (1,)).item() C = torch.randint(3, 10, (1,)).item() x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 max_pool2d = torch.nn.MaxPool2d( kernel_size=3, stride=2, padding=1) self.assertEqual( max_pool2d(x), max_pool2d(x.to_mkldnn()).to_dense()) def test_avg_pool2d(self): N = torch.randint(3, 10, (1,)).item() C = torch.randint(3, 10, (1,)).item() x = torch.randn(N, C, 64, 64, dtype=torch.float32) * 10 for count_include_pad in [True, False]: avg_pool2d = torch.nn.AvgPool2d( kernel_size=3, stride=2, padding=1, count_include_pad=count_include_pad) self.assertEqual( avg_pool2d(x), avg_pool2d(x.to_mkldnn()).to_dense()) def test_batch_norm2d(self): N = torch.randint(3, 10, (1,)).item() C = torch.randint(3, 100, (1,)).item() x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 # TODO: support training for train in [False]: bn = torch.nn.BatchNorm2d(C).float().train(train) mkldnn_bn = mkldnn_utils.to_mkldnn(copy.deepcopy(bn)) self.assertEqual( bn(x), mkldnn_bn(x.to_mkldnn()).to_dense()) def test_add(self): N = torch.randint(3, 10, (1,)).item() C = torch.randint(3, 100, (1,)).item() alpha = torch.randn(1, dtype=torch.float32).item() x = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 y = torch.randn(N, C, 35, 45, dtype=torch.float32) * 10 mx = x.to_mkldnn() my = y.to_mkldnn() # add self.assertEqual( x + y, (mx + my).to_dense()) self.assertEqual( torch.add(x, y, alpha=alpha), torch.add(mx, my, alpha=alpha).to_dense()) # add_ x += y mx += my self.assertEqual(x, mx.to_dense()) # add_out out = x.clone() mkldnn_out = out.to_mkldnn() torch.add(x, y, alpha=alpha, out=out) torch.add(mx, my, alpha=alpha, out=mkldnn_out) self.assertEqual(out, mkldnn_out.to_dense()) def test_view(self): x = torch.randn(3, 4, 5, dtype=torch.float32).to_mkldnn() self.assertRaisesRegex(RuntimeError, "Change to use reshape", lambda: x.view(x.size(0), -1)) def test_reshape(self): x = torch.randn(3, 4, 5, dtype=torch.float32) * 10 size = (x.size(0), -1) self.assertEqual( x.reshape(size), x.to_mkldnn().reshape(size).to_dense(), ) def test_clone(self): x = torch.randn(4, 5, dtype=torch.float32) * 10 self.assertEqual( x.clone(), x.to_mkldnn().clone().to_dense(), ) def test_linear(self): in_features = torch.randint(3, 10, (1,)).item() out_features = torch.randint(3, 100, (1,)).item() x = torch.randn(3, in_features, dtype=torch.float32) * 10 for bias in [True, False]: linear = torch.nn.Linear(in_features, out_features).float() mkldnn_linear = mkldnn_utils.to_mkldnn(copy.deepcopy(linear)) self.assertEqual( linear(x), mkldnn_linear(x.to_mkldnn()).to_dense()) if __name__ == '__main__': run_tests()