mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove deprecated torch.eig (#70982)
The time has come to remove deprecated linear algebra related functions. This PR removes `torch.eig`. cc @jianyuh @nikitaved @pearu @mruberry @walterddr @IvanYashchuk @xwang233 @Lezcano Pull Request resolved: https://github.com/pytorch/pytorch/pull/70982 Approved by: https://github.com/Lezcano, https://github.com/malfet
This commit is contained in:
parent
c4a5255df7
commit
01c54ad6de
|
|
@ -595,12 +595,6 @@ TORCH_LIBRARY_IMPL(aten, AutocastCPU, m) {
|
|||
KERNEL_CPU(ADD_NS(linalg_tensorsolve), "linalg_tensorsolve", Tensor(const Tensor &, const Tensor &, at::OptionalIntArrayRef), fp32)
|
||||
KERNEL_CPU(ADD_NS(fake_quantize_per_tensor_affine), "fake_quantize_per_tensor_affine", Tensor (const Tensor &, double, int64_t, int64_t, int64_t), fp32)
|
||||
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::eig"),
|
||||
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
|
||||
std::tuple<Tensor, Tensor> (const Tensor &, bool),
|
||||
std::tuple<Tensor, Tensor> (const Tensor &, bool),
|
||||
&ADD_NS(eig)>::type::call)));
|
||||
|
||||
m.impl(TORCH_SELECTIVE_NAME("aten::geqrf"),
|
||||
TORCH_FN((&WrapFunction<CastPolicy::fp32, DeviceType::CPU,
|
||||
std::tuple<Tensor, Tensor> (const Tensor &),
|
||||
|
|
|
|||
|
|
@ -3168,66 +3168,6 @@ Tensor linalg_eigvals(const Tensor& input) {
|
|||
return values;
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
DEFINE_DISPATCH(eig_stub);
|
||||
|
||||
std::tuple<Tensor&, Tensor&> eig_out(const Tensor& self, bool eigenvectors, Tensor& e, Tensor& v) {
|
||||
TORCH_WARN_ONCE(
|
||||
"torch.eig is deprecated in favor of torch.linalg.eig and will be removed in a future ",
|
||||
"PyTorch release.\n",
|
||||
"torch.linalg.eig returns complex tensors of dtype cfloat or cdouble rather than real tensors ",
|
||||
"mimicking complex tensors.\n",
|
||||
"L, _ = torch.eig(A)\n",
|
||||
"should be replaced with\n",
|
||||
"L_complex = torch.linalg.eigvals(A)\n",
|
||||
"and\n",
|
||||
"L, V = torch.eig(A, eigenvectors=True)\n",
|
||||
"should be replaced with\n",
|
||||
"L_complex, V_complex = torch.linalg.eig(A)"
|
||||
);
|
||||
TORCH_CHECK(self.dim() == 2, "input should be 2 dimensional");
|
||||
TORCH_CHECK(self.size(0) == self.size(1), "input should be square");
|
||||
TORCH_CHECK(self.isfinite().all().item<bool>(), "input should not contain infs or NaNs");
|
||||
checkSameDevice("torch.eig", e, self, "eigenvalues");
|
||||
checkLinalgCompatibleDtype("torch.eig", e, self, "eigenvalues");
|
||||
if (eigenvectors) {
|
||||
checkSameDevice("torch.eig", v, self, "eigenvectors");
|
||||
checkLinalgCompatibleDtype("torch.eig", v, self, "eigenvectors");
|
||||
}
|
||||
int64_t n = self.size(-1);
|
||||
|
||||
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
|
||||
at::native::resize_output(e, {n});
|
||||
} else {
|
||||
at::native::resize_output(e, {n, 2});
|
||||
}
|
||||
if (eigenvectors) {
|
||||
at::native::resize_output(v, self.sizes());
|
||||
}
|
||||
|
||||
// optimization: if self is empty, we can immediately return the empty
|
||||
// tensors, instead of getting empty tensors from eig_helper
|
||||
if (self.numel() == 0) {
|
||||
return std::tuple<Tensor&, Tensor&>(e, v);
|
||||
}
|
||||
|
||||
Tensor vals_, vecs_;
|
||||
std::tie(vals_, vecs_) = eig_stub(self.device().type(), self, eigenvectors);
|
||||
e.copy_(vals_);
|
||||
if (eigenvectors) {
|
||||
v.copy_(vecs_);
|
||||
}
|
||||
return std::tuple<Tensor&, Tensor&>(e, v);
|
||||
}
|
||||
|
||||
std::tuple<Tensor,Tensor> eig(const Tensor& self, bool eigenvectors) {
|
||||
Tensor e = at::empty({0}, self.options());
|
||||
Tensor v = at::empty({0}, self.options());
|
||||
at::eig_out(e, v, self, eigenvectors);
|
||||
return std::tuple<Tensor, Tensor>(e, v);
|
||||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_svd ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/* torch.svd, implemented in terms of torch.linalg.svd. There are two main
|
||||
|
|
|
|||
|
|
@ -231,10 +231,6 @@ using cholesky_inverse_fn = Tensor& (*)(Tensor& /*result*/, Tensor& /*infos*/, b
|
|||
|
||||
DECLARE_DISPATCH(cholesky_inverse_fn, cholesky_inverse_stub);
|
||||
|
||||
using eig_fn = std::tuple<Tensor, Tensor> (*)(const Tensor&, bool&);
|
||||
|
||||
DECLARE_DISPATCH(eig_fn, eig_stub);
|
||||
|
||||
using linalg_eig_fn = void (*)(Tensor& /*eigenvalues*/, Tensor& /*eigenvectors*/, Tensor& /*infos*/, const Tensor& /*input*/, bool /*compute_eigenvectors*/);
|
||||
|
||||
DECLARE_DISPATCH(linalg_eig_fn, linalg_eig_stub);
|
||||
|
|
|
|||
|
|
@ -127,87 +127,6 @@ Tensor& cholesky_inverse_kernel_impl(Tensor& result, Tensor& infos, bool upper)
|
|||
return result;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void apply_eig(const Tensor& self, bool eigenvectors, Tensor& vals_, Tensor& vecs_, int* info_ptr) {
|
||||
#if !AT_BUILD_WITH_LAPACK()
|
||||
TORCH_CHECK(false, "Calling torch.eig on a CPU tensor requires compiling ",
|
||||
"PyTorch with LAPACK. Please use PyTorch built with LAPACK support.");
|
||||
#else
|
||||
using value_t = typename c10::scalar_value_type<scalar_t>::type;
|
||||
|
||||
char jobvr = eigenvectors ? 'V' : 'N';
|
||||
int64_t n = self.size(-1);
|
||||
auto self_data = self.data_ptr<scalar_t>();
|
||||
|
||||
auto vals_data = vals_.data_ptr<scalar_t>();
|
||||
scalar_t* wr = vals_data;
|
||||
|
||||
scalar_t* vecs_data = eigenvectors ? vecs_.data_ptr<scalar_t>() : nullptr;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
|
||||
int ldvr = eigenvectors ? n : 1;
|
||||
|
||||
Tensor rwork;
|
||||
value_t* rwork_data = nullptr;
|
||||
if (self.is_complex()) {
|
||||
ScalarType real_dtype = toRealValueType(typeMetaToScalarType(self.dtype()));
|
||||
rwork = at::empty({n*2}, self.options().dtype(real_dtype));
|
||||
rwork_data = rwork.data_ptr<value_t>();
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
// call lapackEig once to get the optimal size for work data
|
||||
scalar_t wkopt;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
|
||||
nullptr, 1, vecs_data, ldvr, &wkopt, -1, rwork_data, info_ptr);
|
||||
int lwork = std::max<int>(1, real_impl<scalar_t, value_t>(wkopt));
|
||||
|
||||
// call again to do the actual work
|
||||
Tensor work = at::empty({lwork}, self.dtype());
|
||||
lapackEig<scalar_t, value_t>('N', jobvr, n, self_data, n, wr,
|
||||
nullptr, 1, vecs_data, ldvr, work.data_ptr<scalar_t>(), lwork, rwork_data, info_ptr);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvectors) {
|
||||
int64_t n = self.size(-1);
|
||||
// lapackEig function expects the input to be column major, or stride {1, n},
|
||||
// so we must set the stride manually since the default stride for tensors is
|
||||
// row major, {n, 1}
|
||||
Tensor self_ = at::empty_strided(
|
||||
{n, n},
|
||||
{1, n},
|
||||
at::TensorOptions(self.dtype()));
|
||||
self_.copy_(self);
|
||||
|
||||
auto options = self.options().memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
// the API is slightly different for the complex vs real case: if the input
|
||||
// is complex, eigenvals will be a vector of complex. If the input is real,
|
||||
// eigenvals will be a (n, 2) matrix containing the real and imaginary parts
|
||||
// in each column
|
||||
Tensor vals_;
|
||||
if (self.is_complex()) {
|
||||
vals_ = at::empty({n}, options);
|
||||
} else {
|
||||
vals_ = at::empty_strided({n, 2}, {1, n}, options);
|
||||
}
|
||||
Tensor vecs_ = eigenvectors
|
||||
? at::empty_strided({n, n}, {1, n}, options)
|
||||
: Tensor();
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
auto infos = at::zeros({}, self.options().dtype(kInt));
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cpu", [&]{
|
||||
apply_eig<scalar_t>(self_, eigenvectors, vals_, vecs_, infos.data_ptr<int>());
|
||||
});
|
||||
// NOLINTNEXTLINE(clang-analyzer-core.CallAndMessage)
|
||||
at::_linalg_check_errors(infos, "eig", /*is_matrix*/true);
|
||||
|
||||
return std::tuple<Tensor, Tensor>(vals_, vecs_);
|
||||
}
|
||||
|
||||
/*
|
||||
Computes the eigenvalues and eigenvectors of n-by-n matrix 'input'.
|
||||
This is an in-place routine, content of 'input', 'values', 'vectors' is overwritten.
|
||||
|
|
@ -1200,12 +1119,6 @@ REGISTER_AVX2_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
|||
REGISTER_VSX_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||
REGISTER_ZVECTOR_DISPATCH(cholesky_inverse_stub, &cholesky_inverse_kernel_impl);
|
||||
|
||||
REGISTER_ARCH_DISPATCH(eig_stub, DEFAULT, &eig_kernel_impl);
|
||||
REGISTER_AVX512_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||
REGISTER_AVX2_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||
REGISTER_VSX_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||
REGISTER_ZVECTOR_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||
|
||||
REGISTER_ARCH_DISPATCH(linalg_eig_stub, DEFAULT, &linalg_eig_kernel);
|
||||
REGISTER_AVX512_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||
REGISTER_AVX2_DISPATCH(linalg_eig_stub, &linalg_eig_kernel);
|
||||
|
|
|
|||
|
|
@ -93,11 +93,6 @@ void lazy_linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvecto
|
|||
linalg_eigh_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, upper, compute_eigenvectors);
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> lazy_eig_kernel(const Tensor& self, bool& eigenvectors) {
|
||||
loadLazyTorchLinalgLibrary();
|
||||
return eig_stub(DeviceType::CUDA, self, eigenvectors);
|
||||
}
|
||||
|
||||
void lazy_linalg_eig_kernel(Tensor& eigenvalues, Tensor& eigenvectors, Tensor& infos, const Tensor& input, bool compute_eigenvectors) {
|
||||
getTorchLinalgLibrary();
|
||||
linalg_eig_stub(DeviceType::CUDA, eigenvalues, eigenvectors, infos, input, compute_eigenvectors);
|
||||
|
|
@ -155,7 +150,6 @@ REGISTER_CUDA_DISPATCH(orgqr_stub, &lazy_orgqr_kernel);
|
|||
REGISTER_CUDA_DISPATCH(ormqr_stub, &lazy_ormqr_kernel);
|
||||
REGISTER_CUDA_DISPATCH(geqrf_stub, &lazy_geqrf_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &lazy_linalg_eigh_kernel);
|
||||
REGISTER_CUDA_DISPATCH(eig_stub, &lazy_eig_kernel);
|
||||
REGISTER_CUDA_DISPATCH(linalg_eig_stub, &lazy_linalg_eig_kernel);
|
||||
REGISTER_CUDA_DISPATCH(svd_stub, &lazy_svd_kernel)
|
||||
REGISTER_CUDA_DISPATCH(lu_solve_stub, &lazy_lu_solve);
|
||||
|
|
|
|||
|
|
@ -2036,96 +2036,6 @@ void linalg_eigh_kernel(const Tensor& eigenvalues, const Tensor& eigenvectors, c
|
|||
|
||||
REGISTER_CUDA_DISPATCH(linalg_eigh_stub, &linalg_eigh_kernel);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
// magmaEig uses a hybrid CPU-GPU algorithm, which takes and return CPU
|
||||
// memory. So, we accept a GPU tensor, copy it to CPU memory, and later copy
|
||||
// the returned values from CPU to GPU. See also magmaSymeig, which uses a
|
||||
// similar approach.
|
||||
|
||||
template <typename scalar_t>
|
||||
static void apply_eig(const Tensor& self, bool eigenvectors, Tensor& out_eigvals, Tensor& out_eigvecs,
|
||||
int* info_ptr) {
|
||||
#if !AT_MAGMA_ENABLED()
|
||||
TORCH_CHECK(false, "Calling torch.eig on a CUDA tensor requires compiling PyTorch with MAGMA. "
|
||||
"Either transfer the tensor to the CPU before calling torch.eig or recompile with MAGMA.");
|
||||
#else
|
||||
TORCH_INTERNAL_ASSERT(self.device() == at::kCPU, "Internal error: apply_eig needs a CPU tensor");
|
||||
using value_t = typename c10::scalar_value_type<scalar_t>::type;
|
||||
magma_vec_t jobvr = eigenvectors ? MagmaVec : MagmaNoVec;
|
||||
magma_int_t n = magma_int_cast(self.size(-1), "n");
|
||||
auto self_data = self.data_ptr<scalar_t>();
|
||||
|
||||
auto out_eigvals_data = out_eigvals.data_ptr<scalar_t>();
|
||||
scalar_t *wr = out_eigvals_data;
|
||||
|
||||
scalar_t *vr_data = NULL;
|
||||
magma_int_t ldvr = 1;
|
||||
if (jobvr == MagmaVec)
|
||||
{
|
||||
vr_data = out_eigvecs.data_ptr<scalar_t>();
|
||||
ldvr = n;
|
||||
}
|
||||
|
||||
value_t *rwork_data = nullptr;
|
||||
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
|
||||
ALLOCATE_ARRAY(rwork_data, value_t, n*2);
|
||||
}
|
||||
|
||||
if (n > 0) {
|
||||
// call magmaEig once to get the optimal size of work_data
|
||||
scalar_t wkopt;
|
||||
magma_int_t info;
|
||||
magmaEig<scalar_t, value_t>(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, &wkopt, -1, rwork_data, &info);
|
||||
magma_int_t lwork = static_cast<magma_int_t>(real_impl<scalar_t, value_t>(wkopt));
|
||||
|
||||
// call it a 2nd time to to the actual work
|
||||
scalar_t *work_data = nullptr;
|
||||
ALLOCATE_ARRAY(work_data, scalar_t, lwork);
|
||||
magmaEig<scalar_t, value_t>(MagmaNoVec, jobvr, n, self_data, n, wr, NULL, 1, vr_data, ldvr, work_data, lwork, rwork_data, &info);
|
||||
*info_ptr = info;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* Internal helper; like eig_cuda but:
|
||||
* 1. assume that self is a square matrix of side "n"
|
||||
* 2. return CPU tensors (because this is what magmaEig returns), which will be copied to GPU memory
|
||||
* by the caller
|
||||
*/
|
||||
std::tuple<Tensor, Tensor> eig_kernel_impl(const Tensor& self, bool& eigenvectors) {
|
||||
int64_t n = self.size(-1);
|
||||
// copy self to pinned CPU memory
|
||||
auto self_working_copy = at::empty_strided(
|
||||
{n, n}, // square matrix
|
||||
{1, n}, // column-ordered, as magmaEig expects
|
||||
at::TensorOptions(at::kCPU).dtype(self.dtype()).pinned_memory(true));
|
||||
self_working_copy.copy_(self);
|
||||
|
||||
// tensors holding the results. We use empty_strided to make them column-ordered
|
||||
auto options = self.options().device(at::kCPU).memory_format(LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
Tensor out_eigvals;
|
||||
if (isComplexType(at::typeMetaToScalarType(self.dtype()))) {
|
||||
out_eigvals = at::empty({n}, options);
|
||||
} else {
|
||||
out_eigvals = at::empty_strided({n, 2}, {1, n}, options);
|
||||
}
|
||||
auto out_eigvecs = eigenvectors
|
||||
? at::empty_strided({n, n}, {1, n}, options)
|
||||
: Tensor();
|
||||
|
||||
auto infos = at::zeros({}, self_working_copy.options().dtype(kInt));
|
||||
AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(self.scalar_type(), "eig_cuda", [&]{
|
||||
apply_eig<scalar_t>(self_working_copy, eigenvectors, out_eigvals, out_eigvecs, infos.data_ptr<int>());
|
||||
});
|
||||
at::_linalg_check_errors(infos, "eig", /*is_matrix*/true);
|
||||
|
||||
return std::tuple<Tensor, Tensor>(out_eigvals, out_eigvecs);
|
||||
}
|
||||
|
||||
REGISTER_CUDA_DISPATCH(eig_stub, &eig_kernel_impl);
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ linalg_eig ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
/*
|
||||
|
|
|
|||
|
|
@ -8113,15 +8113,6 @@
|
|||
CUDA: _symeig_helper_cuda
|
||||
autogen: _symeig_helper.out
|
||||
|
||||
- func: eig.e(Tensor self, bool eigenvectors=False, *, Tensor(a!) e, Tensor(b!) v) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: eig_out
|
||||
|
||||
- func: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)
|
||||
variants: method, function
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: eig
|
||||
|
||||
- func: svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
|
||||
|
||||
- func: svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)
|
||||
|
|
|
|||
|
|
@ -345,7 +345,6 @@ Tensor class reference
|
|||
Tensor.dot
|
||||
Tensor.double
|
||||
Tensor.dsplit
|
||||
Tensor.eig
|
||||
Tensor.element_size
|
||||
Tensor.eq
|
||||
Tensor.eq_
|
||||
|
|
|
|||
|
|
@ -558,7 +558,6 @@ BLAS and LAPACK Operations
|
|||
cholesky_inverse
|
||||
cholesky_solve
|
||||
dot
|
||||
eig
|
||||
geqrf
|
||||
ger
|
||||
inner
|
||||
|
|
|
|||
|
|
@ -688,7 +688,6 @@ class TestOperators(TestCase):
|
|||
skip('linalg.svdvals'), # # really annoying thing where it passes correctness check but not has_batch_rule
|
||||
xfail('__getitem__', ''), # dynamic error
|
||||
xfail('_masked.prod'), # calls aten::item
|
||||
xfail('eig'), # calls aten::item
|
||||
xfail('linalg.eig'), # Uses aten::allclose
|
||||
xfail('linalg.householder_product'), # needs select_scatter
|
||||
xfail('nanquantile', device_type='cpu'), # checks q via a .item() call
|
||||
|
|
@ -923,7 +922,6 @@ class TestOperators(TestCase):
|
|||
xfail('cummax'),
|
||||
xfail('cummin'),
|
||||
xfail('cumprod'),
|
||||
xfail('eig'),
|
||||
xfail('nansum'),
|
||||
xfail('nanmean'),
|
||||
xfail('special.log_ndtr'),
|
||||
|
|
@ -1142,7 +1140,6 @@ class TestOperators(TestCase):
|
|||
xfail('_masked.softmin', ''), # NYI: forward-AD for _softmax_backward_data
|
||||
xfail('cdist', ''), # NYI: forward-AD for _cdist_forward
|
||||
xfail('cholesky', ''), # NYI: forward-AD for cholesky
|
||||
xfail('eig', ''), # NYI: forward-AD for eig
|
||||
xfail('logcumsumexp', ''), # NYI: forward-AD for logcumsumexp
|
||||
xfail('nn.functional.embedding_bag', ''), # NYI: forward-AD for _embedding_bag
|
||||
xfail('nn.functional.grid_sample', ''), # NYI: forward AD for grid_sampler_2d
|
||||
|
|
|
|||
|
|
@ -3294,7 +3294,6 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
||||
xfail('complex'),
|
||||
xfail('copysign'),
|
||||
xfail('eig'),
|
||||
xfail('histogram'),
|
||||
xfail('index_fill'),
|
||||
xfail('nansum'),
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ ALLOW_LIST = [
|
|||
("aten::slice_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::diagonal_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::rowwise_prune", datetime.date(9999, 1, 1)),
|
||||
("aten::eig", datetime.date(9999, 1, 1)),
|
||||
("aten::eig.e", datetime.date(9999, 1, 1)),
|
||||
("aten::adaptive_avg_pool3d_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::_embedding_bag_dense_backward", datetime.date(9999, 1, 1)),
|
||||
("aten::randperm", datetime.date(9999, 1, 1)),
|
||||
|
|
|
|||
|
|
@ -3744,20 +3744,6 @@ class TestAutograd(TestCase):
|
|||
out.backward()
|
||||
|
||||
# TODO: update these tests to use the linalg module and move to test_linalg.py
|
||||
@skipIfNoLapack
|
||||
def test_eig_no_eigenvectors(self):
|
||||
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
|
||||
w, v = torch.eig(A, eigenvectors=False)
|
||||
with self.assertRaisesRegex(RuntimeError, 'is not differentiable'):
|
||||
torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_eig_complex_eigenvalues(self):
|
||||
A = torch.tensor([[0., -1.], [1., 0.]], dtype=torch.float32, requires_grad=True)
|
||||
w, v = torch.eig(A, eigenvectors=True)
|
||||
with self.assertRaisesRegex(RuntimeError, 'does not support complex eigenvalues'):
|
||||
torch.autograd.backward([w, v], [torch.ones_like(w), torch.ones_like(v)])
|
||||
|
||||
@skipIfNoLapack
|
||||
def test_symeig_no_eigenvectors(self):
|
||||
A = torch.tensor([[1., 2.], [2., 4.]], dtype=torch.float32, requires_grad=True)
|
||||
|
|
|
|||
|
|
@ -148,6 +148,13 @@ class TestLinalg(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
|
||||
b.solve(a)
|
||||
|
||||
def test_eig_removed_error(self, device):
|
||||
a = make_tensor(5, 5, device=device, dtype=torch.float32)
|
||||
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
|
||||
torch.eig(a)
|
||||
with self.assertRaisesRegex(RuntimeError, "This function was deprecated since version 1.9 and is now removed"):
|
||||
a.eig()
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.float, torch.double, torch.cfloat, torch.cdouble)
|
||||
|
|
@ -1758,122 +1765,6 @@ class TestLinalg(TestCase):
|
|||
expected = torch.pow(x.pow(3).abs().sum(1), 1.0 / 3.0)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(*floating_and_complex_types())
|
||||
def test_old_eig_basic(self, device, dtype):
|
||||
a = torch.tensor([[1.96, 0.00, 0.00, 0.00, 0.00],
|
||||
[-6.49, 3.80, 0.00, 0.00, 0.00],
|
||||
[-0.47, -6.39, 4.17, 0.00, 0.00],
|
||||
[-7.20, 1.50, -1.51, 5.70, 0.00],
|
||||
[-0.65, -6.34, 2.67, 1.80, -7.10]],
|
||||
dtype=dtype, device=device).t()
|
||||
e = torch.eig(a)[0]
|
||||
ee, vv = torch.eig(a, True)
|
||||
te = torch.tensor((), dtype=dtype, device=device)
|
||||
tv = torch.tensor((), dtype=dtype, device=device)
|
||||
eee, vvv = torch.eig(a, True, out=(te, tv))
|
||||
self.assertEqual(e, ee, atol=1e-12, rtol=0)
|
||||
self.assertEqual(ee, eee, atol=1e-12, rtol=0)
|
||||
self.assertEqual(ee, te, atol=1e-12, rtol=0)
|
||||
self.assertEqual(vv, vvv, atol=1e-12, rtol=0)
|
||||
self.assertEqual(vv, tv, atol=1e-12, rtol=0)
|
||||
#
|
||||
# compare with numpy
|
||||
np_e, np_v = np.linalg.eig(a.cpu().numpy())
|
||||
if dtype.is_complex:
|
||||
self.assertEqual(ee, np_e)
|
||||
else:
|
||||
# np_e.shape == (n, 2), where each column contain the real and
|
||||
# imaginary parts of the result
|
||||
self.assertEqual(ee[:, 0], np_e) # real part
|
||||
self.assertEqual(ee[:, 1], torch.zeros(ee.shape[0], dtype=dtype)) # imaginary part
|
||||
self.assertEqual(vv, np_v)
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.double, torch.float)
|
||||
def test_old_eig_reuse(self, device, dtype):
|
||||
X = torch.randn(4, 4, dtype=dtype, device=device)
|
||||
X = torch.mm(X.t(), X)
|
||||
e = torch.zeros(4, 2, dtype=dtype, device=device)
|
||||
v = torch.zeros(4, 4, dtype=dtype, device=device)
|
||||
torch.eig(X, True, out=(e, v))
|
||||
Xhat = np.matmul(np.matmul(v.cpu(), torch.diag(e.select(1, 0)).cpu()), v.t().cpu())
|
||||
if dtype is torch.float:
|
||||
atol = 1e-7
|
||||
rtol = 1e-5
|
||||
else:
|
||||
atol = 1e-8
|
||||
rtol = 0
|
||||
self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong')
|
||||
self.assertTrue(v.is_contiguous(), 'V is not contiguous')
|
||||
|
||||
torch.eig(X, True, out=(e, v))
|
||||
Xhat = np.matmul(v.cpu(), np.matmul(e.select(1, 0).diag().cpu(), v.t().cpu()))
|
||||
self.assertEqual(X, Xhat, atol=atol, rtol=rtol, msg='VeV\' wrong')
|
||||
self.assertTrue(v.is_contiguous(), 'V is not contiguous')
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
@dtypes(torch.double, torch.float)
|
||||
def test_old_eig_invalid_input(self, device, dtype):
|
||||
# test invalid input
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'input should be 2 dimensional',
|
||||
lambda: torch.eig(torch.ones((2))))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'input should be square',
|
||||
lambda: torch.eig(torch.ones((2, 3))))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'input should not contain infs or NaNs',
|
||||
lambda: torch.eig(np.inf * torch.ones((2, 2))))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
'input should not contain infs or NaNs',
|
||||
lambda: torch.eig(np.nan * torch.ones((2, 2))))
|
||||
|
||||
@skipCUDAIfNoMagma
|
||||
@skipCPUIfNoLapack
|
||||
@dtypes(torch.double, torch.float)
|
||||
def test_old_eig_out(self, device, dtype):
|
||||
# the out version of torch.eig needs to be tested manually: we can't
|
||||
# use the "test_out=True" parameter to tensor_op_tests because the
|
||||
# signature is irregular (since we have *two* output vectors)
|
||||
t = torch.randn(10, 10, dtype=dtype, device=device)
|
||||
evals, evecs = torch.eig(t, eigenvectors=True)
|
||||
#
|
||||
# check that the out= version computes the same values as the normal one
|
||||
out_evals = torch.empty_like(evals)
|
||||
out_evecs = torch.empty_like(evecs)
|
||||
evals2, evecs2 = torch.eig(t, eigenvectors=True, out=(out_evals, out_evecs))
|
||||
# check that the out tensors were used in-place
|
||||
self.assertEqual(evals2.data_ptr(), out_evals.data_ptr())
|
||||
self.assertEqual(evecs2.data_ptr(), out_evecs.data_ptr())
|
||||
# check that the result is the same as the non-out version
|
||||
self.assertEqual(evals, out_evals)
|
||||
self.assertEqual(evecs, out_evecs)
|
||||
#
|
||||
# check what happens in the eigenvectors=False case
|
||||
out_evals = torch.empty_like(evals)
|
||||
out_evecs = torch.tensor([1, 2, 3], dtype=dtype, device=device)
|
||||
evals2, evecs2 = torch.eig(t, eigenvectors=False, out=(out_evals, out_evecs))
|
||||
# check that the out_evals was used in-place
|
||||
self.assertEqual(evals2.data_ptr(), out_evals.data_ptr())
|
||||
self.assertEqual(evals, out_evals)
|
||||
# check that out_evecs was NOT touched at all
|
||||
assert out_evecs.tolist() == [1, 2, 3]
|
||||
#
|
||||
# check that we complain if we pass an out vector of the wrong dtype
|
||||
wrong_out = torch.empty((0, 0), dtype=int)
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"):
|
||||
torch.eig(t, eigenvectors=True, out=(wrong_out, out_evecs))
|
||||
with self.assertRaisesRegex(RuntimeError, r"Expected .* but got .*"):
|
||||
torch.eig(t, eigenvectors=True, out=(out_evals, wrong_out))
|
||||
|
||||
@skipCPUIfNoLapack
|
||||
@skipCUDAIfNoMagma
|
||||
# NumPy computes only in float64 and complex128 precisions
|
||||
|
|
@ -7407,12 +7298,6 @@ scipy_lobpcg | {:10.2e} | {:10.2e} | {:6} | N/A
|
|||
self.assertEqual((torch.tensor(1., device=device), torch.tensor(0., device=device)),
|
||||
fn(torch.slogdet, (0, 0)))
|
||||
|
||||
# eig, symeig
|
||||
evalues, evectors = fn(torch.eig, (0, 0), True)
|
||||
self.assertEqual([(0, 2), (0, 0)], [evalues.shape, evectors.shape])
|
||||
evalues, evectors = fn(torch.symeig, (0, 0), True)
|
||||
self.assertEqual([(0,), (0, 0)], [evalues.shape, evectors.shape])
|
||||
|
||||
# lstsq
|
||||
self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0, 0), torch.randn(0, 0)))
|
||||
self.assertRaises(RuntimeError, lambda: torch.lstsq(torch.randn(0,), torch.randn(0, 0)))
|
||||
|
|
|
|||
|
|
@ -443,7 +443,6 @@ meta_function_expected_failures = {
|
|||
torch.cholesky : {f64, f32, c128, c64},
|
||||
torch.cholesky_inverse : {f64, f32, c128, c64},
|
||||
torch.cholesky_solve : {f64, f32, c128, c64},
|
||||
torch.eig : {f64, f32, c128, c64},
|
||||
torch.linalg.eig : {f64, f32, c128, c64},
|
||||
torch.linalg.eigvals : {f64, f32, c128, c64},
|
||||
torch.linalg.lstsq : {f64, f32, c128, c64},
|
||||
|
|
@ -633,7 +632,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.cholesky_solve.out : {c64, c128, f64, f32},
|
||||
aten.count_nonzero.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
|
||||
aten.count_nonzero.dim_IntList : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
|
||||
aten.eig.default : {c64, c128, f64, f32},
|
||||
aten.geqrf.default : {c64, c128, f64, f32},
|
||||
aten.linalg_eig.default : {c64, c128, f64, f32},
|
||||
aten.linalg_householder_product.default : {c64, c128, f64, f32},
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from collections import namedtuple
|
|||
path = os.path.dirname(os.path.realpath(__file__))
|
||||
aten_native_yaml = os.path.join(path, '../aten/src/ATen/native/native_functions.yaml')
|
||||
all_operators_with_namedtuple_return = {
|
||||
'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig', 'eig',
|
||||
'max', 'min', 'aminmax', 'median', 'nanmedian', 'mode', 'kthvalue', 'svd', 'symeig',
|
||||
'qr', 'geqrf', 'slogdet', 'sort', 'topk', 'lstsq', 'linalg_inv_ex',
|
||||
'triangular_solve', 'cummax', 'cummin', 'linalg_eigh', "_linalg_eigh", "_unpack_dual", 'linalg_qr',
|
||||
'linalg_svd', '_linalg_svd', 'linalg_slogdet', '_linalg_slogdet', 'fake_quantize_per_tensor_affine_cachemask',
|
||||
|
|
@ -77,7 +77,7 @@ class TestNamedTupleAPI(TestCase):
|
|||
op(operators=['_linalg_slogdet'], input=(), names=('sign', 'logabsdet', 'LU', 'pivots'), hasout=True),
|
||||
op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True),
|
||||
op(operators=['geqrf'], input=(), names=('a', 'tau'), hasout=True),
|
||||
op(operators=['symeig', 'eig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True),
|
||||
op(operators=['symeig'], input=(True,), names=('eigenvalues', 'eigenvectors'), hasout=True),
|
||||
op(operators=['triangular_solve'], input=(a,), names=('solution', 'cloned_coefficient'), hasout=True),
|
||||
op(operators=['lstsq'], input=(a,), names=('solution', 'QR'), hasout=True),
|
||||
op(operators=['linalg_eig'], input=(), names=('eigenvalues', 'eigenvectors'), hasout=True),
|
||||
|
|
|
|||
|
|
@ -978,7 +978,6 @@ symbolic_tensor_failures = {
|
|||
xfail('dist', ''), # aten.dist.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('double', ''), # aten._to_copy.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('dsplit', ''), # aten.slice.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('eig', ''), # aten.eig.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('einsum', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('expand_as', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('fft.fft2', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
|
|
|
|||
|
|
@ -582,9 +582,6 @@
|
|||
grad_output: "native_dropout_double_backward(grad, grad_output, mask, scale)"
|
||||
mask: 'not_implemented("native_dropout_backward: mask")'
|
||||
|
||||
- name: eig(Tensor self, bool eigenvectors=False) -> (Tensor eigenvalues, Tensor eigenvectors)
|
||||
self: eig_backward(grads, self, eigenvectors, eigenvalues, eigenvectors_return)
|
||||
|
||||
- name: eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
self: zeros_like(self)
|
||||
result: self_t.zero_()
|
||||
|
|
|
|||
|
|
@ -929,7 +929,7 @@ from torch.utils.dlpack import from_dlpack, to_dlpack
|
|||
from . import _masked
|
||||
|
||||
# Import removed ops with error message about removal
|
||||
from ._linalg_utils import solve
|
||||
from ._linalg_utils import eig, solve
|
||||
|
||||
|
||||
def _register_device_module(device_type, module):
|
||||
|
|
|
|||
|
|
@ -96,9 +96,17 @@ def symeig(A: Tensor, largest: Optional[bool] = False) -> Tuple[Tensor, Tensor]:
|
|||
return E, Z
|
||||
|
||||
|
||||
# This function was deprecated and removed
|
||||
# These functions were deprecated and removed
|
||||
# This nice error message can be removed in version 1.13+
|
||||
def solve(input: Tensor, A: Tensor, *, out=None) -> Tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.solve` function instead.",
|
||||
)
|
||||
|
||||
|
||||
def eig(
|
||||
self: Tensor, eigenvectors: bool = False, *, e=None, v=None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
raise RuntimeError(
|
||||
"This function was deprecated since version 1.9 and is now removed. Please use the `torch.linalg.eig` function instead.",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -638,6 +638,11 @@ class Tensor(torch._C._TensorBase):
|
|||
|
||||
return solve(self, other)
|
||||
|
||||
def eig(self, eigenvectors=False):
|
||||
from ._linalg_utils import eig
|
||||
|
||||
return eig(self, eigenvectors=eigenvectors)
|
||||
|
||||
def lu(self, pivot=True, get_infos=False):
|
||||
r"""See :func:`torch.lu`"""
|
||||
# If get_infos is True, then we don't need to check for errors and vice versa
|
||||
|
|
|
|||
|
|
@ -1692,15 +1692,6 @@ See :func:`torch.dot`
|
|||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"eig",
|
||||
r"""
|
||||
eig(eigenvectors=False) -> (Tensor, Tensor)
|
||||
|
||||
See :func:`torch.eig`
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"element_size",
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -3999,92 +3999,6 @@ Example::
|
|||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.eig,
|
||||
r"""
|
||||
eig(input, eigenvectors=False, *, out=None) -> (Tensor, Tensor)
|
||||
|
||||
Computes the eigenvalues and eigenvectors of a real square matrix.
|
||||
|
||||
.. note::
|
||||
Since eigenvalues and eigenvectors might be complex, backward pass is supported only
|
||||
if eigenvalues and eigenvectors are all real valued.
|
||||
|
||||
When :attr:`input` is on CUDA, :func:`torch.eig() <torch.eig>` causes
|
||||
host-device synchronization.
|
||||
|
||||
.. warning::
|
||||
|
||||
:func:`torch.eig` is deprecated in favor of :func:`torch.linalg.eig`
|
||||
and will be removed in a future PyTorch release.
|
||||
:func:`torch.linalg.eig` returns complex tensors of dtype `cfloat` or `cdouble`
|
||||
rather than real tensors mimicking complex tensors.
|
||||
|
||||
``L, _ = torch.eig(A)`` should be replaced with
|
||||
|
||||
.. code :: python
|
||||
|
||||
L_complex = torch.linalg.eigvals(A)
|
||||
|
||||
``L, V = torch.eig(A, eigenvectors=True)`` should be replaced with
|
||||
|
||||
.. code :: python
|
||||
|
||||
L_complex, V_complex = torch.linalg.eig(A)
|
||||
|
||||
Args:
|
||||
input (Tensor): the square matrix of shape :math:`(n \times n)` for which the eigenvalues and eigenvectors
|
||||
will be computed
|
||||
eigenvectors (bool): ``True`` to compute both eigenvalues and eigenvectors;
|
||||
otherwise, only eigenvalues will be computed
|
||||
|
||||
Keyword args:
|
||||
out (tuple, optional): the output tensors
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): A namedtuple (eigenvalues, eigenvectors) containing
|
||||
|
||||
- **eigenvalues** (*Tensor*): Shape :math:`(n \times 2)`. Each row is an eigenvalue of ``input``,
|
||||
where the first element is the real part and the second element is the imaginary part.
|
||||
The eigenvalues are not necessarily ordered.
|
||||
- **eigenvectors** (*Tensor*): If ``eigenvectors=False``, it's an empty tensor.
|
||||
Otherwise, this tensor of shape :math:`(n \times n)` can be used to compute normalized (unit length)
|
||||
eigenvectors of corresponding eigenvalues as follows.
|
||||
If the corresponding `eigenvalues[j]` is a real number, column `eigenvectors[:, j]` is the eigenvector
|
||||
corresponding to `eigenvalues[j]`.
|
||||
If the corresponding `eigenvalues[j]` and `eigenvalues[j + 1]` form a complex conjugate pair, then the
|
||||
true eigenvectors can be computed as
|
||||
:math:`\text{true eigenvector}[j] = eigenvectors[:, j] + i \times eigenvectors[:, j + 1]`,
|
||||
:math:`\text{true eigenvector}[j + 1] = eigenvectors[:, j] - i \times eigenvectors[:, j + 1]`.
|
||||
|
||||
Example::
|
||||
|
||||
Trivial example with a diagonal matrix. By default, only eigenvalues are computed:
|
||||
|
||||
>>> a = torch.diag(torch.tensor([1, 2, 3], dtype=torch.double))
|
||||
>>> e, v = torch.eig(a)
|
||||
>>> e
|
||||
tensor([[1., 0.],
|
||||
[2., 0.],
|
||||
[3., 0.]], dtype=torch.float64)
|
||||
>>> v
|
||||
tensor([], dtype=torch.float64)
|
||||
|
||||
Compute also the eigenvectors:
|
||||
|
||||
>>> e, v = torch.eig(a, eigenvectors=True)
|
||||
>>> e
|
||||
tensor([[1., 0.],
|
||||
[2., 0.],
|
||||
[3., 0.]], dtype=torch.float64)
|
||||
>>> v
|
||||
tensor([[1., 0., 0.],
|
||||
[0., 1., 0.],
|
||||
[0., 0., 1.]], dtype=torch.float64)
|
||||
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.eq,
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -3440,151 +3440,6 @@ Tensor svd_backward(
|
|||
return gA;
|
||||
}
|
||||
|
||||
// The implementation follows:
|
||||
// "An extended collection of matrix derivative results for forward and reverse
|
||||
// mode algorithmic differentiation"
|
||||
// https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf
|
||||
// However, the reference does not cover the constraints on eigenvectors to have
|
||||
// 1-norm. See the details below.
|
||||
Tensor eig_backward(
|
||||
const std::vector<torch::autograd::Variable>& grads,
|
||||
const Tensor& self,
|
||||
bool is_eigvec_tensor_nonempty,
|
||||
const Tensor& eigenvalues,
|
||||
const Tensor& eigenvectors) {
|
||||
at::NoTF32Guard disable_tf32;
|
||||
TORCH_CHECK(
|
||||
is_eigvec_tensor_nonempty,
|
||||
"eig_backward: torch.eig(eigenvalues=False) is not differentiable. ",
|
||||
"Please use torch.linalg.eigvals");
|
||||
|
||||
// variable names correspond to the ones in the reference document
|
||||
auto D = eigenvalues;
|
||||
const auto& U = eigenvectors;
|
||||
auto D_grad = grads[0];
|
||||
auto U_grad = grads[1];
|
||||
|
||||
// The condition below is trying to marry torch.eig and torch.linalg.eig
|
||||
// for real inputs.
|
||||
//
|
||||
// For real inputs torch.eig returns a real 2D tensor representing real and
|
||||
// complex components of eigenvalues, while torch.linalg.eig will most likely
|
||||
// always return complex eigenvalues.
|
||||
if (!self.is_complex()) {
|
||||
Tensor is_imag_eigvals_zero;
|
||||
// path for torch.eig with always a "real" 2D tensor of eigenvalues
|
||||
if (!D.is_complex()) {
|
||||
// narrow extracts the column corresponding to the imaginary part
|
||||
is_imag_eigvals_zero = (D.narrow(-1, 1, 1) == 0.0).min();
|
||||
}
|
||||
// path for torch.linalg.eig with always a complex tensor of eigenvalues
|
||||
else {
|
||||
is_imag_eigvals_zero = (at::imag(D) == 0.0).min();
|
||||
// insert an additional dimension to be compatible with torch.eig.
|
||||
// Recall that it produces 2D tensors.
|
||||
// We extract only the real parts as there is no support for
|
||||
// complex eigenvalues with real inputs yet.
|
||||
D = at::real(D).unsqueeze(-1);
|
||||
D_grad = at::real(D_grad).unsqueeze(-1);
|
||||
}
|
||||
// No support for complex eigenvalues for real inputs yet.
|
||||
TORCH_CHECK(
|
||||
at::is_scalar_tensor_true(is_imag_eigvals_zero),
|
||||
"eig_backward: Backward calculation does not support complex eigenvalues for real inputs at the moment.");
|
||||
} else {
|
||||
// torch.eig returns 2d tensors for eigenvalues,
|
||||
// while torch.linalg.eig returns 1d.
|
||||
// Hence we insert additional dimension for complex input,
|
||||
// such that the same code could be used for both methods.
|
||||
// It will become unnecessary once torch.eig is deprecated.
|
||||
D = D.unsqueeze(-1);
|
||||
if (D_grad.defined()) {
|
||||
D_grad = D_grad.unsqueeze(-1);
|
||||
}
|
||||
}
|
||||
|
||||
if (!D_grad.defined() && !U_grad.defined()) {
|
||||
return at::zeros_like(self, at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// Adapting the result from the reference above for the complex input, we get:
|
||||
//
|
||||
// A_grad = U^{-H} (D_grad + F.conj() * (U^H U_grad)) U^H,
|
||||
// where M^H := (M.mT()).conj() and * is the Hadamard (element-wise) product.
|
||||
//
|
||||
// torch.eig/torch.linalg.eig produce eigenvectors which are
|
||||
// normalized to 1 norm, and the reference does not take that into account.
|
||||
// Hence, we have to modify the formula accordingly.
|
||||
//
|
||||
// Normalization to 1 norm imposes the following constraint on the
|
||||
// eigenvectors, i.e. (U^H U) * I = I, where I is an identity matrix. Forward
|
||||
// AD for this expression yields: (dU^H U + U^H dU) * I = 0 => U^H dU * I = 0
|
||||
// <=> diag(U^H dU) = 0, which means that each i-th column of U is orthogonal
|
||||
// to the i-th column of dU. Now, the value of dU which does not take this
|
||||
// constraint into consideration comes straight from the reference: dU = U(F *
|
||||
// U^{-1} dA U). To make sure that U^H dU * I = 0, and using U^H U * I = I
|
||||
// (normalization), we propose a modifed forward AD for U: dU_new = dU - U(U^H
|
||||
// dU * I) (think of Gram-Schmidt)
|
||||
//
|
||||
// The rest is very similar to what is done in the reference and we finally
|
||||
// arrive at:
|
||||
//
|
||||
// A_grad = U^{-H} (D_grad + (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj())
|
||||
// U^H
|
||||
// = U^{-H} (eigenvalues_contribs + eigenvectors_contrib) U^H, where
|
||||
// eigenvalues_contribs := D_grad,
|
||||
// eigenvectors_contribs := (U^H U_grad - U^H U (U^H U_grad * I)) * F.conj().
|
||||
// The contributions from the eigenvectors and the eigenvalues are computed
|
||||
// below, and then we solve the system U^H A_grad = (eigenvalues_contribs +
|
||||
// eigenvectors_contribs) U_H to produce A_grad.
|
||||
|
||||
// contribution from the eigenvectors
|
||||
Tensor U_contrib;
|
||||
if (U_grad.defined()) {
|
||||
// narrow extracts the column corresponding to the real part
|
||||
D = D.narrow(-1, 0, 1);
|
||||
auto F = (D.mT() - D);
|
||||
if (!F.is_complex()) {
|
||||
F.diagonal(0, -2, -1).fill_(INFINITY);
|
||||
F.pow_(-1);
|
||||
} else {
|
||||
// The F matrix construction for complex eigenvalues
|
||||
// if different from its real counterpart.
|
||||
// There is no complex INFINITY, and we cannot use
|
||||
//
|
||||
// F.pow_(-1);
|
||||
// F.diagonal(0, -2, -1).fill_(0);
|
||||
//
|
||||
// as it breaks gradgradcheck by double backward
|
||||
// propagating nans through F.pow_(-1) at zero,
|
||||
// the point of discontinuity.
|
||||
// Hence this hack below.
|
||||
F.diagonal(0, -2, -1).fill_(1);
|
||||
F.pow_(-1);
|
||||
F.diagonal(0, -2, -1).fill_(0);
|
||||
}
|
||||
auto U_grad_proj_onto_U = at::matmul(U.mH(), U_grad);
|
||||
auto Uh_U = at::matmul(U.mH(), U);
|
||||
U_contrib = (U_grad_proj_onto_U -
|
||||
Uh_U * U_grad_proj_onto_U.diagonal(0, -2, -1).unsqueeze(-2)) *
|
||||
F.conj();
|
||||
} else {
|
||||
U_contrib = at::zeros_like(self, at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
// contributions from the eigenvalues
|
||||
Tensor D_contrib;
|
||||
if (D_grad.defined()) {
|
||||
// narrow extracts the column corresponding to the real part
|
||||
D_contrib = D_grad.narrow(-1, 0, 1);
|
||||
} else {
|
||||
D_contrib = at::zeros_like(D, at::MemoryFormat::Contiguous);
|
||||
}
|
||||
|
||||
return at::linalg_solve(
|
||||
U.mH(), at::matmul(U_contrib, U.mH()) + D_contrib * U.mH());
|
||||
}
|
||||
|
||||
Tensor linalg_eig_backward(
|
||||
const Tensor& gL,
|
||||
const Tensor& gV,
|
||||
|
|
|
|||
|
|
@ -628,12 +628,6 @@ Tensor linalg_qr_backward(
|
|||
const Tensor& Q,
|
||||
const Tensor& R,
|
||||
const c10::string_view mode);
|
||||
Tensor eig_backward(
|
||||
const std::vector<torch::autograd::Variable>& grads,
|
||||
const Tensor& self,
|
||||
bool eigenvectors,
|
||||
const Tensor& lambda,
|
||||
const Tensor& v);
|
||||
Tensor linalg_matrix_exp_differential(
|
||||
const Tensor& self,
|
||||
const Tensor& grad,
|
||||
|
|
|
|||
|
|
@ -254,6 +254,7 @@ def get_ignored_functions() -> Set[Callable]:
|
|||
Tensor.__subclasshook__,
|
||||
Tensor.__hash__,
|
||||
Tensor.as_subclass,
|
||||
Tensor.eig,
|
||||
Tensor.reinforce,
|
||||
Tensor.new,
|
||||
Tensor.new_tensor,
|
||||
|
|
@ -487,7 +488,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.hsmm: lambda mat1, mat2: -1,
|
||||
torch.dsplit: lambda input, indices_or_sections: -1,
|
||||
torch.dstack: lambda tensors, out=None: -1,
|
||||
torch.eig: lambda input, eigenvectors=False, out=None: -1,
|
||||
torch.linalg.eig: lambda input, out=None: -1,
|
||||
torch.linalg.eigvals: lambda input, out=None: -1,
|
||||
torch.linalg.eigh: lambda input, UPLO="L", out=None: -1,
|
||||
|
|
|
|||
|
|
@ -2147,14 +2147,6 @@ def error_inputs_renorm(op_info, device, **kwargs):
|
|||
yield ErrorInput(SampleInput(zero_d, args=(0.5, 0, 1.0)), error_type=RuntimeError,
|
||||
error_regex="needs at least 2 dimensions, got 0 dimensions")
|
||||
|
||||
def error_inputs_eig(op_info, device, **kwargs):
|
||||
zero_d = torch.randn((), device=device)
|
||||
|
||||
yield ErrorInput(SampleInput(zero_d, args=(False,)), error_type=RuntimeError,
|
||||
error_regex="input should be 2 dimensional")
|
||||
|
||||
yield ErrorInput(SampleInput(zero_d, args=(True,)), error_type=RuntimeError,
|
||||
error_regex="input should be 2 dimensional")
|
||||
|
||||
def error_inputs_ormqr(op_info, device, **kwargs):
|
||||
# this is only implemented on cpu
|
||||
|
|
@ -4822,41 +4814,6 @@ def sample_inputs_hardtanh(op_info, device, dtype, requires_grad=False, **kwargs
|
|||
|
||||
yield from sample_inputs_elementwise_unary(op_info, device, dtype, requires_grad)
|
||||
|
||||
def sample_inputs_eig(op_info, device, dtype, requires_grad=False, **kwargs):
|
||||
eigvecs = make_tensor((S, S), device=device, dtype=dtype,
|
||||
low=None, high=None)
|
||||
eigvals = make_tensor((S,), device=device, dtype=dtype,
|
||||
low=None, high=None)
|
||||
# we produce only diagonazible inputs which do not have
|
||||
# complex eigenvalues for real inputs, as there is no
|
||||
# backward implementation for real inputs with complex
|
||||
# eigenvalues yet.
|
||||
input = (eigvecs * eigvals.unsqueeze(-2)) @ eigvecs.inverse()
|
||||
input.requires_grad_(requires_grad)
|
||||
|
||||
def process_output(eigpair):
|
||||
eigvals, eigvecs = eigpair
|
||||
if dtype.is_complex:
|
||||
# eig produces eigenvectors which are normalized to 1 norm.
|
||||
# Note that if v is an eigenvector, so is v * e^{i \phi},
|
||||
# and |v| = |v * e^{i \phi}| = 1.
|
||||
# This, however, makes the eigenvector backward computation process
|
||||
# rather unstable unless the objective function is gauge-invariant,
|
||||
# that is if f(z) == f(|z|), for example.
|
||||
# Hence for complex inputs we ignore the phases and return only
|
||||
# the absolute values.
|
||||
return eigvals, eigvecs.abs()
|
||||
else:
|
||||
return eigvals, eigvecs
|
||||
|
||||
return [
|
||||
SampleInput(
|
||||
input,
|
||||
kwargs=dict(eigenvectors=True),
|
||||
output_process_fn_grad=process_output
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def sample_inputs_einsum(op_info, device, dtype, requires_grad=False, **kwargs):
|
||||
def c(t):
|
||||
|
|
@ -13001,16 +12958,6 @@ op_db: List[OpInfo] = [
|
|||
supports_sparse_bsr=True,
|
||||
supports_sparse_bsc=True,
|
||||
supports_autograd=False),
|
||||
OpInfo('eig',
|
||||
op=torch.eig,
|
||||
dtypes=floating_and_complex_types(),
|
||||
sample_inputs_func=sample_inputs_eig,
|
||||
error_inputs_func=error_inputs_eig,
|
||||
decorators=[
|
||||
skipCUDAIfNoMagma,
|
||||
skipCPUIfNoLapack,
|
||||
],
|
||||
),
|
||||
OpInfo('einsum',
|
||||
# we need this lambda because SampleInput expects tensor input as the first argument
|
||||
# TODO(@heitorschueroff) update SampleInput to handle such cases
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user