[Intel GPU] Xpu matmul implementation for complex dtype (#160867)

Enabling complex datatype support for 4 ops: `mm`, `bmm`, `addmm`, `baddbmm` for XPU. From now implementation will call functions created in: https://github.com/intel/torch-xpu-ops/pull/1992.

Additionally added complex datatype tests for matmul operators. More detailed tests are going to be enabled in: https://github.com/intel/torch-xpu-ops/pull/1993

CI runs have found that `test_comprehensive_linalg_eig_xpu` tests were calling internally matmul with complex datatype. With this PR test starts to pass so linalg.eig was removed from `inductor_expected_failures_single_sample["xpu"]` as otherwise it was failing with: `Unexpected success` message.

Part of: https://github.com/intel/torch-xpu-ops/issues/1853

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160867
Approved by: https://github.com/guangyey, https://github.com/ZhiweiYan-96, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/Silv3S, https://github.com/CuiYifeng, https://github.com/jansel
This commit is contained in:
Pawel Swider 2025-10-25 17:13:10 +00:00 committed by PyTorch MergeBot
parent 516e58965a
commit d97f6550a2
4 changed files with 31 additions and 15 deletions

View File

@ -2,6 +2,7 @@
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/native/Resize.h>
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
#include <ATen/native/xpu/Blas.h>
#include <torch/library.h>
#ifndef AT_PER_OPERATOR_HEADERS
@ -50,9 +51,13 @@ Tensor& addmm_out(
mat1.dtype(),
" != ",
mat2.dtype())
// complex case
TORCH_CHECK(
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
return result;
}
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
result.resize_(result_shape);
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
if (self.is_complex()) {
at::native::mm_complex_out_xpu(self, mat2, result);
return result;
}
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
return result;
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
input.sizes());
// complex case
TORCH_CHECK(
!batch1.is_complex(),
"Complex datatype matmul is not supported in oneDNN");
if (input.is_complex()) {
at::native::baddbmm_complex_out_xpu(
input, batch1, batch2, beta, alpha, result);
return result;
}
// general case
onednn::Attr attr;
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
return result;
}
TORCH_CHECK(
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
// complex case
if (self.is_complex()) {
at::native::bmm_complex_out_xpu(self, batch2, result);
return result;
}
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
return result;
}

View File

@ -294,7 +294,6 @@ inductor_expected_failures_single_sample["xpu"] = {
i32,
i64,
}, # align with cuda.
"linalg.eig": {f32, f64},
("linalg.pinv", "singular"): {f64},
# could not create a primitive
"addmv": {f64},

View File

@ -238,7 +238,7 @@ class TestBasicGEMM(TestCase):
)
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
@dtypes(torch.float32, torch.half, torch.double)
@dtypes(torch.float32, torch.half, torch.double, torch.complex64)
@tf32_on_and_off(0.05)
def test_addmm(self, device, dtype):
self._test_addmm_impl(torch.addmm, None, device, dtype)
@ -313,6 +313,7 @@ class TestBasicGEMM(TestCase):
torch.half,
torch.float32,
torch.float64,
torch.complex64,
)
@tf32_on_and_off(0.05)
def test_mm(self, device, dtype):
@ -416,7 +417,7 @@ class TestBasicGEMM(TestCase):
_test_mm(n, m, p, dtype, genf)
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64)
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64, torch.complex64)
@tf32_on_and_off(0.05)
def test_bmm(self, device, dtype):
batch_sizes = [1, 10]
@ -533,7 +534,7 @@ class TestBasicGEMM(TestCase):
self.assertEqual(res7, ref)
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
@tf32_on_and_off(0.005)
def test_addbmm(self, device, dtype):
num_batches = 2
@ -637,7 +638,7 @@ class TestBasicGEMM(TestCase):
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
@tf32_on_and_off(0.01)
def test_baddbmm(self, device, dtype):
num_batches = 10

2
third_party/xpu.txt vendored
View File

@ -1 +1 @@
ce9db15136c5e8ba1b51710aae574ce4791c5d73
30dcaa83183b189b00dc827fb39234714fe4e46d