Fix nuclear norm with requires_grad=True (#26303)

Summary:
Changelog:
- Selectively assign compute_uv in the at::svd used internally in the implementation of at::nuclear_norm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26303

Test Plan:
- Add tests in common_method_invocations.py

Refixes: https://github.com/pytorch/pytorch/issues/18275

Differential Revision: D17605357

Pulled By: ezyang

fbshipit-source-id: d87d60afe678e2546dca6992ea66f2daeb6b0346
This commit is contained in:
vishwakftw 2019-09-26 12:06:34 -07:00 committed by Facebook Github Bot
parent 0e3389dced
commit 43b07ff2c4
3 changed files with 16 additions and 3 deletions

View File

@ -6,6 +6,7 @@
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
#include <ATen/core/grad_mode.h>
#include <functional>
#include <numeric>
#include <vector>
@ -549,7 +550,11 @@ Tensor nuclear_norm(const Tensor& self, bool keepdim) {
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum(std::get<1>(at::svd(self)), 0, keepdim);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
// would end up throwing an error as a result if U and V aren't computed.
// Due to this, we have to compute U and V conditionally.
return at::sum(std::get<1>(at::svd(self, /*some=*/true,
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), 0, keepdim);
}
Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
@ -557,14 +562,19 @@ Tensor &nuclear_norm_out(Tensor& result, const Tensor& self, bool keepdim) {
self.dim() == 2,
"Expected a tensor with 2 dimensions, but got a tensor with ",
self.dim(), " dimension", self.dim()==1 ? "" : "s", " instead.");
return at::sum_out(result, std::get<1>(at::svd(self)), 0, keepdim);
return at::sum_out(result, std::get<1>(at::svd(self, /*some=*/true, /*compute_uv=*/false)), 0, keepdim);
}
Tensor nuclear_norm(const Tensor& self, IntArrayRef dim, bool keepdim) {
TORCH_CHECK(dim.size() == 2, "nuclear norm requires a 'dim' argument of size 2");
Tensor p = _move_to_end(self, dim);
return at::sum(std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
// Since we error out on svd_backward when we don't compute U and V, the backward pass for nuclear_norm
// would end up throwing an error as a result if U and V aren't computed.
// Due to this, we have to compute U and V conditionally.
return at::sum(std::get<1>(at::svd(p, /*some=*/true,
/*compute_uv=*/at::GradMode::is_enabled() && self.is_variable() && self.requires_grad())), -1, keepdim);
}
Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bool keepdim) {
@ -572,6 +582,7 @@ Tensor& nuclear_norm_out(Tensor& result, const Tensor& self, IntArrayRef dim, bo
Tensor p = _move_to_end(self, dim);
return at::sum_out(result, std::get<1>(at::svd(p, /*some=*/true, /*compute_uv=*/false)), -1, keepdim);
}
static inline Tensor _chain_matmul_general(TensorList matrices, std::vector<std::vector<int64_t>>& order, int64_t i, int64_t j) {

View File

@ -534,6 +534,7 @@ def method_tests():
('norm', (S, S), ('fro',), 'fro_default'),
('norm', (S, S), ('fro', [0, 1],), 'fro'),
('norm', (S, S), ('nuc',), 'nuc', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S, S), ('nuc', [1, 2]), 'nuc_batched', (), NO_ARGS, [skipIfNoLapack]),
('norm', (S, S), (-1,), 'neg_1'),
('norm', (S, S), (-2,), 'neg_2'),
('norm', (S, S), (-0.5,), 'neg_0_5'),

View File

@ -16334,6 +16334,7 @@ EXCLUDE_SCRIPT = {
'test_norm_fro',
'test_norm_fro_default',
'test_norm_nuc',
'test_norm_nuc_batched',
# aten op has additional cudnn argument
'test_nn_unfold',