mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Fix nuclear norm with requires_grad=True (#26303)
Summary: Changelog: - Selectively assign compute_uv in the at::svd used internally in the implementation of at::nuclear_norm Pull Request resolved: https://github.com/pytorch/pytorch/pull/26303 Test Plan: - Add tests in common_method_invocations.py Refixes: https://github.com/pytorch/pytorch/issues/18275 Differential Revision: D17605357 Pulled By: ezyang fbshipit-source-id: d87d60afe678e2546dca6992ea66f2daeb6b0346
This commit is contained in:
parent
0e3389dced
commit
43b07ff2c4
|
|
@ -6,6 +6,7 @@
|
|||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/LegacyTHFunctionsCPU.h>
|
||||
#include <ATen/core/grad_mode.h>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <vector>
|
||||
|
|
@ -549,7 +550,11 @@ Tensor nuclear_norm(const Tensor& self, bool keepdim) {
|
|||
self.dim() == 2,
|
||||
"Expected a tensor with 2 dimensions, but got a tensor with ",
|
||||
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
|
||||
return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
|
||||
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
|
||||
// would end up throwing an error as a result if U and V aren't computed.
|
||||
// Due to this, we have to compute U and V conditionally.
|
||||
return at::sum(std::get<1>(at::svd(self, /*some=*/true,
|
||||
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), 0, keepdim);
|
||||
}
|
||||
|
||||
Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
|
||||
|
|
@ -557,14 +562,19 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
|
|||
self.dim() == 2,
|
||||
"Expected a tensor with 2 dimensions, but got a tensor with ",
|
||||
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
|
||||
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
|
||||
return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
|
||||
|
||||
}
|
||||
|
||||
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
|
||||
|
||||
Tensor p = _move_to_end(self, dim);
|
||||
return at::sum(std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
|
||||
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
|
||||
// would end up throwing an error as a result if U and V aren't computed.
|
||||
// Due to this, we have to compute U and V conditionally.
|
||||
return at::sum(std::get<1>(at::svd(p, /*some=*/true,
|
||||
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), -1, keepdim);
|
||||
}
|
||||
|
||||
Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
|
||||
|
|
@ -572,6 +582,7 @@ Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bo
|
|||
|
||||
Tensor p = _move_to_end(self, dim);
|
||||
return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
|
||||
|
||||
}
|
||||
|
||||
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {
|
||||
|
|
|
|||
|
|
@ -534,6 +534,7 @@ def method_tests():
|
|||
('norm', (S, S), ('fro',), 'fro_default'),
|
||||
('norm', (S, S), ('fro', [0, 1],), 'fro'),
|
||||
('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
|
||||
('norm', (S, S, S), ('nuc', [1, 2]), 'nuc_batched', (), NO_ARGS, [skipIfNoLapack]),
|
||||
('norm', (S, S), (-1,), 'neg_1'),
|
||||
('norm', (S, S), (-2,), 'neg_2'),
|
||||
('norm', (S, S), (-0.5,), 'neg_0_5'),
|
||||
|
|
|
|||
|
|
@ -16334,6 +16334,7 @@ EXCLUDE_SCRIPT = {
|
|||
'test_norm_fro',
|
||||
'test_norm_fro_default',
|
||||
'test_norm_nuc',
|
||||
'test_norm_nuc_batched',
|
||||
|
||||
# aten op has additional cudnn argument
|
||||
'test_nn_unfold',
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user