mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] enable bf32 for mkldnn linear pointwise/binary in inductor (#127294)
When `torch.backends.mkldnn.matmul.fp32_precision=='bf16'`, we also enabled mkldnn linear in inductor path and allow to run with bf16 computation data type. Testplan: ``` python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_unary python test/inductor/test_mkldnn_pattern_matcher.py -k test_linear_fp32 python test/inductor/test_mkldnn_pattern_matcher.py -k test_multi_linear_share_same_input ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/127294 Approved by: https://github.com/jgong5, https://github.com/jansel Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
This commit is contained in:
parent
d26ca5de05
commit
815545f2dd
|
|
@ -68,6 +68,11 @@ mkldnn_scaled_mm(const Tensor& mat1, const Tensor& mat2,
|
|||
|
||||
namespace at::native {
|
||||
|
||||
static bool use_mkldnn_bf32_linear() {
|
||||
return at::globalContext().float32Precision("mkldnn", "matmul") == "bf16" &&
|
||||
mkldnn_bf16_device_check();
|
||||
}
|
||||
|
||||
Tensor mkldnn_linear(
|
||||
const Tensor& self,
|
||||
const Tensor& weight_t, const std::optional<Tensor>& bias_opt) {
|
||||
|
|
@ -251,7 +256,9 @@ Tensor mkldnn_linear_pointwise(
|
|||
it != fusion_unary_attr_map().end(), "Fusion behavior undefined.");
|
||||
op_attr = it->second(scalars, algorithm);
|
||||
}
|
||||
|
||||
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
|
||||
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||
}
|
||||
if (mkldnn_bias.has_value()) {
|
||||
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
|
||||
mkldnn_input,
|
||||
|
|
@ -341,6 +348,10 @@ Tensor mkldnn_linear_pointwise_binary(
|
|||
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
|
||||
auto aprop_kind = ideep::prop_kind::forward_inference;
|
||||
|
||||
if (use_mkldnn_bf32_linear() && input_t.scalar_type() == at::kFloat){
|
||||
op_attr.set_fpmath_mode(dnnl_fpmath_mode_bf16);
|
||||
}
|
||||
|
||||
if (mkldnn_bias.has_value()) {
|
||||
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
|
||||
mkldnn_input,
|
||||
|
|
|
|||
|
|
@ -699,6 +699,7 @@ class TestPatternMatcherGeneric(TestPatternMatcherBase):
|
|||
|
||||
|
||||
class TestPatternMatcher(TestPatternMatcherBase):
|
||||
@bf32_on_and_off()
|
||||
def test_linear_unary(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
|
|
@ -729,6 +730,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
dtypes.append(torch.bfloat16)
|
||||
if is_mkldnn_fp16_supported(self.device):
|
||||
dtypes.append(torch.float16)
|
||||
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
|
||||
dtypes.append(torch.float32)
|
||||
options = itertools.product(unary_list, [True, False], dtypes)
|
||||
for unary_fn, bias, dtype in options:
|
||||
metrics.reset()
|
||||
|
|
@ -739,7 +742,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
|
||||
def matcher_check_fn():
|
||||
match_nodes = unary_list[unary_fn]
|
||||
if self._check_unary_is_decomposed(unary_fn):
|
||||
if dtype != torch.float32 and self._check_unary_is_decomposed(unary_fn):
|
||||
# Has extra dtype conversion nodes for autocast.
|
||||
match_nodes += 2
|
||||
self.assertEqual(
|
||||
|
|
@ -751,9 +754,14 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
)
|
||||
|
||||
self._test_common(mod, (v,), matcher_check_fn, check_autocast=dtype)
|
||||
# only generated 1 kernel for "to"
|
||||
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
||||
# only generated 1 kernel for "to_dtype"
|
||||
expected_kernel_count = 2 if TEST_ACL else 1
|
||||
if dtype == torch.float32:
|
||||
# In BF32, input is float32, will not generate kernel for "to_dtype"
|
||||
expected_kernel_count -= 1
|
||||
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
|
||||
|
||||
@bf32_on_and_off()
|
||||
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
|
||||
def test_linear_fp32(self, device="cpu"):
|
||||
self.device = device
|
||||
|
|
@ -901,6 +909,7 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
# 1 kernel for "to_lowp", 2 kernels for unary ops
|
||||
self.assertEqual(metrics.generated_kernel_count, 3)
|
||||
|
||||
@bf32_on_and_off()
|
||||
def test_linear_binary(self, device="cpu"):
|
||||
self.device = device
|
||||
|
||||
|
|
@ -922,6 +931,8 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
dtypes.append(torch.bfloat16)
|
||||
if is_mkldnn_fp16_supported(self.device):
|
||||
dtypes.append(torch.float16)
|
||||
if torch.backends.mkldnn.matmul.fp32_precision == "bf16":
|
||||
dtypes.append(torch.float32)
|
||||
options = itertools.product(
|
||||
binary_list, [[2, 3, 10], [2, 10]], [True, False], dtypes
|
||||
)
|
||||
|
|
@ -958,7 +969,12 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
|||
matcher_check_fn,
|
||||
check_autocast=dtype,
|
||||
)
|
||||
self.assertEqual(metrics.generated_kernel_count, 2 if TEST_ACL else 1)
|
||||
# only generated 1 kernel for "to_dtype"
|
||||
expected_kernel_count = 2 if TEST_ACL else 1
|
||||
if dtype == torch.float32:
|
||||
# In BF32, input is float32, will not generate kernel for "to_dtype"
|
||||
expected_kernel_count -= 1
|
||||
self.assertEqual(metrics.generated_kernel_count, expected_kernel_count)
|
||||
|
||||
def test_linear_binary_broadcast_shapes(self, device="cpu"):
|
||||
self.device = device
|
||||
|
|
|
|||
|
|
@ -1228,10 +1228,15 @@ if torch._C._has_mkldnn:
|
|||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
bf32_matmul_enabled = torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
|
||||
use_bf16_for_fp32_weight = (
|
||||
bf32_matmul_enabled and weight_meta_value.dtype == torch.float32
|
||||
)
|
||||
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
|
||||
# on x86, for fp32, mkl should be enabled and batch_size should not be a free symbol.
|
||||
# on aarch64, use mkldnn op for fp32 as well if acl is enabled
|
||||
if (
|
||||
not is_lp_weight
|
||||
not compute_with_lp
|
||||
and not mkldnn._is_mkldnn_acl_supported()
|
||||
and ((not torch._C.has_mkl) or has_free_symbols(batch_size))
|
||||
):
|
||||
|
|
@ -1444,16 +1449,23 @@ if torch._C._has_mkldnn:
|
|||
torch.bfloat16,
|
||||
torch.float16,
|
||||
)
|
||||
bf32_matmul_enabled = (
|
||||
torch.backends.mkldnn.matmul.fp32_precision == "bf16" # type: ignore[attr-defined]
|
||||
)
|
||||
use_bf16_for_fp32_weight = (
|
||||
bf32_matmul_enabled and weight_dtype == torch.float32
|
||||
)
|
||||
compute_with_lp = is_lp_weight or use_bf16_for_fp32_weight
|
||||
batch_size = input.meta.get("val").shape[0]
|
||||
if has_free_symbols(batch_size):
|
||||
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
|
||||
assert compute_with_lp or mkldnn._is_mkldnn_acl_supported(), (
|
||||
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
||||
)
|
||||
packed_weight_node = mkldnn_device_op.pack_linear_weight(
|
||||
graph, is_lp_weight, transpose_weight_node, batch_size
|
||||
graph, compute_with_lp, transpose_weight_node, batch_size
|
||||
)
|
||||
packed_linear_node = mkldnn_device_op.pack_linear(
|
||||
graph, is_lp_weight, batch_size, input, packed_weight_node, bias
|
||||
graph, compute_with_lp, batch_size, input, packed_weight_node, bias
|
||||
)
|
||||
|
||||
linear_node.replace_all_uses_with(packed_linear_node)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user