diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index 396b9746754..95f9029c8df 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -595,12 +595,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) { KERNEL_CPU(ADD_NS(linalg_tensorsolve), "linalg_tensorsolve", Tensor(const Tensor &, const Tensor &, at::OptionalIntArrayRef), fp32) KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32) - m.impl(TORCH_SELECTIVE_NAME("aten::eig"), - TORCH_FN((&WrapFunction (const Tensor &, bool), - std::tuple (const Tensor &, bool), - &ADD_NS(eig)>::type::call))); - m.impl(TORCH_SELECTIVE_NAME("aten::geqrf"), TORCH_FN((&WrapFunction (const Tensor &), diff --git a/aten/src/ATen/native/BatchLinearAlgebra.cpp b/aten/src/ATen/native/BatchLinearAlgebra.cpp index 7464e12fd7d..09bffa1a743 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebra.cpp @@ -3168,66 +3168,6 @@ Tensor linalg_eigvals(const Tensor& input) { return values; } -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -DEFINE_DISPATCH(eig_stub); - -std::tuple eig_out(const Tensor& self, bool eigenvectors, Tensor& e, Tensor& v) { - TORCH_WARN_ONCE( - "torch.eig is deprecated in favor of torch.linalg.eig and will be removed in a future ", - "PyTorch release.\n", - "torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors ", - "mimicking complex tensors.\n", - "L, _ = torch.eig(A)\n", - "should be replaced with\n", - "L_complex = torch.linalg.eigvals(A)\n", - "and\n", - "L, V = torch.eig(A, eigenvectors=True)\n", - "should be replaced with\n", - "L_complex, V_complex = torch.linalg.eig(A)" - ); - TORCH_CHECK(self.dim() == 2, "input should be 2 dimensional"); - TORCH_CHECK(self.size(0) == self.size(1), "input should be square"); - TORCH_CHECK(self.isfinite().all().item(), "input should not contain infs or NaNs"); - checkSameDevice("torch.eig", e, self, "eigenvalues"); - checkLinalgCompatibleDtype("torch.eig", e, self, "eigenvalues"); - if (eigenvectors) { - checkSameDevice("torch.eig", v, self, "eigenvectors"); - checkLinalgCompatibleDtype("torch.eig", v, self, "eigenvectors"); - } - int64_t n = self.size(-1); - - if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { - at::native::resize_output(e, {n}); - } else { - at::native::resize_output(e, {n, 2}); - } - if (eigenvectors) { - at::native::resize_output(v, self.sizes()); - } - - // optimization: if self is empty, we can immediately return the empty - // tensors, instead of getting empty tensors from eig_helper - if (self.numel() == 0) { - return std::tuple(e, v); - } - - Tensor vals_, vecs_; - std::tie(vals_, vecs_) = eig_stub(self.device().type(), self, eigenvectors); - e.copy_(vals_); - if (eigenvectors) { - v.copy_(vecs_); - } - return std::tuple(e, v); -} - -std::tuple eig(const Tensor& self, bool eigenvectors) { - Tensor e = at::empty({0}, self.options()); - Tensor v = at::empty({0}, self.options()); - at::eig_out(e, v, self, eigenvectors); - return std::tuple(e, v); -} - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /* torch.svd, implemented in terms of torch.linalg.svd. There are two main diff --git a/aten/src/ATen/native/BatchLinearAlgebra.h b/aten/src/ATen/native/BatchLinearAlgebra.h index 531595f3544..a86be95f40b 100644 --- a/aten/src/ATen/native/BatchLinearAlgebra.h +++ b/aten/src/ATen/native/BatchLinearAlgebra.h @@ -231,10 +231,6 @@ using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, b DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub); -using eig_fn = std::tuple (*)(const Tensor&, bool&); - -DECLARE_DISPATCH(eig_fn, eig_stub); - using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/); DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub); diff --git a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp index 5b18dbe2d5f..3fe9fc13769 100644 --- a/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp +++ b/aten/src/ATen/native/BatchLinearAlgebraKernel.cpp @@ -127,87 +127,6 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper) return result; } -template -void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int* info_ptr) { -#if !AT_BUILD_WITH_LAPACK() - TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ", - "PyTorch with LAPACK. Please use PyTorch built with LAPACK support."); -#else - using value_t = typename c10::scalar_value_type::type; - - char jobvr = eigenvectors ? 'V' : 'N'; - int64_t n = self.size(-1); - auto self_data = self.data_ptr(); - - auto vals_data = vals_.data_ptr(); - scalar_t* wr = vals_data; - - scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr() : nullptr; - // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions) - int ldvr = eigenvectors ? n : 1; - - Tensor rwork; - value_t* rwork_data = nullptr; - if (self.is_complex()) { - ScalarType real_dtype = toRealValueType(typeMetaToScalarType(self.dtype())); - rwork = at::empty({n*2}, self.options().dtype(real_dtype)); - rwork_data = rwork.data_ptr(); - } - - if (n > 0) { - // call lapackEig once to get the optimal size for work data - scalar_t wkopt; - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, info_ptr); - int lwork = std::max(1, real_impl(wkopt)); - - // call again to do the actual work - Tensor work = at::empty({lwork}, self.dtype()); - lapackEig('N', jobvr, n, self_data, n, wr, - nullptr, 1, vecs_data, ldvr, work.data_ptr(), lwork, rwork_data, info_ptr); - } -#endif -} - -std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { - int64_t n = self.size(-1); - // lapackEig function expects the input to be column major, or stride {1, n}, - // so we must set the stride manually since the default stride for tensors is - // row major, {n, 1} - Tensor self_ = at::empty_strided( - {n, n}, - {1, n}, - at::TensorOptions(self.dtype())); - self_.copy_(self); - - auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - - // the API is slightly different for the complex vs real case: if the input - // is complex, eigenvals will be a vector of complex. If the input is real, - // eigenvals will be a (n, 2) matrix containing the real and imaginary parts - // in each column - Tensor vals_; - if (self.is_complex()) { - vals_ = at::empty({n}, options); - } else { - vals_ = at::empty_strided({n, 2}, {1, n}, options); - } - Tensor vecs_ = eigenvectors - ? at::empty_strided({n, n}, {1, n}, options) - : Tensor(); - - // NOLINTNEXTLINE(cppcoreguidelines-init-variables) - auto infos = at::zeros({}, self.options().dtype(kInt)); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{ - apply_eig(self_, eigenvectors, vals_, vecs_, infos.data_ptr()); - }); - // NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage) - at::_linalg_check_errors(infos, "eig", /*is_matrix*/true); - - return std::tuple(vals_, vecs_); -} - /* Computes the eigenvalues and eigenvectors of n-by-n matrix 'input'. This is an in-place routine, content of 'input', 'values', 'vectors' is overwritten. @@ -1200,12 +1119,6 @@ REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl); -REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl); -REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl); -REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl); -REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl); -REGISTER_ZVECTOR_DISPATCH(eig_stub, &eig_kernel_impl); - REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel); REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel); diff --git a/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp b/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp index cb6cacb3630..f5816c8c674 100644 --- a/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp +++ b/aten/src/ATen/native/cuda/LinearAlgebraStubs.cpp @@ -93,11 +93,6 @@ void lazy_linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvecto linalg_eigh_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, upper, compute_eigenvectors); } -std::tuple lazy_eig_kernel(const Tensor& self, bool& eigenvectors) { - loadLazyTorchLinalgLibrary(); - return eig_stub(DeviceType::CUDA, self, eigenvectors); -} - void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) { getTorchLinalgLibrary(); linalg_eig_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, input, compute_eigenvectors); @@ -155,7 +150,6 @@ REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel); REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel); REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel); REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel); -REGISTER_CUDA_DISPATCH(eig_stub, &lazy_eig_kernel); REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel); REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel) REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve); diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp index 061e7e86de8..a7d379ec462 100644 --- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp +++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebra.cpp @@ -2036,96 +2036,6 @@ void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, c REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel); -// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -// magmaEig uses a hybrid CPU-GPU algorithm, which takes and return CPU -// memory. So, we accept a GPU tensor, copy it to CPU memory, and later copy -// the returned values from CPU to GPU. See also magmaSymeig, which uses a -// similar approach. - -template -static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs, - int* info_ptr) { -#if !AT_MAGMA_ENABLED() -TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. " - "Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA."); -#else - TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor"); - using value_t = typename c10::scalar_value_type::type; - magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec; - magma_int_t n = magma_int_cast(self.size(-1), "n"); - auto self_data = self.data_ptr(); - - auto out_eigvals_data = out_eigvals.data_ptr(); - scalar_t *wr = out_eigvals_data; - - scalar_t *vr_data = NULL; - magma_int_t ldvr = 1; - if (jobvr == MagmaVec) - { - vr_data = out_eigvecs.data_ptr(); - ldvr = n; - } - - value_t *rwork_data = nullptr; - if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { - ALLOCATE_ARRAY(rwork_data, value_t, n*2); - } - - if (n > 0) { - // call magmaEig once to get the optimal size of work_data - scalar_t wkopt; - magma_int_t info; - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info); - magma_int_t lwork = static_cast(real_impl(wkopt)); - - // call it a 2nd time to to the actual work - scalar_t *work_data = nullptr; - ALLOCATE_ARRAY(work_data, scalar_t, lwork); - magmaEig(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info); - *info_ptr = info; - } -#endif -} - -/* - * Internal helper; like eig_cuda but: - * 1. assume that self is a square matrix of side "n" - * 2. return CPU tensors (because this is what magmaEig returns), which will be copied to GPU memory - * by the caller - */ -std::tuple eig_kernel_impl(const Tensor& self, bool& eigenvectors) { - int64_t n = self.size(-1); - // copy self to pinned CPU memory - auto self_working_copy = at::empty_strided( - {n, n}, // square matrix - {1, n}, // column-ordered, as magmaEig expects - at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true)); - self_working_copy.copy_(self); - - // tensors holding the results. We use empty_strided to make them column-ordered - auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT); - Tensor out_eigvals; - if (isComplexType(at::typeMetaToScalarType(self.dtype()))) { - out_eigvals = at::empty({n}, options); - } else { - out_eigvals = at::empty_strided({n, 2}, {1, n}, options); - } - auto out_eigvecs = eigenvectors - ? at::empty_strided({n, n}, {1, n}, options) - : Tensor(); - - auto infos = at::zeros({}, self_working_copy.options().dtype(kInt)); - AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{ - apply_eig(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, infos.data_ptr()); - }); - at::_linalg_check_errors(infos, "eig", /*is_matrix*/true); - - return std::tuple(out_eigvals, out_eigvecs); -} - -REGISTER_CUDA_DISPATCH(eig_stub, &eig_kernel_impl); - // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ /* diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index d3aeeb00f84..b95623e9bda 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -8113,15 +8113,6 @@ CUDA: _symeig_helper_cuda autogen: _symeig_helper.out -- func: eig.e(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) - dispatch: - CompositeExplicitAutograd: eig_out - -- func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors) - variants: method, function - dispatch: - CompositeExplicitAutograd: eig - - func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) - func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 9c4264316fd..467a26a02df 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -345,7 +345,6 @@ Tensor class reference Tensor.dot Tensor.double Tensor.dsplit - Tensor.eig Tensor.element_size Tensor.eq Tensor.eq_ diff --git a/docs/source/torch.rst b/docs/source/torch.rst index a530c5af136..3defab649a9 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -558,7 +558,6 @@ BLAS and LAPACK Operations cholesky_inverse cholesky_solve dot - eig geqrf ger inner diff --git a/functorch/test/test_ops.py b/functorch/test/test_ops.py index 03d9a6c25a6..28605b422ba 100644 --- a/functorch/test/test_ops.py +++ b/functorch/test/test_ops.py @@ -688,7 +688,6 @@ class TestOperators(TestCase): skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule xfail('__getitem__', ''), # dynamic error xfail('_masked.prod'), # calls aten::item - xfail('eig'), # calls aten::item xfail('linalg.eig'), # Uses aten::allclose xfail('linalg.householder_product'), # needs select_scatter xfail('nanquantile', device_type='cpu'), # checks q via a .item() call @@ -923,7 +922,6 @@ class TestOperators(TestCase): xfail('cummax'), xfail('cummin'), xfail('cumprod'), - xfail('eig'), xfail('nansum'), xfail('nanmean'), xfail('special.log_ndtr'), @@ -1142,7 +1140,6 @@ class TestOperators(TestCase): xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data xfail('cdist', ''), # NYI: forward-AD for _cdist_forward xfail('cholesky', ''), # NYI: forward-AD for cholesky - xfail('eig', ''), # NYI: forward-AD for eig xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp xfail('nn.functional.embedding_bag', ''), # NYI: forward-AD for _embedding_bag xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py index e8e5d8c254f..af5c5e12faf 100644 --- a/functorch/test/test_vmap.py +++ b/functorch/test/test_vmap.py @@ -3294,7 +3294,6 @@ class TestVmapOperatorsOpInfo(TestCase): skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format xfail('complex'), xfail('copysign'), - xfail('eig'), xfail('histogram'), xfail('index_fill'), xfail('nansum'), diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 71560c5c055..32388cf6681 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -61,6 +61,8 @@ ALLOW_LIST = [ ("aten::slice_backward", datetime.date(9999, 1, 1)), ("aten::diagonal_backward", datetime.date(9999, 1, 1)), ("aten::rowwise_prune", datetime.date(9999, 1, 1)), + ("aten::eig", datetime.date(9999, 1, 1)), + ("aten::eig.e", datetime.date(9999, 1, 1)), ("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)), ("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)), ("aten::randperm", datetime.date(9999, 1, 1)), diff --git a/test/test_autograd.py b/test/test_autograd.py index 341aea5c919..6da0c9f3062 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3744,20 +3744,6 @@ class TestAutograd(TestCase): out.backward() # TODO: update these tests to use the linalg module and move to test_linalg.py - @skipIfNoLapack - def test_eig_no_eigenvectors(self): - A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True) - w, v = torch.eig(A, eigenvectors=False) - with self.assertRaisesRegex(RuntimeError, 'is not differentiable'): - torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)]) - - @skipIfNoLapack - def test_eig_complex_eigenvalues(self): - A = torch.tensor([[0., -1.], [1., 0.]], dtype=torch.float32, requires_grad=True) - w, v = torch.eig(A, eigenvectors=True) - with self.assertRaisesRegex(RuntimeError, 'does not support complex eigenvalues'): - torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)]) - @skipIfNoLapack def test_symeig_no_eigenvectors(self): A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True) diff --git a/test/test_linalg.py b/test/test_linalg.py index fce5da2f42b..c546da5e85f 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -148,6 +148,13 @@ class TestLinalg(TestCase): with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): b.solve(a) + def test_eig_removed_error(self, device): + a = make_tensor(5, 5, device=device, dtype=torch.float32) + with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): + torch.eig(a) + with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"): + a.eig() + @skipCUDAIfNoMagma @skipCPUIfNoLapack @dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble) @@ -1758,122 +1765,6 @@ class TestLinalg(TestCase): expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0) self.assertEqual(result, expected) - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(*floating_and_complex_types()) - def test_old_eig_basic(self, device, dtype): - a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00], - [-6.49, 3.80, 0.00, 0.00, 0.00], - [-0.47, -6.39, 4.17, 0.00, 0.00], - [-7.20, 1.50, -1.51, 5.70, 0.00], - [-0.65, -6.34, 2.67, 1.80, -7.10]], - dtype=dtype, device=device).t() - e = torch.eig(a)[0] - ee, vv = torch.eig(a, True) - te = torch.tensor((), dtype=dtype, device=device) - tv = torch.tensor((), dtype=dtype, device=device) - eee, vvv = torch.eig(a, True, out=(te, tv)) - self.assertEqual(e, ee, atol=1e-12, rtol=0) - self.assertEqual(ee, eee, atol=1e-12, rtol=0) - self.assertEqual(ee, te, atol=1e-12, rtol=0) - self.assertEqual(vv, vvv, atol=1e-12, rtol=0) - self.assertEqual(vv, tv, atol=1e-12, rtol=0) - # - # compare with numpy - np_e, np_v = np.linalg.eig(a.cpu().numpy()) - if dtype.is_complex: - self.assertEqual(ee, np_e) - else: - # np_e.shape == (n, 2), where each column contain the real and - # imaginary parts of the result - self.assertEqual(ee[:, 0], np_e) # real part - self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part - self.assertEqual(vv, np_v) - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double, torch.float) - def test_old_eig_reuse(self, device, dtype): - X = torch.randn(4, 4, dtype=dtype, device=device) - X = torch.mm(X.t(), X) - e = torch.zeros(4, 2, dtype=dtype, device=device) - v = torch.zeros(4, 4, dtype=dtype, device=device) - torch.eig(X, True, out=(e, v)) - Xhat = np.matmul(np.matmul(v.cpu(), torch.diag(e.select(1, 0)).cpu()), v.t().cpu()) - if dtype is torch.float: - atol = 1e-7 - rtol = 1e-5 - else: - atol = 1e-8 - rtol = 0 - self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') - self.assertTrue(v.is_contiguous(), 'V is not contiguous') - - torch.eig(X, True, out=(e, v)) - Xhat = np.matmul(v.cpu(), np.matmul(e.select(1, 0).diag().cpu(), v.t().cpu())) - self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong') - self.assertTrue(v.is_contiguous(), 'V is not contiguous') - - @skipCPUIfNoLapack - @skipCUDAIfNoMagma - @dtypes(torch.double, torch.float) - def test_old_eig_invalid_input(self, device, dtype): - # test invalid input - self.assertRaisesRegex( - RuntimeError, - 'input should be 2 dimensional', - lambda: torch.eig(torch.ones((2)))) - self.assertRaisesRegex( - RuntimeError, - 'input should be square', - lambda: torch.eig(torch.ones((2, 3)))) - self.assertRaisesRegex( - RuntimeError, - 'input should not contain infs or NaNs', - lambda: torch.eig(np.inf * torch.ones((2, 2)))) - self.assertRaisesRegex( - RuntimeError, - 'input should not contain infs or NaNs', - lambda: torch.eig(np.nan * torch.ones((2, 2)))) - - @skipCUDAIfNoMagma - @skipCPUIfNoLapack - @dtypes(torch.double, torch.float) - def test_old_eig_out(self, device, dtype): - # the out version of torch.eig needs to be tested manually: we can't - # use the "test_out=True" parameter to tensor_op_tests because the - # signature is irregular (since we have *two* output vectors) - t = torch.randn(10, 10, dtype=dtype, device=device) - evals, evecs = torch.eig(t, eigenvectors=True) - # - # check that the out= version computes the same values as the normal one - out_evals = torch.empty_like(evals) - out_evecs = torch.empty_like(evecs) - evals2, evecs2 = torch.eig(t, eigenvectors=True, out=(out_evals, out_evecs)) - # check that the out tensors were used in-place - self.assertEqual(evals2.data_ptr(), out_evals.data_ptr()) - self.assertEqual(evecs2.data_ptr(), out_evecs.data_ptr()) - # check that the result is the same as the non-out version - self.assertEqual(evals, out_evals) - self.assertEqual(evecs, out_evecs) - # - # check what happens in the eigenvectors=False case - out_evals = torch.empty_like(evals) - out_evecs = torch.tensor([1, 2, 3], dtype=dtype, device=device) - evals2, evecs2 = torch.eig(t, eigenvectors=False, out=(out_evals, out_evecs)) - # check that the out_evals was used in-place - self.assertEqual(evals2.data_ptr(), out_evals.data_ptr()) - self.assertEqual(evals, out_evals) - # check that out_evecs was NOT touched at all - assert out_evecs.tolist() == [1, 2, 3] - # - # check that we complain if we pass an out vector of the wrong dtype - wrong_out = torch.empty((0, 0), dtype=int) - with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"): - torch.eig(t, eigenvectors=True, out=(wrong_out, out_evecs)) - with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"): - torch.eig(t, eigenvectors=True, out=(out_evals, wrong_out)) - @skipCPUIfNoLapack @skipCUDAIfNoMagma # NumPy computes only in float64 and complex128 precisions @@ -7407,12 +7298,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)), fn(torch.slogdet, (0, 0))) - # eig, symeig - evalues, evectors = fn(torch.eig, (0, 0), True) - self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape]) - evalues, evectors = fn(torch.symeig, (0, 0), True) - self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape]) - # lstsq self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0))) self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0))) diff --git a/test/test_meta.py b/test/test_meta.py index eb4bfb566d8..2697a6339a2 100644 --- a/test/test_meta.py +++ b/test/test_meta.py @@ -443,7 +443,6 @@ meta_function_expected_failures = { torch.cholesky : {f64, f32, c128, c64}, torch.cholesky_inverse : {f64, f32, c128, c64}, torch.cholesky_solve : {f64, f32, c128, c64}, - torch.eig : {f64, f32, c128, c64}, torch.linalg.eig : {f64, f32, c128, c64}, torch.linalg.eigvals : {f64, f32, c128, c64}, torch.linalg.lstsq : {f64, f32, c128, c64}, @@ -633,7 +632,6 @@ meta_dispatch_expected_failures = { aten.cholesky_solve.out : {c64, c128, f64, f32}, aten.count_nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, aten.count_nonzero.dim_IntList : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8}, - aten.eig.default : {c64, c128, f64, f32}, aten.geqrf.default : {c64, c128, f64, f32}, aten.linalg_eig.default : {c64, c128, f64, f32}, aten.linalg_householder_product.default : {c64, c128, f64, f32}, diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 65d90625e5d..00409cf8dd4 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -13,7 +13,7 @@ from collections import namedtuple path = os.path.dirname(os.path.realpath(__file__)) aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml') all_operators_with_namedtuple_return = { - 'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', + 'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'qr', 'geqrf', 'slogdet', 'sort', 'topk', 'lstsq', 'linalg_inv_ex', 'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_linalg_eigh", "_unpack_dual", 'linalg_qr', 'linalg_svd', '_linalg_svd', 'linalg_slogdet', '_linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask', @@ -77,7 +77,7 @@ class TestNamedTupleAPI(TestCase): op(operators=['_linalg_slogdet'], input=(), names=('sign', 'logabsdet', 'LU', 'pivots'), hasout=True), op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True), op(operators=['geqrf'], input=(), names=('a', 'tau'), hasout=True), - op(operators=['symeig', 'eig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True), + op(operators=['symeig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True), op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True), op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True), op(operators=['linalg_eig'], input=(), names=('eigenvalues', 'eigenvectors'), hasout=True), diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 966c34eae58..d2f4d34d4cf 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -978,7 +978,6 @@ symbolic_tensor_failures = { xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition xfail('double', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition - xfail('eig', ''), # aten.eig.default - couldn't find symbolic meta function/decomposition xfail('einsum', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('expand_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 5a8bf46319f..456ff56a670 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -582,9 +582,6 @@ grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)" mask: 'not_implemented("native_dropout_backward: mask")' -- name: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors) - self: eig_backward(grads, self, eigenvectors, eigenvalues, eigenvectors_return) - - name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) self: zeros_like(self) result: self_t.zero_() diff --git a/torch/__init__.py b/torch/__init__.py index a6e8bc295d0..22fffec424d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -929,7 +929,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack from . import _masked # Import removed ops with error message about removal -from ._linalg_utils import solve +from ._linalg_utils import eig, solve def _register_device_module(device_type, module): diff --git a/torch/_linalg_utils.py b/torch/_linalg_utils.py index d7f6798dd9d..b9261cb25ae 100644 --- a/torch/_linalg_utils.py +++ b/torch/_linalg_utils.py @@ -96,9 +96,17 @@ def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]: return E, Z -# This function was deprecated and removed +# These functions were deprecated and removed # This nice error message can be removed in version 1.13+ def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]: raise RuntimeError( "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.", ) + + +def eig( + self: Tensor, eigenvectors: bool = False, *, e=None, v=None +) -> Tuple[Tensor, Tensor]: + raise RuntimeError( + "This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.eig` function instead.", + ) diff --git a/torch/_tensor.py b/torch/_tensor.py index 8db059ea708..9812a6148b4 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -638,6 +638,11 @@ class Tensor(torch._C._TensorBase): return solve(self, other) + def eig(self, eigenvectors=False): + from ._linalg_utils import eig + + return eig(self, eigenvectors=eigenvectors) + def lu(self, pivot=True, get_infos=False): r"""See :func:`torch.lu`""" # If get_infos is True, then we don't need to check for errors and vice versa diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index 3380942c028..0cccd817272 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -1692,15 +1692,6 @@ See :func:`torch.dot` """, ) -add_docstr_all( - "eig", - r""" -eig(eigenvectors=False) -> (Tensor, Tensor) - -See :func:`torch.eig` -""", -) - add_docstr_all( "element_size", r""" diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index 9941367b791..d39d27103fe 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -3999,92 +3999,6 @@ Example:: """, ) -add_docstr( - torch.eig, - r""" -eig(input, eigenvectors=False, *, out=None) -> (Tensor, Tensor) - -Computes the eigenvalues and eigenvectors of a real square matrix. - -.. note:: - Since eigenvalues and eigenvectors might be complex, backward pass is supported only - if eigenvalues and eigenvectors are all real valued. - - When :attr:`input` is on CUDA, :func:`torch.eig() ` causes - host-device synchronization. - -.. warning:: - - :func:`torch.eig` is deprecated in favor of :func:`torch.linalg.eig` - and will be removed in a future PyTorch release. - :func:`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble` - rather than real tensors mimicking complex tensors. - - ``L, _ = torch.eig(A)`` should be replaced with - - .. code :: python - - L_complex = torch.linalg.eigvals(A) - - ``L, V = torch.eig(A, eigenvectors=True)`` should be replaced with - - .. code :: python - - L_complex, V_complex = torch.linalg.eig(A) - -Args: - input (Tensor): the square matrix of shape :math:`(n \times n)` for which the eigenvalues and eigenvectors - will be computed - eigenvectors (bool): ``True`` to compute both eigenvalues and eigenvectors; - otherwise, only eigenvalues will be computed - -Keyword args: - out (tuple, optional): the output tensors - -Returns: - (Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing - - - **eigenvalues** (*Tensor*): Shape :math:`(n \times 2)`. Each row is an eigenvalue of ``input``, - where the first element is the real part and the second element is the imaginary part. - The eigenvalues are not necessarily ordered. - - **eigenvectors** (*Tensor*): If ``eigenvectors=False``, it's an empty tensor. - Otherwise, this tensor of shape :math:`(n \times n)` can be used to compute normalized (unit length) - eigenvectors of corresponding eigenvalues as follows. - If the corresponding `eigenvalues[j]` is a real number, column `eigenvectors[:, j]` is the eigenvector - corresponding to `eigenvalues[j]`. - If the corresponding `eigenvalues[j]` and `eigenvalues[j + 1]` form a complex conjugate pair, then the - true eigenvectors can be computed as - :math:`\text{true eigenvector}[j] = eigenvectors[:, j] + i \times eigenvectors[:, j + 1]`, - :math:`\text{true eigenvector}[j + 1] = eigenvectors[:, j] - i \times eigenvectors[:, j + 1]`. - -Example:: - - Trivial example with a diagonal matrix. By default, only eigenvalues are computed: - - >>> a = torch.diag(torch.tensor([1, 2, 3], dtype=torch.double)) - >>> e, v = torch.eig(a) - >>> e - tensor([[1., 0.], - [2., 0.], - [3., 0.]], dtype=torch.float64) - >>> v - tensor([], dtype=torch.float64) - - Compute also the eigenvectors: - - >>> e, v = torch.eig(a, eigenvectors=True) - >>> e - tensor([[1., 0.], - [2., 0.], - [3., 0.]], dtype=torch.float64) - >>> v - tensor([[1., 0., 0.], - [0., 1., 0.], - [0., 0., 1.]], dtype=torch.float64) - -""", -) - add_docstr( torch.eq, r""" diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 7ad92e83f08..de4cb9741f5 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -3440,151 +3440,6 @@ Tensor svd_backward( return gA; } -// The implementation follows: -// "An extended collection of matrix derivative results for forward and reverse -// mode algorithmic differentiation" -// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf -// However, the reference does not cover the constraints on eigenvectors to have -// 1-norm. See the details below. -Tensor eig_backward( - const std::vector& grads, - const Tensor& self, - bool is_eigvec_tensor_nonempty, - const Tensor& eigenvalues, - const Tensor& eigenvectors) { - at::NoTF32Guard disable_tf32; - TORCH_CHECK( - is_eigvec_tensor_nonempty, - "eig_backward: torch.eig(eigenvalues=False) is not differentiable. ", - "Please use torch.linalg.eigvals"); - - // variable names correspond to the ones in the reference document - auto D = eigenvalues; - const auto& U = eigenvectors; - auto D_grad = grads[0]; - auto U_grad = grads[1]; - - // The condition below is trying to marry torch.eig and torch.linalg.eig - // for real inputs. - // - // For real inputs torch.eig returns a real 2D tensor representing real and - // complex components of eigenvalues, while torch.linalg.eig will most likely - // always return complex eigenvalues. - if (!self.is_complex()) { - Tensor is_imag_eigvals_zero; - // path for torch.eig with always a "real" 2D tensor of eigenvalues - if (!D.is_complex()) { - // narrow extracts the column corresponding to the imaginary part - is_imag_eigvals_zero = (D.narrow(-1, 1, 1) == 0.0).min(); - } - // path for torch.linalg.eig with always a complex tensor of eigenvalues - else { - is_imag_eigvals_zero = (at::imag(D) == 0.0).min(); - // insert an additional dimension to be compatible with torch.eig. - // Recall that it produces 2D tensors. - // We extract only the real parts as there is no support for - // complex eigenvalues with real inputs yet. - D = at::real(D).unsqueeze(-1); - D_grad = at::real(D_grad).unsqueeze(-1); - } - // No support for complex eigenvalues for real inputs yet. - TORCH_CHECK( - at::is_scalar_tensor_true(is_imag_eigvals_zero), - "eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment."); - } else { - // torch.eig returns 2d tensors for eigenvalues, - // while torch.linalg.eig returns 1d. - // Hence we insert additional dimension for complex input, - // such that the same code could be used for both methods. - // It will become unnecessary once torch.eig is deprecated. - D = D.unsqueeze(-1); - if (D_grad.defined()) { - D_grad = D_grad.unsqueeze(-1); - } - } - - if (!D_grad.defined() && !U_grad.defined()) { - return at::zeros_like(self, at::MemoryFormat::Contiguous); - } - - // Adapting the result from the reference above for the complex input, we get: - // - // A_grad = U^{-H} (D_grad + F.conj() * (U^H U_grad)) U^H, - // where M^H := (M.mT()).conj() and * is the Hadamard (element-wise) product. - // - // torch.eig/torch.linalg.eig produce eigenvectors which are - // normalized to 1 norm, and the reference does not take that into account. - // Hence, we have to modify the formula accordingly. - // - // Normalization to 1 norm imposes the following constraint on the - // eigenvectors, i.e. (U^H U) * I = I, where I is an identity matrix. Forward - // AD for this expression yields: (dU^H U + U^H dU) * I = 0 => U^H dU * I = 0 - // <=> diag(U^H dU) = 0, which means that each i-th column of U is orthogonal - // to the i-th column of dU. Now, the value of dU which does not take this - // constraint into consideration comes straight from the reference: dU = U(F * - // U^{-1} dA U). To make sure that U^H dU * I = 0, and using U^H U * I = I - // (normalization), we propose a modifed forward AD for U: dU_new = dU - U(U^H - // dU * I) (think of Gram-Schmidt) - // - // The rest is very similar to what is done in the reference and we finally - // arrive at: - // - // A_grad = U^{-H} (D_grad + (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj()) - // U^H - // = U^{-H} (eigenvalues_contribs + eigenvectors_contrib) U^H, where - // eigenvalues_contribs := D_grad, - // eigenvectors_contribs := (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj(). - // The contributions from the eigenvectors and the eigenvalues are computed - // below, and then we solve the system U^H A_grad = (eigenvalues_contribs + - // eigenvectors_contribs) U_H to produce A_grad. - - // contribution from the eigenvectors - Tensor U_contrib; - if (U_grad.defined()) { - // narrow extracts the column corresponding to the real part - D = D.narrow(-1, 0, 1); - auto F = (D.mT() - D); - if (!F.is_complex()) { - F.diagonal(0, -2, -1).fill_(INFINITY); - F.pow_(-1); - } else { - // The F matrix construction for complex eigenvalues - // if different from its real counterpart. - // There is no complex INFINITY, and we cannot use - // - // F.pow_(-1); - // F.diagonal(0, -2, -1).fill_(0); - // - // as it breaks gradgradcheck by double backward - // propagating nans through F.pow_(-1) at zero, - // the point of discontinuity. - // Hence this hack below. - F.diagonal(0, -2, -1).fill_(1); - F.pow_(-1); - F.diagonal(0, -2, -1).fill_(0); - } - auto U_grad_proj_onto_U = at::matmul(U.mH(), U_grad); - auto Uh_U = at::matmul(U.mH(), U); - U_contrib = (U_grad_proj_onto_U - - Uh_U * U_grad_proj_onto_U.diagonal(0, -2, -1).unsqueeze(-2)) * - F.conj(); - } else { - U_contrib = at::zeros_like(self, at::MemoryFormat::Contiguous); - } - - // contributions from the eigenvalues - Tensor D_contrib; - if (D_grad.defined()) { - // narrow extracts the column corresponding to the real part - D_contrib = D_grad.narrow(-1, 0, 1); - } else { - D_contrib = at::zeros_like(D, at::MemoryFormat::Contiguous); - } - - return at::linalg_solve( - U.mH(), at::matmul(U_contrib, U.mH()) + D_contrib * U.mH()); -} - Tensor linalg_eig_backward( const Tensor& gL, const Tensor& gV, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index d8fb0923eed..83b3f51dc98 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -628,12 +628,6 @@ Tensor linalg_qr_backward( const Tensor& Q, const Tensor& R, const c10::string_view mode); -Tensor eig_backward( - const std::vector& grads, - const Tensor& self, - bool eigenvectors, - const Tensor& lambda, - const Tensor& v); Tensor linalg_matrix_exp_differential( const Tensor& self, const Tensor& grad, diff --git a/torch/overrides.py b/torch/overrides.py index 8ea0ae8c022..d77950ad8ad 100644 --- a/torch/overrides.py +++ b/torch/overrides.py @@ -254,6 +254,7 @@ def get_ignored_functions() -> Set[Callable]: Tensor.__subclasshook__, Tensor.__hash__, Tensor.as_subclass, + Tensor.eig, Tensor.reinforce, Tensor.new, Tensor.new_tensor, @@ -487,7 +488,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]: torch.hsmm: lambda mat1, mat2: -1, torch.dsplit: lambda input, indices_or_sections: -1, torch.dstack: lambda tensors, out=None: -1, - torch.eig: lambda input, eigenvectors=False, out=None: -1, torch.linalg.eig: lambda input, out=None: -1, torch.linalg.eigvals: lambda input, out=None: -1, torch.linalg.eigh: lambda input, UPLO="L", out=None: -1, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 4048fe40c34..9a9d8c7a62a 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -2147,14 +2147,6 @@ def error_inputs_renorm(op_info, device, **kwargs): yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError, error_regex="needs at least 2 dimensions, got 0 dimensions") -def error_inputs_eig(op_info, device, **kwargs): - zero_d = torch.randn((), device=device) - - yield ErrorInput(SampleInput(zero_d, args=(False,)), error_type=RuntimeError, - error_regex="input should be 2 dimensional") - - yield ErrorInput(SampleInput(zero_d, args=(True,)), error_type=RuntimeError, - error_regex="input should be 2 dimensional") def error_inputs_ormqr(op_info, device, **kwargs): # this is only implemented on cpu @@ -4822,41 +4814,6 @@ def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad) -def sample_inputs_eig(op_info, device, dtype, requires_grad=False, **kwargs): - eigvecs = make_tensor((S, S), device=device, dtype=dtype, - low=None, high=None) - eigvals = make_tensor((S,), device=device, dtype=dtype, - low=None, high=None) - # we produce only diagonazible inputs which do not have - # complex eigenvalues for real inputs, as there is no - # backward implementation for real inputs with complex - # eigenvalues yet. - input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse() - input.requires_grad_(requires_grad) - - def process_output(eigpair): - eigvals, eigvecs = eigpair - if dtype.is_complex: - # eig produces eigenvectors which are normalized to 1 norm. - # Note that if v is an eigenvector, so is v * e^{i \phi}, - # and |v| = |v * e^{i \phi}| = 1. - # This, however, makes the eigenvector backward computation process - # rather unstable unless the objective function is gauge-invariant, - # that is if f(z) == f(|z|), for example. - # Hence for complex inputs we ignore the phases and return only - # the absolute values. - return eigvals, eigvecs.abs() - else: - return eigvals, eigvecs - - return [ - SampleInput( - input, - kwargs=dict(eigenvectors=True), - output_process_fn_grad=process_output - ), - ] - def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs): def c(t): @@ -13001,16 +12958,6 @@ op_db: List[OpInfo] = [ supports_sparse_bsr=True, supports_sparse_bsc=True, supports_autograd=False), - OpInfo('eig', - op=torch.eig, - dtypes=floating_and_complex_types(), - sample_inputs_func=sample_inputs_eig, - error_inputs_func=error_inputs_eig, - decorators=[ - skipCUDAIfNoMagma, - skipCPUIfNoLapack, - ], - ), OpInfo('einsum', # we need this lambda because SampleInput expects tensor input as the first argument # TODO(@heitorschueroff) update SampleInput to handle such cases