From 5a2db5152d23f76dbb45d20008d9af68e761e8d1 Mon Sep 17 00:00:00 2001 From: "haozhe.zhu" Date: Tue, 1 Jul 2025 08:14:50 +0000 Subject: [PATCH] allow to use bf16 as fp32 internal precision for mkldnn conv (#126050) Allow to use `BF16` as the internal computation data types by `torch.backends.mkldnn.conv.fp32_precision="bf16"` ### TestPlan python test/test_mkldnn.py -k conv ### Benchmarking FP32 conv2d vs. BF16 internal computation conv2d on SPR Single core: Input | fp32 ms | bf16 internal ms | Speed up -- | -- | -- | -- IC: 64, OC: 256, kernel: 1, stride: 1, N: 256, H: 56, W: 56, G: 1, pad: 0 | 185.5071 | 83.4749 | 2.22 IC: 128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 194.7558 | 79.1683| 2.46 IC: 256, OC: 256, kernel: 3, stride: 1, N: 1, H: 16, W: 16, G: 1, pad: 0 | 1.9213 | 1.3690 | 1.40 56 cores: Input | fp32 ms | bf16 internal ms | Speed up -- | -- | -- | -- IC: 64, OC: 256, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 6.5804 | 7.4349 | 0.89 IC: 128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 4.9940 | 3.8093 | 1.31 IC: 256, OC: 1024, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 8.8359 | 5.5802 | 1.58 IC: 1024, OC: 256, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 16.5800 | 9.2367 | 1.80 IC: 256, OC: 256, kernel: 3, stride: 1, N: 1, H: 16, W: 16, G: 1, pad: 0 | 79.5436 | 38.3861 | 2.07 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126050 Approved by: https://github.com/jgong5, https://github.com/jansel Co-authored-by: Jiang, Yanbing --- aten/src/ATen/native/mkldnn/Conv.cpp | 24 +++++++++++++++++++++++- test/test_mkldnn.py | 20 ++++++++++++++++++-- torch/testing/_internal/common_mkldnn.py | 22 ++++++++++++++-------- 3 files changed, 55 insertions(+), 11 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/Conv.cpp b/aten/src/ATen/native/mkldnn/Conv.cpp index d13fe6b2328..8de7d84cada 100644 --- a/aten/src/ATen/native/mkldnn/Conv.cpp +++ b/aten/src/ATen/native/mkldnn/Conv.cpp @@ -155,6 +155,12 @@ static void check_shape_forward(const Tensor& input, // but weight/bias and grad_weight/grad_bias are always CPU tensor. // +static bool mkldnn_conv_enabled_fpmath_mode_bf16(){ + return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" && + mkldnn_bf16_device_check(); +} + + static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) { auto memory_format = at::MemoryFormat::Contiguous; if (is_channels_last) { @@ -163,7 +169,7 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo return memory_format; } -static void _mkldnn_convolution_out ( +static void _mkldnn_convolution_out( const Tensor& input_t, const Tensor& weight_t, const Tensor& bias, @@ -261,6 +267,10 @@ static Tensor _mkldnn_convolution( output.resize_(output_sizes, memory_format); y = itensor_from_tensor(output); } + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } _mkldnn_convolution_out( input_t, weight_t, @@ -442,6 +452,10 @@ Tensor mkldnn_convolution_pointwise_binary( op_attr.set_post_ops(po); auto aprop_kind = ideep::prop_kind::forward_inference; + if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias); ideep::convolution_forward::compute_binary( @@ -579,6 +593,10 @@ Tensor& mkldnn_convolution_pointwise_binary_( op_attr = ideep::attr_t::fuse_sum(); } auto aprop_kind = ideep::prop_kind::forward_inference; + if (mkldnn_conv_enabled_fpmath_mode_bf16() && + input_t.scalar_type() == at::kFloat) { + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } _mkldnn_convolution_out( input_t, weight_t, @@ -697,6 +715,10 @@ Tensor _mkldnn_convolution_transpose( y = itensor_from_tensor(output); } + if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){ + op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16); + } + if (bias.defined()) { const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true); ideep::convolution_transpose_forward::compute_v3( diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py index 1dfb42758e4..0f73a71c182 100644 --- a/test/test_mkldnn.py +++ b/test/test_mkldnn.py @@ -27,6 +27,7 @@ from torch.testing._internal.common_device_type import ( instantiate_device_type_tests, dtypes, ) +from torch.testing._internal.common_mkldnn import bf32_on_and_off # batched grad doesn't support mkldnn gradcheck = functools.partial(gradcheck, check_batched_grad=False) @@ -264,7 +265,10 @@ class TestMkldnn(TestCase): loss1.backward() if not train or (train and dim != 1): y_mkldnn = mkldnn_conv(x2).to_dense() - self.assertEqual(y_aten, y_mkldnn) + if self.precision != 0: + self.assertEqual(y_aten, y_mkldnn, atol=self.precision, rtol=self.precision) + else: + self.assertEqual(y_aten, y_mkldnn) if not train: self._test_serialization(mkldnn_conv, (x.to_mkldnn(),)) self._test_tracing(mkldnn_conv, (x.to_mkldnn(),)) @@ -280,12 +284,15 @@ class TestMkldnn(TestCase): if bias: self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad) + @bf32_on_and_off() def test_conv1d(self): self._test_conv_base(dim=1) + @bf32_on_and_off() def test_conv2d(self): self._test_conv_base(dim=2) + @bf32_on_and_off() def test_conv3d(self): self._test_conv_base(dim=3) @@ -400,6 +407,7 @@ class TestMkldnn(TestCase): self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec) self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec) + @bf32_on_and_off() def test_conv_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32) @@ -435,6 +443,7 @@ class TestMkldnn(TestCase): self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec) + @bf32_on_and_off() def test_conv_transpose_nhwc_fp32(self): self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32) self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32) @@ -509,7 +518,11 @@ class TestMkldnn(TestCase): if train: y.sum().backward() - self.assertEqual(y, y_ref) + if self.precision != 0: + self.assertEqual(y, y_ref, atol=self.precision, rtol=self.precision) + else: + self.assertEqual(y, y_ref) + if train: self.assertEqual(x.grad, x_ref.grad) self.assertEqual(conv.weight.grad, @@ -519,12 +532,15 @@ class TestMkldnn(TestCase): if bias: self.assertEqual(conv.bias.grad, conv_ref.bias.grad) + @bf32_on_and_off() def test_conv_transpose1d(self): self._test_conv_transpose_base(dim=1) + @bf32_on_and_off() def test_conv_transpose2d(self): self._test_conv_transpose_base(dim=2) + @bf32_on_and_off() def test_conv_transpose3d(self): self._test_conv_transpose_base(dim=3) diff --git a/torch/testing/_internal/common_mkldnn.py b/torch/testing/_internal/common_mkldnn.py index f9a05cf807a..ffaed6c7e00 100644 --- a/torch/testing/_internal/common_mkldnn.py +++ b/torch/testing/_internal/common_mkldnn.py @@ -20,24 +20,30 @@ def bf32_is_not_fp32(): @contextlib.contextmanager def bf32_off(): - old_matmul_precision = torch.get_float32_matmul_precision() + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision try: - torch.set_float32_matmul_precision("highest") + torch.backends.mkldnn.matmul.fp32_precision = "ieee" + torch.backends.mkldnn.conv.fp32_precision = "ieee" yield finally: - torch.set_float32_matmul_precision(old_matmul_precision) + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision @contextlib.contextmanager -def bf32_on(self, bf32_precision=1e-5): - old_matmul_precision = torch.get_float32_matmul_precision() +def bf32_on(self, bf32_precision=1e-2): + old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision + old_conv_precision = torch.backends.mkldnn.conv.fp32_precision old_precision = self.precision try: - torch.set_float32_matmul_precision("medium") + torch.backends.mkldnn.matmul.fp32_precision = "bf16" + torch.backends.mkldnn.conv.fp32_precision = "bf16" self.precision = bf32_precision yield finally: - torch.set_float32_matmul_precision(old_matmul_precision) + torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision + torch.backends.mkldnn.conv.fp32_precision = old_conv_precision self.precision = old_precision @@ -45,7 +51,7 @@ def bf32_on(self, bf32_precision=1e-5): # allow_bf32=True, another with allow_bf32=False. When running with # allow_bf32=True, it will use reduced precision as specified by the # argument -def bf32_on_and_off(bf32_precision=1e-5): +def bf32_on_and_off(bf32_precision=1e-2): def with_bf32_disabled(self, function_call): with bf32_off(): function_call()