mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add renorm forward-ad (#100798)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100798 Approved by: https://github.com/soulitzer
This commit is contained in:
parent
d89c719160
commit
1c2dfdf30c
|
|
@ -1122,6 +1122,7 @@ class TestOperators(TestCase):
|
||||||
xfail('as_strided_scatter', ''),
|
xfail('as_strided_scatter', ''),
|
||||||
xfail('masked.cumprod', ''),
|
xfail('masked.cumprod', ''),
|
||||||
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
|
xfail("_upsample_bilinear2d_aa"), # hit vmap fallback, which is disabled
|
||||||
|
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||||
}))
|
}))
|
||||||
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
||||||
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
|
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
|
||||||
|
|
@ -1383,7 +1384,6 @@ class TestOperators(TestCase):
|
||||||
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
|
xfail('nn.functional.hardsigmoid', ''), # NYI: forward AD for hardsigmoid_backward
|
||||||
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
|
xfail('nn.functional.huber_loss', ''), # NYI: forward AD for huber_loss_backward
|
||||||
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
|
xfail('NumpyCubeNotComposableAutogradFunction'), # not composable
|
||||||
xfail('renorm', ''), # NYI: forward AD for renorm
|
|
||||||
xfail('ormqr', ''), # NYI: forward AD for ormqr
|
xfail('ormqr', ''), # NYI: forward AD for ormqr
|
||||||
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
|
xfail('nn.functional.multilabel_margin_loss', ''), # NYI: multilabel_margin_loss_forward
|
||||||
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
|
xfail('nn.functional.soft_margin_loss', ''), # NYI: forward-AD for soft_margin_loss_backward
|
||||||
|
|
@ -1543,7 +1543,6 @@ class TestOperators(TestCase):
|
||||||
xfail('normal', 'number_mean'), # calls random op
|
xfail('normal', 'number_mean'), # calls random op
|
||||||
xfail('pca_lowrank'), # calls random op
|
xfail('pca_lowrank'), # calls random op
|
||||||
xfail('quantile'), # Batching rule not implemented for aten::equal
|
xfail('quantile'), # Batching rule not implemented for aten::equal
|
||||||
xfail('renorm'), # Forward AD not implemented and no decomposition
|
|
||||||
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
|
xfail('scatter_reduce', 'prod'), # Forward AD not implemented and no decomposition
|
||||||
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
|
xfail('_segment_reduce', 'lengths'), # Forward AD not implemented and no decomposition
|
||||||
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
|
xfail('_segment_reduce', 'offsets'), # Forward AD not implemented and no decomposition
|
||||||
|
|
|
||||||
|
|
@ -1345,6 +1345,7 @@
|
||||||
|
|
||||||
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
|
- name: renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor
|
||||||
self: renorm_backward(grad, self, p, dim, maxnorm)
|
self: renorm_backward(grad, self, p, dim, maxnorm)
|
||||||
|
result: renorm_jvp(self_p, self_t, p, dim, maxnorm)
|
||||||
|
|
||||||
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
|
- name: repeat(Tensor self, SymInt[] repeats) -> Tensor
|
||||||
self: repeat_backward(grad, repeats, self.sym_sizes())
|
self: repeat_backward(grad, repeats, self.sym_sizes())
|
||||||
|
|
|
||||||
|
|
@ -1523,6 +1523,54 @@ Tensor renorm_backward(
|
||||||
return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
|
return at::where(norm > maxnorm, grad_norm.to(grad.scalar_type()), grad);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor renorm_jvp(
|
||||||
|
const Tensor& self_p,
|
||||||
|
const Tensor& self_t,
|
||||||
|
const Scalar& p,
|
||||||
|
int64_t dim,
|
||||||
|
const Scalar& maxnorm) {
|
||||||
|
auto self_sizes = self_p.sizes();
|
||||||
|
dim = c10::maybe_wrap_dim(dim, self_sizes.size());
|
||||||
|
|
||||||
|
at::DimVector reduce_dims(self_sizes.size());
|
||||||
|
std::iota(reduce_dims.begin(), reduce_dims.end(), 0);
|
||||||
|
reduce_dims.erase(reduce_dims.begin() + dim);
|
||||||
|
|
||||||
|
// For cuda half, calculate norm in float precision then cast
|
||||||
|
// normalization factor to half
|
||||||
|
auto dtype = self_p.scalar_type();
|
||||||
|
auto acc_type = at::toAccumulateType(dtype, /*is_cuda=*/true);
|
||||||
|
Tensor norm = [&self_p, &p, &reduce_dims, acc_type, dtype]() {
|
||||||
|
if (acc_type != dtype) {
|
||||||
|
return at::linalg_vector_norm(
|
||||||
|
self_p,
|
||||||
|
p.toDouble(),
|
||||||
|
reduce_dims,
|
||||||
|
/*keepdim=*/true,
|
||||||
|
/*dtype=*/acc_type);
|
||||||
|
} else {
|
||||||
|
return at::linalg_vector_norm(
|
||||||
|
self_p,
|
||||||
|
p.toDouble(),
|
||||||
|
reduce_dims,
|
||||||
|
/*keepdim=*/true);
|
||||||
|
}
|
||||||
|
}();
|
||||||
|
|
||||||
|
auto double_maxnorm = maxnorm.toDouble();
|
||||||
|
auto invnorm = (norm + 1e-7).reciprocal();
|
||||||
|
auto factor = invnorm * double_maxnorm;
|
||||||
|
|
||||||
|
return where(
|
||||||
|
norm > double_maxnorm,
|
||||||
|
factor *
|
||||||
|
(self_t -
|
||||||
|
self_p * invnorm *
|
||||||
|
norm_jvp(
|
||||||
|
self_p, self_t, p, norm, reduce_dims, /*keepdim=*/true)),
|
||||||
|
self_t);
|
||||||
|
}
|
||||||
|
|
||||||
Tensor repeat_backward(
|
Tensor repeat_backward(
|
||||||
Tensor grad,
|
Tensor grad,
|
||||||
c10::SymIntArrayRef repeats,
|
c10::SymIntArrayRef repeats,
|
||||||
|
|
|
||||||
|
|
@ -303,6 +303,12 @@ at::Tensor renorm_backward(
|
||||||
const at::Scalar& p,
|
const at::Scalar& p,
|
||||||
int64_t dim,
|
int64_t dim,
|
||||||
const at::Scalar& maxnorm);
|
const at::Scalar& maxnorm);
|
||||||
|
at::Tensor renorm_jvp(
|
||||||
|
const at::Tensor& self_p,
|
||||||
|
const at::Tensor& self_t,
|
||||||
|
const at::Scalar& p,
|
||||||
|
int64_t dim,
|
||||||
|
const at::Scalar& maxnorm);
|
||||||
at::Tensor repeat_backward(
|
at::Tensor repeat_backward(
|
||||||
at::Tensor grad,
|
at::Tensor grad,
|
||||||
at::SymIntArrayRef repeats,
|
at::SymIntArrayRef repeats,
|
||||||
|
|
|
||||||
|
|
@ -16448,6 +16448,8 @@ op_db: List[OpInfo] = [
|
||||||
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
dtypes=floating_and_complex_types_and(torch.float16, torch.bfloat16),
|
||||||
sample_inputs_func=sample_inputs_renorm,
|
sample_inputs_func=sample_inputs_renorm,
|
||||||
error_inputs_func=error_inputs_renorm,
|
error_inputs_func=error_inputs_renorm,
|
||||||
|
supports_forward_ad=True,
|
||||||
|
supports_fwgrad_bwgrad=True,
|
||||||
skips=(
|
skips=(
|
||||||
# RuntimeError: Difference from float64 is larger with decomposition
|
# RuntimeError: Difference from float64 is larger with decomposition
|
||||||
# linalg_vector_norm.default than original on output 0.
|
# linalg_vector_norm.default than original on output 0.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user