mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Intel GPU] Xpu matmul implementation for complex dtype (#160867)
Enabling complex datatype support for 4 ops: `mm`, `bmm`, `addmm`, `baddbmm` for XPU. From now implementation will call functions created in: https://github.com/intel/torch-xpu-ops/pull/1992. Additionally added complex datatype tests for matmul operators. More detailed tests are going to be enabled in: https://github.com/intel/torch-xpu-ops/pull/1993 CI runs have found that `test_comprehensive_linalg_eig_xpu` tests were calling internally matmul with complex datatype. With this PR test starts to pass so linalg.eig was removed from `inductor_expected_failures_single_sample["xpu"]` as otherwise it was failing with: `Unexpected success` message. Part of: https://github.com/intel/torch-xpu-ops/issues/1853 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160867 Approved by: https://github.com/guangyey, https://github.com/ZhiweiYan-96, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/Silv3S, https://github.com/CuiYifeng, https://github.com/jansel
This commit is contained in:
parent
516e58965a
commit
d97f6550a2
|
|
@ -2,6 +2,7 @@
|
|||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||
#include <ATen/native/xpu/Blas.h>
|
||||
#include <torch/library.h>
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
|
|
@ -50,9 +51,13 @@ Tensor& addmm_out(
|
|||
mat1.dtype(),
|
||||
" != ",
|
||||
mat2.dtype())
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||
result.resize_(result_shape);
|
||||
|
|
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
if (self.is_complex()) {
|
||||
at::native::mm_complex_out_xpu(self, mat2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
|
|
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
|
|||
input.sizes());
|
||||
|
||||
// complex case
|
||||
TORCH_CHECK(
|
||||
!batch1.is_complex(),
|
||||
"Complex datatype matmul is not supported in oneDNN");
|
||||
if (input.is_complex()) {
|
||||
at::native::baddbmm_complex_out_xpu(
|
||||
input, batch1, batch2, beta, alpha, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
// general case
|
||||
onednn::Attr attr;
|
||||
|
|
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
||||
// complex case
|
||||
if (self.is_complex()) {
|
||||
at::native::bmm_complex_out_xpu(self, batch2, result);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,7 +294,6 @@ inductor_expected_failures_single_sample["xpu"] = {
|
|||
i32,
|
||||
i64,
|
||||
}, # align with cuda.
|
||||
"linalg.eig": {f32, f64},
|
||||
("linalg.pinv", "singular"): {f64},
|
||||
# could not create a primitive
|
||||
"addmv": {f64},
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class TestBasicGEMM(TestCase):
|
|||
)
|
||||
|
||||
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
|
||||
@dtypes(torch.float32, torch.half, torch.double)
|
||||
@dtypes(torch.float32, torch.half, torch.double, torch.complex64)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_addmm(self, device, dtype):
|
||||
self._test_addmm_impl(torch.addmm, None, device, dtype)
|
||||
|
|
@ -313,6 +313,7 @@ class TestBasicGEMM(TestCase):
|
|||
torch.half,
|
||||
torch.float32,
|
||||
torch.float64,
|
||||
torch.complex64,
|
||||
)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_mm(self, device, dtype):
|
||||
|
|
@ -416,7 +417,7 @@ class TestBasicGEMM(TestCase):
|
|||
_test_mm(n, m, p, dtype, genf)
|
||||
|
||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64)
|
||||
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64, torch.complex64)
|
||||
@tf32_on_and_off(0.05)
|
||||
def test_bmm(self, device, dtype):
|
||||
batch_sizes = [1, 10]
|
||||
|
|
@ -533,7 +534,7 @@ class TestBasicGEMM(TestCase):
|
|||
self.assertEqual(res7, ref)
|
||||
|
||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
|
||||
@tf32_on_and_off(0.005)
|
||||
def test_addbmm(self, device, dtype):
|
||||
num_batches = 2
|
||||
|
|
@ -637,7 +638,7 @@ class TestBasicGEMM(TestCase):
|
|||
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
|
||||
|
||||
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
|
||||
@tf32_on_and_off(0.01)
|
||||
def test_baddbmm(self, device, dtype):
|
||||
num_batches = 10
|
||||
|
|
|
|||
2
third_party/xpu.txt
vendored
2
third_party/xpu.txt
vendored
|
|
@ -1 +1 @@
|
|||
ce9db15136c5e8ba1b51710aae574ce4791c5d73
|
||||
30dcaa83183b189b00dc827fb39234714fe4e46d
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user