pytorch/torch/testing/_internal/common_mkldnn.py
haozhe.zhu 5a2db5152d allow to use bf16 as fp32 internal precision for mkldnn conv (#126050)
Allow to use `BF16` as the internal computation data types by `torch.backends.mkldnn.conv.fp32_precision="bf16"`

### TestPlan
python test/test_mkldnn.py -k conv

### Benchmarking

FP32 conv2d vs. BF16 internal computation conv2d on SPR

Single core:

Input | fp32 ms | bf16 internal  ms | Speed up
-- | -- | -- | --
IC:   64, OC: 256, kernel: 1, stride: 1, N: 256, H: 56, W: 56, G: 1, pad: 0 | 185.5071 | 83.4749 | 2.22
IC:   128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 194.7558 | 79.1683| 2.46
IC: 256, OC: 256, kernel: 3, stride: 1,   N: 1, H: 16, W: 16, G: 1, pad: 0 | 1.9213 | 1.3690 | 1.40

56 cores:
Input | fp32 ms | bf16 internal ms | Speed up
-- | -- | -- | --
IC:   64, OC: 256, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 6.5804  | 7.4349 | 0.89
IC:   128, OC: 512, kernel: 1, stride: 1, N: 256, H: 28, W: 28, G: 1, pad: 0 | 4.9940  | 3.8093 | 1.31
IC:   256, OC: 1024, kernel: 1, stride: 1, N: 256, H: 14, W: 14, G: 1, pad: 0 | 8.8359 | 5.5802 | 1.58
IC: 1024, OC: 256, kernel: 1, stride: 1,   N: 256, H: 14, W: 14, G: 1, pad: 0 | 16.5800 | 9.2367 | 1.80
IC: 256, OC: 256, kernel: 3, stride: 1,   N: 1, H: 16, W: 16, G: 1, pad: 0 | 79.5436 | 38.3861  | 2.07

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126050
Approved by: https://github.com/jgong5, https://github.com/jansel

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
2025-07-02 01:31:23 +00:00

84 lines
2.7 KiB
Python

# mypy: ignore-errors
import contextlib
import functools
import inspect
import torch
# Test whether hardware BF32 math mode enabled. It is enabled only on:
# - MKLDNN is available
# - BF16 is supported by MKLDNN
def bf32_is_not_fp32():
if not torch.backends.mkldnn.is_available():
return False
if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
return False
return True
@contextlib.contextmanager
def bf32_off():
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
try:
torch.backends.mkldnn.matmul.fp32_precision = "ieee"
torch.backends.mkldnn.conv.fp32_precision = "ieee"
yield
finally:
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
@contextlib.contextmanager
def bf32_on(self, bf32_precision=1e-2):
old_matmul_precision = torch.backends.mkldnn.matmul.fp32_precision
old_conv_precision = torch.backends.mkldnn.conv.fp32_precision
old_precision = self.precision
try:
torch.backends.mkldnn.matmul.fp32_precision = "bf16"
torch.backends.mkldnn.conv.fp32_precision = "bf16"
self.precision = bf32_precision
yield
finally:
torch.backends.mkldnn.matmul.fp32_precision = old_matmul_precision
torch.backends.mkldnn.conv.fp32_precision = old_conv_precision
self.precision = old_precision
# This is a wrapper that wraps a test to run this test twice, one with
# allow_bf32=True, another with allow_bf32=False. When running with
# allow_bf32=True, it will use reduced precision as specified by the
# argument
def bf32_on_and_off(bf32_precision=1e-2):
def with_bf32_disabled(self, function_call):
with bf32_off():
function_call()
def with_bf32_enabled(self, function_call):
with bf32_on(self, bf32_precision):
function_call()
def wrapper(f):
params = inspect.signature(f).parameters
arg_names = tuple(params.keys())
@functools.wraps(f)
def wrapped(*args, **kwargs):
kwargs.update(zip(arg_names, args))
cond = bf32_is_not_fp32()
if "device" in kwargs:
cond = cond and (torch.device(kwargs["device"]).type == "cpu")
if "dtype" in kwargs:
cond = cond and (kwargs["dtype"] == torch.float)
if cond:
with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
else:
f(**kwargs)
return wrapped
return wrapper