diff --git a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp index 7ef9aa5689d..fb117ccc63f 100644 --- a/aten/src/ATen/native/mkldnn/xpu/Blas.cpp +++ b/aten/src/ATen/native/mkldnn/xpu/Blas.cpp @@ -2,6 +2,7 @@ #include #include #include +#include #include #ifndef AT_PER_OPERATOR_HEADERS @@ -50,9 +51,13 @@ Tensor& addmm_out( mat1.dtype(), " != ", mat2.dtype()) + // complex case - TORCH_CHECK( - !mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN"); + if (self.is_complex()) { + at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result); + + return result; + } std::vector result_shape = {mat1.size(0), mat2.size(1)}; result.resize_(result_shape); @@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) { return result; } - TORCH_CHECK( - !self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); + if (self.is_complex()) { + at::native::mm_complex_out_xpu(self, mat2, result); + + return result; + } onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr()); return result; @@ -208,9 +216,12 @@ Tensor& baddbmm_out( input.sizes()); // complex case - TORCH_CHECK( - !batch1.is_complex(), - "Complex datatype matmul is not supported in oneDNN"); + if (input.is_complex()) { + at::native::baddbmm_complex_out_xpu( + input, batch1, batch2, beta, alpha, result); + + return result; + } // general case onednn::Attr attr; @@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) { return result; } - TORCH_CHECK( - !self.is_complex(), "Complex datatype matmul is not supported in oneDNN"); + // complex case + if (self.is_complex()) { + at::native::bmm_complex_out_xpu(self, batch2, result); + + return result; + } + onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr()); return result; } diff --git a/test/inductor/test_torchinductor_opinfo.py b/test/inductor/test_torchinductor_opinfo.py index bfbccda5dd8..1c9b39a1bd0 100644 --- a/test/inductor/test_torchinductor_opinfo.py +++ b/test/inductor/test_torchinductor_opinfo.py @@ -294,7 +294,6 @@ inductor_expected_failures_single_sample["xpu"] = { i32, i64, }, # align with cuda. - "linalg.eig": {f32, f64}, ("linalg.pinv", "singular"): {f64}, # could not create a primitive "addmv": {f64}, diff --git a/test/xpu/test_gemm.py b/test/xpu/test_gemm.py index f2a273ccc33..566c1711532 100644 --- a/test/xpu/test_gemm.py +++ b/test/xpu/test_gemm.py @@ -238,7 +238,7 @@ class TestBasicGEMM(TestCase): ) @precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1}) - @dtypes(torch.float32, torch.half, torch.double) + @dtypes(torch.float32, torch.half, torch.double, torch.complex64) @tf32_on_and_off(0.05) def test_addmm(self, device, dtype): self._test_addmm_impl(torch.addmm, None, device, dtype) @@ -313,6 +313,7 @@ class TestBasicGEMM(TestCase): torch.half, torch.float32, torch.float64, + torch.complex64, ) @tf32_on_and_off(0.05) def test_mm(self, device, dtype): @@ -416,7 +417,7 @@ class TestBasicGEMM(TestCase): _test_mm(n, m, p, dtype, genf) @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) - @dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64) + @dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64, torch.complex64) @tf32_on_and_off(0.05) def test_bmm(self, device, dtype): batch_sizes = [1, 10] @@ -533,7 +534,7 @@ class TestBasicGEMM(TestCase): self.assertEqual(res7, ref) @precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05}) - @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half) + @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64) @tf32_on_and_off(0.005) def test_addbmm(self, device, dtype): num_batches = 2 @@ -637,7 +638,7 @@ class TestBasicGEMM(TestCase): self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor) @precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6}) - @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half) + @dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64) @tf32_on_and_off(0.01) def test_baddbmm(self, device, dtype): num_batches = 10 diff --git a/third_party/xpu.txt b/third_party/xpu.txt index 47097a86a01..4664f1a16dc 100644 --- a/third_party/xpu.txt +++ b/third_party/xpu.txt @@ -1 +1 @@ -ce9db15136c5e8ba1b51710aae574ce4791c5d73 +30dcaa83183b189b00dc827fb39234714fe4e46d