[inductor] enable bf32 for mkldnn linear pointwise/binary in inductor (#127294)

When `torch.backends.mkldnn.matmul.fp32_precision=='bf16'`, we also enabled mkldnn linear in inductor path and allow to run with bf16 computation data type.

Testplan:
```
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_unary
python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_fp32
python test/inductor/test_mkldnn_pattern_matcher.py -k test_multi_linear_share_same_input
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127294
Approved by: https://github.com/jgong5, https://github.com/jansel

Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
haozhe.zhu 2025-07-07 01:37:46 +00:00 committed by PyTorch MergeBot
parent d26ca5de05
commit 815545f2dd
3 changed files with 48 additions and 9 deletions

View File

@ -68,6 +68,11 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
namespace at::native {
static bool use_mkldnn_bf32_linear() {
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
mkldnn_bf16_device_check();
}
Tensor mkldnn_linear(
const Tensor& self,
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
@ -251,7 +256,9 @@ Tensor mkldnn_linear_pointwise(
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
op_attr = it->second(scalars, algorithm);
}
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,
@ -341,6 +348,10 @@ Tensor mkldnn_linear_pointwise_binary(
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
auto aprop_kind = ideep::prop_kind::forward_inference;
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
}
if (mkldnn_bias.has_value()) {
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
mkldnn_input,

View File

@ -699,6 +699,7 @@ class TestPatternMatcherGeneric(TestPatternMatcherBase):
class TestPatternMatcher(TestPatternMatcherBase):
@bf32_on_and_off()
def test_linear_unary(self, device="cpu"):
self.device = device
@ -729,6 +730,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
dtypes.append(torch.bfloat16)
if is_mkldnn_fp16_supported(self.device):
dtypes.append(torch.float16)
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
dtypes.append(torch.float32)
options = itertools.product(unary_list, [True, False], dtypes)
for unary_fn, bias, dtype in options:
metrics.reset()
@ -739,7 +742,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
def matcher_check_fn():
match_nodes = unary_list[unary_fn]
if self._check_unary_is_decomposed(unary_fn):
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
# Has extra dtype conversion nodes for autocast.
match_nodes += 2
self.assertEqual(
@ -751,9 +754,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
)
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
# only generated 1 kernel for "to"
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
# only generated 1 kernel for "to_dtype"
expected_kernel_count = 2 if TEST_ACL else 1
if dtype == torch.float32:
# In BF32, input is float32, will not generate kernel for "to_dtype"
expected_kernel_count -= 1
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
@bf32_on_and_off()
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
def test_linear_fp32(self, device="cpu"):
self.device = device
@ -901,6 +909,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
# 1 kernel for "to_lowp", 2 kernels for unary ops
self.assertEqual(metrics.generated_kernel_count, 3)
@bf32_on_and_off()
def test_linear_binary(self, device="cpu"):
self.device = device
@ -922,6 +931,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
dtypes.append(torch.bfloat16)
if is_mkldnn_fp16_supported(self.device):
dtypes.append(torch.float16)
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
dtypes.append(torch.float32)
options = itertools.product(
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
)
@ -958,7 +969,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
matcher_check_fn,
check_autocast=dtype,
)
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
# only generated 1 kernel for "to_dtype"
expected_kernel_count = 2 if TEST_ACL else 1
if dtype == torch.float32:
# In BF32, input is float32, will not generate kernel for "to_dtype"
expected_kernel_count -= 1
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
def test_linear_binary_broadcast_shapes(self, device="cpu"):
self.device = device

View File

@ -1228,10 +1228,15 @@ if torch._C._has_mkldnn:
torch.bfloat16,
torch.float16,
)
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
use_bf16_for_fp32_weight = (
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
)
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
if (
not is_lp_weight
not compute_with_lp
and not mkldnn._is_mkldnn_acl_supported()
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
):
@ -1444,16 +1449,23 @@ if torch._C._has_mkldnn:
torch.bfloat16,
torch.float16,
)
bf32_matmul_enabled = (
torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
)
use_bf16_for_fp32_weight = (
bf32_matmul_enabled and weight_dtype == torch.float32
)
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size):
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
)
packed_weight_node = mkldnn_device_op.pack_linear_weight(
graph, is_lp_weight, transpose_weight_node, batch_size
graph, compute_with_lp, transpose_weight_node, batch_size
)
packed_linear_node = mkldnn_device_op.pack_linear(
graph, is_lp_weight, batch_size, input, packed_weight_node, bias
graph, compute_with_lp, batch_size, input, packed_weight_node, bias
)
linear_node.replace_all_uses_with(packed_linear_node)