mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
allow to use bf16 as fp32 internal precision for mkldnn conv (#126050)
Allow to use `BF16` as the internal computation data types by `torch.backends.mkldnn.conv.fp32_precision="bf16"` ### TestPlan python test/test_mkldnn.py -k conv ### Benchmarking FP32 conv2d vs. BF16 internal computation conv2d on SPR Single core: Input | fp32 ms | bf16 internal ms | Speed up -- | -- | -- | -- IC: 64, OC: 256, kernel: 1, stride: 1, N: 256, H: 56, W: 56, G: 1, pad: 0 | 185.5071 | 83.4749 | 2.22 IC: 128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 194.7558 | 79.1683| 2.46 IC: 256, OC: 256, kernel: 3, stride: 1, N: 1, H: 16, W: 16, G: 1, pad: 0 | 1.9213 | 1.3690 | 1.40 56 cores: Input | fp32 ms | bf16 internal ms | Speed up -- | -- | -- | -- IC: 64, OC: 256, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 6.5804 | 7.4349 | 0.89 IC: 128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 4.9940 | 3.8093 | 1.31 IC: 256, OC: 1024, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 8.8359 | 5.5802 | 1.58 IC: 1024, OC: 256, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 16.5800 | 9.2367 | 1.80 IC: 256, OC: 256, kernel: 3, stride: 1, N: 1, H: 16, W: 16, G: 1, pad: 0 | 79.5436 | 38.3861 | 2.07 Pull Request resolved: https://github.com/pytorch/pytorch/pull/126050 Approved by: https://github.com/jgong5, https://github.com/jansel Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
parent
0a63053fe9
commit
5a2db5152d
|
|
@ -155,6 +155,12 @@ static void check_shape_forward(const Tensor& input,
|
||||||
// but weight/bias and grad_weight/grad_bias are always CPU tensor.
|
// but weight/bias and grad_weight/grad_bias are always CPU tensor.
|
||||||
//
|
//
|
||||||
|
|
||||||
|
static bool mkldnn_conv_enabled_fpmath_mode_bf16(){
|
||||||
|
return at::globalContext().float32Precision("mkldnn", "conv") == "bf16" &&
|
||||||
|
mkldnn_bf16_device_check();
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
|
static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bool is_channels_last) {
|
||||||
auto memory_format = at::MemoryFormat::Contiguous;
|
auto memory_format = at::MemoryFormat::Contiguous;
|
||||||
if (is_channels_last) {
|
if (is_channels_last) {
|
||||||
|
|
@ -163,7 +169,7 @@ static inline at::MemoryFormat mkldnn_convolution_memory_format(int64_t dims, bo
|
||||||
return memory_format;
|
return memory_format;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void _mkldnn_convolution_out (
|
static void _mkldnn_convolution_out(
|
||||||
const Tensor& input_t,
|
const Tensor& input_t,
|
||||||
const Tensor& weight_t,
|
const Tensor& weight_t,
|
||||||
const Tensor& bias,
|
const Tensor& bias,
|
||||||
|
|
@ -261,6 +267,10 @@ static Tensor _mkldnn_convolution(
|
||||||
output.resize_(output_sizes, memory_format);
|
output.resize_(output_sizes, memory_format);
|
||||||
y = itensor_from_tensor(output);
|
y = itensor_from_tensor(output);
|
||||||
}
|
}
|
||||||
|
if (mkldnn_conv_enabled_fpmath_mode_bf16() &&
|
||||||
|
input_t.scalar_type() == at::kFloat) {
|
||||||
|
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||||
|
}
|
||||||
_mkldnn_convolution_out(
|
_mkldnn_convolution_out(
|
||||||
input_t,
|
input_t,
|
||||||
weight_t,
|
weight_t,
|
||||||
|
|
@ -442,6 +452,10 @@ Tensor mkldnn_convolution_pointwise_binary(
|
||||||
op_attr.set_post_ops(po);
|
op_attr.set_post_ops(po);
|
||||||
auto aprop_kind = ideep::prop_kind::forward_inference;
|
auto aprop_kind = ideep::prop_kind::forward_inference;
|
||||||
|
|
||||||
|
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
|
||||||
|
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||||
|
}
|
||||||
|
|
||||||
if (bias.defined()) {
|
if (bias.defined()) {
|
||||||
const ideep::tensor b = itensor_from_tensor(bias);
|
const ideep::tensor b = itensor_from_tensor(bias);
|
||||||
ideep::convolution_forward::compute_binary(
|
ideep::convolution_forward::compute_binary(
|
||||||
|
|
@ -579,6 +593,10 @@ Tensor& mkldnn_convolution_pointwise_binary_(
|
||||||
op_attr = ideep::attr_t::fuse_sum();
|
op_attr = ideep::attr_t::fuse_sum();
|
||||||
}
|
}
|
||||||
auto aprop_kind = ideep::prop_kind::forward_inference;
|
auto aprop_kind = ideep::prop_kind::forward_inference;
|
||||||
|
if (mkldnn_conv_enabled_fpmath_mode_bf16() &&
|
||||||
|
input_t.scalar_type() == at::kFloat) {
|
||||||
|
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||||
|
}
|
||||||
_mkldnn_convolution_out(
|
_mkldnn_convolution_out(
|
||||||
input_t,
|
input_t,
|
||||||
weight_t,
|
weight_t,
|
||||||
|
|
@ -697,6 +715,10 @@ Tensor _mkldnn_convolution_transpose(
|
||||||
y = itensor_from_tensor(output);
|
y = itensor_from_tensor(output);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (mkldnn_conv_enabled_fpmath_mode_bf16() && input_t.scalar_type() ==at::kFloat){
|
||||||
|
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||||
|
}
|
||||||
|
|
||||||
if (bias.defined()) {
|
if (bias.defined()) {
|
||||||
const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
|
const ideep::tensor b = itensor_from_tensor(bias, /*from_const_data_ptr*/true);
|
||||||
ideep::convolution_transpose_forward::compute_v3(
|
ideep::convolution_transpose_forward::compute_v3(
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
dtypes,
|
dtypes,
|
||||||
)
|
)
|
||||||
|
from torch.testing._internal.common_mkldnn import bf32_on_and_off
|
||||||
|
|
||||||
# batched grad doesn't support mkldnn
|
# batched grad doesn't support mkldnn
|
||||||
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
|
gradcheck = functools.partial(gradcheck, check_batched_grad=False)
|
||||||
|
|
@ -264,7 +265,10 @@ class TestMkldnn(TestCase):
|
||||||
loss1.backward()
|
loss1.backward()
|
||||||
if not train or (train and dim != 1):
|
if not train or (train and dim != 1):
|
||||||
y_mkldnn = mkldnn_conv(x2).to_dense()
|
y_mkldnn = mkldnn_conv(x2).to_dense()
|
||||||
self.assertEqual(y_aten, y_mkldnn)
|
if self.precision != 0:
|
||||||
|
self.assertEqual(y_aten, y_mkldnn, atol=self.precision, rtol=self.precision)
|
||||||
|
else:
|
||||||
|
self.assertEqual(y_aten, y_mkldnn)
|
||||||
if not train:
|
if not train:
|
||||||
self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
|
self._test_serialization(mkldnn_conv, (x.to_mkldnn(),))
|
||||||
self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
|
self._test_tracing(mkldnn_conv, (x.to_mkldnn(),))
|
||||||
|
|
@ -280,12 +284,15 @@ class TestMkldnn(TestCase):
|
||||||
if bias:
|
if bias:
|
||||||
self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
|
self.assertEqual(conv.bias.grad, mkldnn_conv.bias.grad)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv1d(self):
|
def test_conv1d(self):
|
||||||
self._test_conv_base(dim=1)
|
self._test_conv_base(dim=1)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv2d(self):
|
def test_conv2d(self):
|
||||||
self._test_conv_base(dim=2)
|
self._test_conv_base(dim=2)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv3d(self):
|
def test_conv3d(self):
|
||||||
self._test_conv_base(dim=3)
|
self._test_conv_base(dim=3)
|
||||||
|
|
||||||
|
|
@ -400,6 +407,7 @@ class TestMkldnn(TestCase):
|
||||||
self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
|
self.assertEqual(conv1.bias.grad, conv2.bias.grad, atol=prec, rtol=prec)
|
||||||
self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
|
self.assertEqual(x1.grad, x2.grad, atol=prec, rtol=prec)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv_nhwc_fp32(self):
|
def test_conv_nhwc_fp32(self):
|
||||||
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
|
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.contiguous_format, dtype=torch.float32)
|
||||||
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
|
self._test_conv_deconv_nhwc_base(torch.nn.Conv2d, torch.channels_last, dtype=torch.float32)
|
||||||
|
|
@ -435,6 +443,7 @@ class TestMkldnn(TestCase):
|
||||||
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
|
self._test_conv_deconv_nhwc_base(torch.nn.Conv3d, torch.channels_last_3d, dtype=dtype, prec=prec)
|
||||||
|
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv_transpose_nhwc_fp32(self):
|
def test_conv_transpose_nhwc_fp32(self):
|
||||||
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
|
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.contiguous_format, dtype=torch.float32)
|
||||||
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
|
self._test_conv_deconv_nhwc_base(torch.nn.ConvTranspose2d, torch.channels_last, dtype=torch.float32)
|
||||||
|
|
@ -509,7 +518,11 @@ class TestMkldnn(TestCase):
|
||||||
if train:
|
if train:
|
||||||
y.sum().backward()
|
y.sum().backward()
|
||||||
|
|
||||||
self.assertEqual(y, y_ref)
|
if self.precision != 0:
|
||||||
|
self.assertEqual(y, y_ref, atol=self.precision, rtol=self.precision)
|
||||||
|
else:
|
||||||
|
self.assertEqual(y, y_ref)
|
||||||
|
|
||||||
if train:
|
if train:
|
||||||
self.assertEqual(x.grad, x_ref.grad)
|
self.assertEqual(x.grad, x_ref.grad)
|
||||||
self.assertEqual(conv.weight.grad,
|
self.assertEqual(conv.weight.grad,
|
||||||
|
|
@ -519,12 +532,15 @@ class TestMkldnn(TestCase):
|
||||||
if bias:
|
if bias:
|
||||||
self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
|
self.assertEqual(conv.bias.grad, conv_ref.bias.grad)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv_transpose1d(self):
|
def test_conv_transpose1d(self):
|
||||||
self._test_conv_transpose_base(dim=1)
|
self._test_conv_transpose_base(dim=1)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv_transpose2d(self):
|
def test_conv_transpose2d(self):
|
||||||
self._test_conv_transpose_base(dim=2)
|
self._test_conv_transpose_base(dim=2)
|
||||||
|
|
||||||
|
@bf32_on_and_off()
|
||||||
def test_conv_transpose3d(self):
|
def test_conv_transpose3d(self):
|
||||||
self._test_conv_transpose_base(dim=3)
|
self._test_conv_transpose_base(dim=3)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -20,24 +20,30 @@ def bf32_is_not_fp32():
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def bf32_off():
|
def bf32_off():
|
||||||
old_matmul_precision = torch.get_float32_matmul_precision()
|
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
|
||||||
|
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
|
||||||
try:
|
try:
|
||||||
torch.set_float32_matmul_precision("highest")
|
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
|
||||||
|
torch.backends.mkldnn.conv.fp32_precision = "ieee"
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
torch.set_float32_matmul_precision(old_matmul_precision)
|
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
|
||||||
|
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def bf32_on(self, bf32_precision=1e-5):
|
def bf32_on(self, bf32_precision=1e-2):
|
||||||
old_matmul_precision = torch.get_float32_matmul_precision()
|
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
|
||||||
|
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
|
||||||
old_precision = self.precision
|
old_precision = self.precision
|
||||||
try:
|
try:
|
||||||
torch.set_float32_matmul_precision("medium")
|
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
|
||||||
|
torch.backends.mkldnn.conv.fp32_precision = "bf16"
|
||||||
self.precision = bf32_precision
|
self.precision = bf32_precision
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
torch.set_float32_matmul_precision(old_matmul_precision)
|
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
|
||||||
|
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
|
||||||
self.precision = old_precision
|
self.precision = old_precision
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,7 +51,7 @@ def bf32_on(self, bf32_precision=1e-5):
|
||||||
# allow_bf32=True, another with allow_bf32=False. When running with
|
# allow_bf32=True, another with allow_bf32=False. When running with
|
||||||
# allow_bf32=True, it will use reduced precision as specified by the
|
# allow_bf32=True, it will use reduced precision as specified by the
|
||||||
# argument
|
# argument
|
||||||
def bf32_on_and_off(bf32_precision=1e-5):
|
def bf32_on_and_off(bf32_precision=1e-2):
|
||||||
def with_bf32_disabled(self, function_call):
|
def with_bf32_disabled(self, function_call):
|
||||||
with bf32_off():
|
with bf32_off():
|
||||||
function_call()
|
function_call()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user