allow to use bf16 as fp32 internal precision for mkldnn conv (#126050)

Allow to use `BF16` as the internal computation data types by `torch.backends.mkldnn.conv.fp32_precision="bf16"`

### TestPlan
python test/test_mkldnn.py -k conv

### Benchmarking

FP32 conv2d vs. BF16 internal computation conv2d on SPR

Single core:

Input | fp32 ms | bf16 internal  ms | Speed up
-- | -- | -- | --
IC:   64, OC: 256, kernel: 1, stride: 1, N: 256, H: 56, W: 56, G: 1, pad: 0 | 185.5071 | 83.4749 | 2.22
IC:   128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 194.7558 | 79.1683| 2.46
IC: 256, OC: 256, kernel: 3, stride: 1,   N: 1, H: 16, W: 16, G: 1, pad: 0 | 1.9213 | 1.3690 | 1.40

56 cores:
Input | fp32 ms | bf16 internal ms | Speed up
-- | -- | -- | --
IC:   64, OC: 256, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 6.5804  | 7.4349 | 0.89
IC:   128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 4.9940  | 3.8093 | 1.31
IC:   256, OC: 1024, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 8.8359 | 5.5802 | 1.58
IC: 1024, OC: 256, kernel: 1, stride: 1,   N: 256, H: 14, W: 14, G: 1, pad: 0 | 16.5800 | 9.2367 | 1.80
IC: 256, OC: 256, kernel: 3, stride: 1,   N: 1, H: 16, W: 16, G: 1, pad: 0 | 79.5436 | 38.3861  | 2.07

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126050
Approved by: https://github.com/jgong5, https://github.com/jansel

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
haozhe.zhu 2025-07-01 08:14:50 +00:00 committed by PyTorch MergeBot
parent 0a63053fe9
commit 5a2db5152d
3 changed files with 55 additions and 11 deletions

View File

@ -155,6 +155,12 @@ static void check_shape_forward(const Tensor& input,
// but weight/bias and grad_weight/grad_bias are always CPU tensor.
//
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" &&
mkldnn_bf16_device_check();
}
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
auto memory_format = at::MemoryFormat::Contiguous;
if (is_channels_last) {
@ -163,7 +169,7 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo
return memory_format;
}
static void _mkldnn_convolution_out (
static void _mkldnn_convolution_out(
const Tensor& input_t,
const Tensor& weight_t,
const Tensor& bias,
@ -261,6 +267,10 @@ static Tensor _mkldnn_convolution(
output.resize_(output_sizes, memory_format);
y = itensor_from_tensor(output);
}
if (mkldnn_conv_enabled_fpmath_mode_bf16() &&
input_t.scalar_type() == at::kFloat) {
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
_mkldnn_convolution_out(
input_t,
weight_t,
@ -442,6 +452,10 @@ Tensor mkldnn_convolution_pointwise_binary(
op_attr.set_post_ops(po);
auto aprop_kind = ideep::prop_kind::forward_inference;
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias);
ideep::convolution_forward::compute_binary(
@ -579,6 +593,10 @@ Tensor& mkldnn_convolution_pointwise_binary_(
op_attr = ideep::attr_t::fuse_sum();
}
auto aprop_kind = ideep::prop_kind::forward_inference;
if (mkldnn_conv_enabled_fpmath_mode_bf16() &&
input_t.scalar_type() == at::kFloat) {
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
_mkldnn_convolution_out(
input_t,
weight_t,
@ -697,6 +715,10 @@ Tensor _mkldnn_convolution_transpose(
y = itensor_from_tensor(output);
}
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
if (bias.defined()) {
const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
ideep::convolution_transpose_forward::compute_v3(

View File

@ -27,6 +27,7 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
)
from torch.testing._internal.common_mkldnn import bf32_on_and_off
# batched grad doesn't support mkldnn
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
@ -264,6 +265,9 @@ class TestMkldnn(TestCase):
loss1.backward()
if not train or (train and dim != 1):
y_mkldnn = mkldnn_conv(x2).to_dense()
if self.precision != 0:
self.assertEqual(y_aten, y_mkldnn, atol=self.precision, rtol=self.precision)
else:
self.assertEqual(y_aten, y_mkldnn)
if not train:
self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
@ -280,12 +284,15 @@ class TestMkldnn(TestCase):
if bias:
self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
@bf32_on_and_off()
def test_conv1d(self):
self._test_conv_base(dim=1)
@bf32_on_and_off()
def test_conv2d(self):
self._test_conv_base(dim=2)
@bf32_on_and_off()
def test_conv3d(self):
self._test_conv_base(dim=3)
@ -400,6 +407,7 @@ class TestMkldnn(TestCase):
self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
@bf32_on_and_off()
def test_conv_nhwc_fp32(self):
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
@ -435,6 +443,7 @@ class TestMkldnn(TestCase):
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
@bf32_on_and_off()
def test_conv_transpose_nhwc_fp32(self):
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
@ -509,7 +518,11 @@ class TestMkldnn(TestCase):
if train:
y.sum().backward()
if self.precision != 0:
self.assertEqual(y, y_ref, atol=self.precision, rtol=self.precision)
else:
self.assertEqual(y, y_ref)
if train:
self.assertEqual(x.grad, x_ref.grad)
self.assertEqual(conv.weight.grad,
@ -519,12 +532,15 @@ class TestMkldnn(TestCase):
if bias:
self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
@bf32_on_and_off()
def test_conv_transpose1d(self):
self._test_conv_transpose_base(dim=1)
@bf32_on_and_off()
def test_conv_transpose2d(self):
self._test_conv_transpose_base(dim=2)
@bf32_on_and_off()
def test_conv_transpose3d(self):
self._test_conv_transpose_base(dim=3)

View File

@ -20,24 +20,30 @@ def bf32_is_not_fp32():
@contextlib.contextmanager
def bf32_off():
old_matmul_precision = torch.get_float32_matmul_precision()
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
try:
torch.set_float32_matmul_precision("highest")
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
torch.backends.mkldnn.conv.fp32_precision = "ieee"
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
@contextlib.contextmanager
def bf32_on(self, bf32_precision=1e-5):
old_matmul_precision = torch.get_float32_matmul_precision()
def bf32_on(self, bf32_precision=1e-2):
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
old_precision = self.precision
try:
torch.set_float32_matmul_precision("medium")
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
torch.backends.mkldnn.conv.fp32_precision = "bf16"
self.precision = bf32_precision
yield
finally:
torch.set_float32_matmul_precision(old_matmul_precision)
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
self.precision = old_precision
@ -45,7 +51,7 @@ def bf32_on(self, bf32_precision=1e-5):
# allow_bf32=True, another with allow_bf32=False. When running with
# allow_bf32=True, it will use reduced precision as specified by the
# argument
def bf32_on_and_off(bf32_precision=1e-5):
def bf32_on_and_off(bf32_precision=1e-2):
def with_bf32_disabled(self, function_call):
with bf32_off():
function_call()