mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add renorm forward-ad (#100798)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100798 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
d89c719160
commit
1c2dfdf30c
|
|
@ -1122,6 +1122,7 @@ class TestOperators(TestCase):
|
|||
xfail('as_strided_scatter', ''),
|
||||
xfail('masked.cumprod', ''),
|
||||
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
|
||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||
}))
|
||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
|
||||
|
|
@ -1383,7 +1384,6 @@ class TestOperators(TestCase):
|
|||
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
|
||||
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
|
||||
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
|
||||
xfail('renorm', ''), # NYI: forward AD for renorm
|
||||
xfail('ormqr', ''), # NYI: forward AD for ormqr
|
||||
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
|
||||
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
|
||||
|
|
@ -1543,7 +1543,6 @@ class TestOperators(TestCase):
|
|||
xfail('normal', 'number_mean'), # calls random op
|
||||
xfail('pca_lowrank'), # calls random op
|
||||
xfail('quantile'), # Batching rule not implemented for aten::equal
|
||||
xfail('renorm'), # Forward AD not implemented and no decomposition
|
||||
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
|
||||
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
|
||||
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
|
||||
|
|
|
|||
|
|
@ -1345,6 +1345,7 @@
|
|||
|
||||
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
|
||||
self: renorm_backward(grad, self, p, dim, maxnorm)
|
||||
result: renorm_jvp(self_p, self_t, p, dim, maxnorm)
|
||||
|
||||
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
|
||||
self: repeat_backward(grad, repeats, self.sym_sizes())
|
||||
|
|
|
|||
|
|
@ -1523,6 +1523,54 @@ Tensor renorm_backward(
|
|||
return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
|
||||
}
|
||||
|
||||
Tensor renorm_jvp(
|
||||
const Tensor& self_p,
|
||||
const Tensor& self_t,
|
||||
const Scalar& p,
|
||||
int64_t dim,
|
||||
const Scalar& maxnorm) {
|
||||
auto self_sizes = self_p.sizes();
|
||||
dim = c10::maybe_wrap_dim(dim, self_sizes.size());
|
||||
|
||||
at::DimVector reduce_dims(self_sizes.size());
|
||||
std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
|
||||
reduce_dims.erase(reduce_dims.begin() + dim);
|
||||
|
||||
// For cuda half, calculate norm in float precision then cast
|
||||
// normalization factor to half
|
||||
auto dtype = self_p.scalar_type();
|
||||
auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
|
||||
Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
|
||||
if (acc_type != dtype) {
|
||||
return at::linalg_vector_norm(
|
||||
self_p,
|
||||
p.toDouble(),
|
||||
reduce_dims,
|
||||
/*keepdim=*/true,
|
||||
/*dtype=*/acc_type);
|
||||
} else {
|
||||
return at::linalg_vector_norm(
|
||||
self_p,
|
||||
p.toDouble(),
|
||||
reduce_dims,
|
||||
/*keepdim=*/true);
|
||||
}
|
||||
}();
|
||||
|
||||
auto double_maxnorm = maxnorm.toDouble();
|
||||
auto invnorm = (norm + 1e-7).reciprocal();
|
||||
auto factor = invnorm * double_maxnorm;
|
||||
|
||||
return where(
|
||||
norm > double_maxnorm,
|
||||
factor *
|
||||
(self_t -
|
||||
self_p * invnorm *
|
||||
norm_jvp(
|
||||
self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
|
||||
self_t);
|
||||
}
|
||||
|
||||
Tensor repeat_backward(
|
||||
Tensor grad,
|
||||
c10::SymIntArrayRef repeats,
|
||||
|
|
|
|||
|
|
@ -303,6 +303,12 @@ at::Tensor renorm_backward(
|
|||
const at::Scalar& p,
|
||||
int64_t dim,
|
||||
const at::Scalar& maxnorm);
|
||||
at::Tensor renorm_jvp(
|
||||
const at::Tensor& self_p,
|
||||
const at::Tensor& self_t,
|
||||
const at::Scalar& p,
|
||||
int64_t dim,
|
||||
const at::Scalar& maxnorm);
|
||||
at::Tensor repeat_backward(
|
||||
at::Tensor grad,
|
||||
at::SymIntArrayRef repeats,
|
||||
|
|
|
|||
|
|
@ -16448,6 +16448,8 @@ op_db: List[OpInfo] = [
|
|||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_renorm,
|
||||
error_inputs_func=error_inputs_renorm,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
skips=(
|
||||
# RuntimeError: Difference from float64 is larger with decomposition
|
||||
# linalg_vector_norm.default than original on output 0.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user