mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Intel GPU] Xpu matmul implementation for complex dtype (#160867)
Enabling complex datatype support for 4 ops: `mm`, `bmm`, `addmm`, `baddbmm` for XPU. From now implementation will call functions created in: https://github.com/intel/torch-xpu-ops/pull/1992. Additionally added complex datatype tests for matmul operators. More detailed tests are going to be enabled in: https://github.com/intel/torch-xpu-ops/pull/1993 CI runs have found that `test_comprehensive_linalg_eig_xpu` tests were calling internally matmul with complex datatype. With this PR test starts to pass so linalg.eig was removed from `inductor_expected_failures_single_sample["xpu"]` as otherwise it was failing with: `Unexpected success` message. Part of: https://github.com/intel/torch-xpu-ops/issues/1853 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160867 Approved by: https://github.com/guangyey, https://github.com/ZhiweiYan-96, https://github.com/gujinghui, https://github.com/EikanWang, https://github.com/Silv3S, https://github.com/CuiYifeng, https://github.com/jansel
This commit is contained in:
parent
516e58965a
commit
d97f6550a2
|
|
@ -2,6 +2,7 @@
|
||||||
#include <ATen/WrapDimUtilsMulti.h>
|
#include <ATen/WrapDimUtilsMulti.h>
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
#include <ATen/native/mkldnn/xpu/detail/oneDNN.h>
|
||||||
|
#include <ATen/native/xpu/Blas.h>
|
||||||
#include <torch/library.h>
|
#include <torch/library.h>
|
||||||
#ifndef AT_PER_OPERATOR_HEADERS
|
#ifndef AT_PER_OPERATOR_HEADERS
|
||||||
|
|
||||||
|
|
@ -50,9 +51,13 @@ Tensor& addmm_out(
|
||||||
mat1.dtype(),
|
mat1.dtype(),
|
||||||
" != ",
|
" != ",
|
||||||
mat2.dtype())
|
mat2.dtype())
|
||||||
|
|
||||||
// complex case
|
// complex case
|
||||||
TORCH_CHECK(
|
if (self.is_complex()) {
|
||||||
!mat1.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
at::native::addmm_complex_out_xpu(self, mat1, mat2, beta, alpha, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
std::vector<int64_t> result_shape = {mat1.size(0), mat2.size(1)};
|
||||||
result.resize_(result_shape);
|
result.resize_(result_shape);
|
||||||
|
|
@ -167,8 +172,11 @@ Tensor& mm_out(const Tensor& self, const Tensor& mat2, Tensor& result) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(
|
if (self.is_complex()) {
|
||||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
at::native::mm_complex_out_xpu(self, mat2, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
onednn::matmul(result, self, mat2, Tensor(), true, onednn::Attr());
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -208,9 +216,12 @@ Tensor& baddbmm_out(
|
||||||
input.sizes());
|
input.sizes());
|
||||||
|
|
||||||
// complex case
|
// complex case
|
||||||
TORCH_CHECK(
|
if (input.is_complex()) {
|
||||||
!batch1.is_complex(),
|
at::native::baddbmm_complex_out_xpu(
|
||||||
"Complex datatype matmul is not supported in oneDNN");
|
input, batch1, batch2, beta, alpha, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
// general case
|
// general case
|
||||||
onednn::Attr attr;
|
onednn::Attr attr;
|
||||||
|
|
@ -257,8 +268,13 @@ Tensor& bmm_out(const Tensor& self, const Tensor& batch2, Tensor& result) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_CHECK(
|
// complex case
|
||||||
!self.is_complex(), "Complex datatype matmul is not supported in oneDNN");
|
if (self.is_complex()) {
|
||||||
|
at::native::bmm_complex_out_xpu(self, batch2, result);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
onednn::matmul(result, self, batch2, at::Tensor(), true, onednn::Attr());
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -294,7 +294,6 @@ inductor_expected_failures_single_sample["xpu"] = {
|
||||||
i32,
|
i32,
|
||||||
i64,
|
i64,
|
||||||
}, # align with cuda.
|
}, # align with cuda.
|
||||||
"linalg.eig": {f32, f64},
|
|
||||||
("linalg.pinv", "singular"): {f64},
|
("linalg.pinv", "singular"): {f64},
|
||||||
# could not create a primitive
|
# could not create a primitive
|
||||||
"addmv": {f64},
|
"addmv": {f64},
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ class TestBasicGEMM(TestCase):
|
||||||
)
|
)
|
||||||
|
|
||||||
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
|
@precisionOverride({torch.float: 1e-4, torch.double: 1e-6, torch.half: 1e-1})
|
||||||
@dtypes(torch.float32, torch.half, torch.double)
|
@dtypes(torch.float32, torch.half, torch.double, torch.complex64)
|
||||||
@tf32_on_and_off(0.05)
|
@tf32_on_and_off(0.05)
|
||||||
def test_addmm(self, device, dtype):
|
def test_addmm(self, device, dtype):
|
||||||
self._test_addmm_impl(torch.addmm, None, device, dtype)
|
self._test_addmm_impl(torch.addmm, None, device, dtype)
|
||||||
|
|
@ -313,6 +313,7 @@ class TestBasicGEMM(TestCase):
|
||||||
torch.half,
|
torch.half,
|
||||||
torch.float32,
|
torch.float32,
|
||||||
torch.float64,
|
torch.float64,
|
||||||
|
torch.complex64,
|
||||||
)
|
)
|
||||||
@tf32_on_and_off(0.05)
|
@tf32_on_and_off(0.05)
|
||||||
def test_mm(self, device, dtype):
|
def test_mm(self, device, dtype):
|
||||||
|
|
@ -416,7 +417,7 @@ class TestBasicGEMM(TestCase):
|
||||||
_test_mm(n, m, p, dtype, genf)
|
_test_mm(n, m, p, dtype, genf)
|
||||||
|
|
||||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||||
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64)
|
@dtypes(torch.float32, torch.bfloat16, torch.half, torch.float64, torch.complex64)
|
||||||
@tf32_on_and_off(0.05)
|
@tf32_on_and_off(0.05)
|
||||||
def test_bmm(self, device, dtype):
|
def test_bmm(self, device, dtype):
|
||||||
batch_sizes = [1, 10]
|
batch_sizes = [1, 10]
|
||||||
|
|
@ -533,7 +534,7 @@ class TestBasicGEMM(TestCase):
|
||||||
self.assertEqual(res7, ref)
|
self.assertEqual(res7, ref)
|
||||||
|
|
||||||
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
@precisionOverride({torch.half: 0.05, torch.bfloat16: 0.05})
|
||||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
|
||||||
@tf32_on_and_off(0.005)
|
@tf32_on_and_off(0.005)
|
||||||
def test_addbmm(self, device, dtype):
|
def test_addbmm(self, device, dtype):
|
||||||
num_batches = 2
|
num_batches = 2
|
||||||
|
|
@ -637,7 +638,7 @@ class TestBasicGEMM(TestCase):
|
||||||
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
|
self._test_addbmm_baddbmm("addbmm", b1, b2, ref, out_tensor)
|
||||||
|
|
||||||
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
|
@precisionOverride({torch.half: 0.1, torch.bfloat16: 0.5, torch.float64: 1e-6})
|
||||||
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half)
|
@dtypes(torch.float64, torch.float32, torch.bfloat16, torch.half, torch.complex64)
|
||||||
@tf32_on_and_off(0.01)
|
@tf32_on_and_off(0.01)
|
||||||
def test_baddbmm(self, device, dtype):
|
def test_baddbmm(self, device, dtype):
|
||||||
num_batches = 10
|
num_batches = 10
|
||||||
|
|
|
||||||
2
third_party/xpu.txt
vendored
2
third_party/xpu.txt
vendored
|
|
@ -1 +1 @@
|
||||||
ce9db15136c5e8ba1b51710aae574ce4791c5d73
|
30dcaa83183b189b00dc827fb39234714fe4e46d
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user