Add renorm forward-ad (#100798)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100798
Approved by: https://github.com/soulitzer
This commit is contained in:
Li-Huai (Allan) Lin 2023-06-05 16:29:48 +00:00 committed by PyTorch MergeBot
parent d89c719160
commit 1c2dfdf30c
5 changed files with 58 additions and 2 deletions

View File

@ -1122,6 +1122,7 @@ class TestOperators(TestCase):
xfail('as_strided_scatter', ''),
xfail('masked.cumprod', ''),
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
xfail("renorm"), # hit vmap fallback, which is disabled
}))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
@ -1383,7 +1384,6 @@ class TestOperators(TestCase):
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
xfail('renorm', ''), # NYI: forward AD for renorm
xfail('ormqr', ''), # NYI: forward AD for ormqr
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
@ -1543,7 +1543,6 @@ class TestOperators(TestCase):
xfail('normal', 'number_mean'), # calls random op
xfail('pca_lowrank'), # calls random op
xfail('quantile'), # Batching rule not implemented for aten::equal
xfail('renorm'), # Forward AD not implemented and no decomposition
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition

View File

@ -1345,6 +1345,7 @@
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
self: renorm_backward(grad, self, p, dim, maxnorm)
result: renorm_jvp(self_p, self_t, p, dim, maxnorm)
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
self: repeat_backward(grad, repeats, self.sym_sizes())

View File

@ -1523,6 +1523,54 @@ Tensor renorm_backward(
return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
}
Tensor renorm_jvp(
const Tensor& self_p,
const Tensor& self_t,
const Scalar& p,
int64_t dim,
const Scalar& maxnorm) {
auto self_sizes = self_p.sizes();
dim = c10::maybe_wrap_dim(dim, self_sizes.size());
at::DimVector reduce_dims(self_sizes.size());
std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
reduce_dims.erase(reduce_dims.begin() + dim);
// For cuda half, calculate norm in float precision then cast
// normalization factor to half
auto dtype = self_p.scalar_type();
auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
if (acc_type != dtype) {
return at::linalg_vector_norm(
self_p,
p.toDouble(),
reduce_dims,
/*keepdim=*/true,
/*dtype=*/acc_type);
} else {
return at::linalg_vector_norm(
self_p,
p.toDouble(),
reduce_dims,
/*keepdim=*/true);
}
}();
auto double_maxnorm = maxnorm.toDouble();
auto invnorm = (norm + 1e-7).reciprocal();
auto factor = invnorm * double_maxnorm;
return where(
norm > double_maxnorm,
factor *
(self_t -
self_p * invnorm *
norm_jvp(
self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
self_t);
}
Tensor repeat_backward(
Tensor grad,
c10::SymIntArrayRef repeats,

View File

@ -303,6 +303,12 @@ at::Tensor renorm_backward(
const at::Scalar& p,
int64_t dim,
const at::Scalar& maxnorm);
at::Tensor renorm_jvp(
const at::Tensor& self_p,
const at::Tensor& self_t,
const at::Scalar& p,
int64_t dim,
const at::Scalar& maxnorm);
at::Tensor repeat_backward(
at::Tensor grad,
at::SymIntArrayRef repeats,

View File

@ -16448,6 +16448,8 @@ op_db: List[OpInfo] = [
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_renorm,
error_inputs_func=error_inputs_renorm,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
skips=(
# RuntimeError: Difference from float64 is larger with decomposition
# linalg_vector_norm.default than original on output 0.