diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 9a755b98b28..4a1734ebc34 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -1122,6 +1122,7 @@ class TestOperators(TestCase): xfail('as_strided_scatter', ''), xfail('masked.cumprod', ''), xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled + xfail("renorm"), # hit vmap fallback, which is disabled })) @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)}) def test_vmapjvpall_has_batch_rule(self, device, dtype, op): @@ -1383,7 +1384,6 @@ class TestOperators(TestCase): xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward xfail('NumpyCubeNotComposableAutogradFunction'), # not composable - xfail('renorm', ''), # NYI: forward AD for renorm xfail('ormqr', ''), # NYI: forward AD for ormqr xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward @@ -1543,7 +1543,6 @@ class TestOperators(TestCase): xfail('normal', 'number_mean'), # calls random op xfail('pca_lowrank'), # calls random op xfail('quantile'), # Batching rule not implemented for aten::equal - xfail('renorm'), # Forward AD not implemented and no decomposition xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 69d042700fa..7998ceb989b 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -1345,6 +1345,7 @@ - name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor self: renorm_backward(grad, self, p, dim, maxnorm) + result: renorm_jvp(self_p, self_t, p, dim, maxnorm) - name: repeat(Tensor self, SymInt[] repeats) -> Tensor self: repeat_backward(grad, repeats, self.sym_sizes()) diff --git a/torch/csrc/autograd/FunctionsManual.cpp b/torch/csrc/autograd/FunctionsManual.cpp index 372347b5eea..d60d2107118 100644 --- a/torch/csrc/autograd/FunctionsManual.cpp +++ b/torch/csrc/autograd/FunctionsManual.cpp @@ -1523,6 +1523,54 @@ Tensor renorm_backward( return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad); } +Tensor renorm_jvp( + const Tensor& self_p, + const Tensor& self_t, + const Scalar& p, + int64_t dim, + const Scalar& maxnorm) { + auto self_sizes = self_p.sizes(); + dim = c10::maybe_wrap_dim(dim, self_sizes.size()); + + at::DimVector reduce_dims(self_sizes.size()); + std::iota(reduce_dims.begin(), reduce_dims.end(), 0); + reduce_dims.erase(reduce_dims.begin() + dim); + + // For cuda half, calculate norm in float precision then cast + // normalization factor to half + auto dtype = self_p.scalar_type(); + auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true); + Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() { + if (acc_type != dtype) { + return at::linalg_vector_norm( + self_p, + p.toDouble(), + reduce_dims, + /*keepdim=*/true, + /*dtype=*/acc_type); + } else { + return at::linalg_vector_norm( + self_p, + p.toDouble(), + reduce_dims, + /*keepdim=*/true); + } + }(); + + auto double_maxnorm = maxnorm.toDouble(); + auto invnorm = (norm + 1e-7).reciprocal(); + auto factor = invnorm * double_maxnorm; + + return where( + norm > double_maxnorm, + factor * + (self_t - + self_p * invnorm * + norm_jvp( + self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)), + self_t); +} + Tensor repeat_backward( Tensor grad, c10::SymIntArrayRef repeats, diff --git a/torch/csrc/autograd/FunctionsManual.h b/torch/csrc/autograd/FunctionsManual.h index 023632bfbf9..684976bef25 100644 --- a/torch/csrc/autograd/FunctionsManual.h +++ b/torch/csrc/autograd/FunctionsManual.h @@ -303,6 +303,12 @@ at::Tensor renorm_backward( const at::Scalar& p, int64_t dim, const at::Scalar& maxnorm); +at::Tensor renorm_jvp( + const at::Tensor& self_p, + const at::Tensor& self_t, + const at::Scalar& p, + int64_t dim, + const at::Scalar& maxnorm); at::Tensor repeat_backward( at::Tensor grad, at::SymIntArrayRef repeats, diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index de4801d595e..0ac5cb25ab0 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -16448,6 +16448,8 @@ op_db: List[OpInfo] = [ dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16), sample_inputs_func=sample_inputs_renorm, error_inputs_func=error_inputs_renorm, + supports_forward_ad=True, + supports_fwgrad_bwgrad=True, skips=( # RuntimeError: Difference from float64 is larger with decomposition # linalg_vector_norm.default than original on output 0.