mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Change name for namedtuple return of torch.linalg.svd (#57181)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/57181 Documentation for torch.linalg.svd says: > The returned decomposition is a named tuple `(U, S, Vh)` The documentation is correct while the implementation was wrong. Renamed `V` -> `Vh`. `h` stands for hermitian. This is a BC-breaking change but our linalg module is beta, therefore we can do it without a deprecation notice or aliases. Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D28142162 Pulled By: mruberry fbshipit-source-id: 5e6e0ae5a63300f2db1575ca3259df381f8e1a7e
This commit is contained in:
parent
58f32fa5fd
commit
18fed3dfbe
|
|
@ -2894,7 +2894,7 @@ std::tuple<Tensor&, Tensor&, Tensor&> svd_out(const Tensor& self, bool some, boo
|
|||
1. the 2nd parameter is bool some=True, which if effectively the opposite
|
||||
of full_matrices=True
|
||||
|
||||
2. svd returns V, while linalg.svd returns VT = V^T (for real inputs) or VT = V^H (for complex inputs).
|
||||
2. svd returns V, while linalg.svd returns Vh = V^T (for real inputs) or Vh = V^H (for complex inputs).
|
||||
To accommodate the difference, we transpose() and conj() V upon return
|
||||
*/
|
||||
|
||||
|
|
@ -2906,8 +2906,8 @@ std::tuple<Tensor, Tensor, Tensor> linalg_svd(const Tensor& self, bool full_matr
|
|||
Tensor U, S, V;
|
||||
std::tie(U, S, V) = at::_svd_helper(self, some, /*compute_uv=*/true);
|
||||
|
||||
Tensor VT = V.conj().transpose(-2, -1);
|
||||
return std::make_tuple(U, S, VT);
|
||||
Tensor Vh = V.conj().transpose(-2, -1);
|
||||
return std::make_tuple(U, S, Vh);
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -2917,21 +2917,21 @@ static void svd_resize_and_copy(const char *name, const Tensor& src, Tensor &dst
|
|||
dst.copy_(src);
|
||||
}
|
||||
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> linalg_svd_out(const Tensor& self, bool full_matrices, Tensor& U, Tensor& S, Tensor& VT) {
|
||||
std::tuple<Tensor&, Tensor&, Tensor&> linalg_svd_out(const Tensor& self, bool full_matrices, Tensor& U, Tensor& S, Tensor& Vh) {
|
||||
checkSameDevice("svd", U, self, "U");
|
||||
checkSameDevice("svd", S, self, "S");
|
||||
checkSameDevice("svd", VT, self, "VT");
|
||||
checkSameDevice("svd", Vh, self, "Vh");
|
||||
checkLinalgCompatibleDtype("linalg_svd", U, self, "U");
|
||||
checkLinalgCompatibleDtype("linalg_svd", VT, self, "VT");
|
||||
checkLinalgCompatibleDtype("linalg_svd", Vh, self, "Vh");
|
||||
// singular values are always real-valued here
|
||||
ScalarType real_dtype = toValueType(self.scalar_type());
|
||||
checkLinalgCompatibleDtype("linalg_svd", S.scalar_type(), real_dtype, "S");
|
||||
Tensor U_tmp, S_tmp, VT_tmp;
|
||||
std::tie(U_tmp, S_tmp, VT_tmp) = at::native::linalg_svd(self, full_matrices);
|
||||
Tensor U_tmp, S_tmp, Vh_tmp;
|
||||
std::tie(U_tmp, S_tmp, Vh_tmp) = at::native::linalg_svd(self, full_matrices);
|
||||
svd_resize_and_copy("U", U_tmp, U);
|
||||
svd_resize_and_copy("S", S_tmp, S);
|
||||
svd_resize_and_copy("V", VT_tmp, VT);
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(U, S, VT);
|
||||
svd_resize_and_copy("V", Vh_tmp, Vh);
|
||||
return std::tuple<Tensor&, Tensor&, Tensor&>(U, S, Vh);
|
||||
}
|
||||
|
||||
Tensor linalg_svdvals(const Tensor& input) {
|
||||
|
|
|
|||
|
|
@ -9636,10 +9636,10 @@
|
|||
dispatch:
|
||||
CPU, CUDA: linalg_vector_norm_out
|
||||
|
||||
- func: linalg_svd.U(Tensor self, bool full_matrices=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)
|
||||
- func: linalg_svd.U(Tensor self, bool full_matrices=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)
|
||||
python_module: linalg
|
||||
|
||||
- func: linalg_svd(Tensor self, bool full_matrices=True) -> (Tensor U, Tensor S, Tensor V)
|
||||
- func: linalg_svd(Tensor self, bool full_matrices=True) -> (Tensor U, Tensor S, Tensor Vh)
|
||||
python_module: linalg
|
||||
variants: function
|
||||
|
||||
|
|
|
|||
|
|
@ -2793,7 +2793,7 @@ class TestLinalg(TestCase):
|
|||
|
||||
out_u = torch.empty(0, dtype=dtype, device=device)
|
||||
if svd == torch.linalg.svd:
|
||||
msg = "but got VT with dtype Int"
|
||||
msg = "but got Vh with dtype Int"
|
||||
else:
|
||||
msg = "but got V with dtype Int"
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
|
|
|
|||
|
|
@ -62,7 +62,8 @@ class TestNamedTupleAPI(TestCase):
|
|||
names=('values', 'indices'), hasout=True),
|
||||
op(operators=['kthvalue'], input=(1, 0),
|
||||
names=('values', 'indices'), hasout=True),
|
||||
op(operators=['svd', '_svd_helper', 'linalg_svd'], input=(), names=('U', 'S', 'V'), hasout=True),
|
||||
op(operators=['svd', '_svd_helper'], input=(), names=('U', 'S', 'V'), hasout=True),
|
||||
op(operators=['linalg_svd'], input=(), names=('U', 'S', 'Vh'), hasout=True),
|
||||
op(operators=['slogdet'], input=(), names=('sign', 'logabsdet'), hasout=False),
|
||||
op(operators=['qr', 'linalg_qr'], input=(), names=('Q', 'R'), hasout=True),
|
||||
op(operators=['solve'], input=(a,), names=('solution', 'LU'), hasout=True),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user