fix kl_div for negative targets

ghstack-source-id: d69d60f4fe
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69212
This commit is contained in:
Philip Meier 2022-02-08 14:36:26 +01:00
parent 334339a3d2
commit 20c2bb4c9f
7 changed files with 73 additions and 82 deletions

View File

@ -147,21 +147,10 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten
return apply_loss_reduction(output, reduction);
}
Tensor _kl_div_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
auto output = at::exp(target) * (target - input);
return apply_loss_reduction(output, reduction);
}
Tensor _kl_div_non_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
auto output_pos = target * (at::log(target) - input);
auto zeros = at::zeros_like(output_pos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
auto output = at::where(target > 0, output_pos, zeros);
return apply_loss_reduction(output, reduction);
}
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
return log_target ? _kl_div_log_target(input, target, reduction)
: _kl_div_non_log_target(input, target, reduction);
auto output = log_target ? at::exp(target) * (target - input)
: at::xlogy(target, target) - target * input;
return apply_loss_reduction(output, reduction);
}
Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {

View File

@ -7092,16 +7092,16 @@ class TestONNXRuntime(unittest.TestCase):
@skipIfUnsupportedMinOpsetVersion(9)
def test_kldiv_loss(self):
x = torch.randn(5)
y = torch.randn(5)
x = torch.rand(5).log()
y = torch.rand(5)
self._kldiv_loss(x, y)
x = torch.randn(2, 3, 5)
y = torch.randn(2, 3, 5)
x = torch.rand(2, 3, 5).log()
y = torch.rand(2, 3, 5)
self._kldiv_loss(x, y)
x = torch.randn(2, 3, 5, 7)
y = torch.randn(2, 3, 5, 7)
x = torch.rand(2, 3, 5, 7).log()
y = torch.rand(2, 3, 5, 7)
self._kldiv_loss(x, y)
def _kldiv_loss(self, x, y):
@ -7111,7 +7111,7 @@ class TestONNXRuntime(unittest.TestCase):
self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
def forward(self, input, target):
return self.loss(input, target)
return self.loss(input, target.log())
self.run_test(KLDivLossNone(), input=(x, y))
@ -7131,7 +7131,7 @@ class TestONNXRuntime(unittest.TestCase):
self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
def forward(self, input, target):
return self.loss(input, target)
return self.loss(input, target.log())
self.run_test(KLDivLossSum(), input=(x, y))
@ -7151,7 +7151,7 @@ class TestONNXRuntime(unittest.TestCase):
self.loss = torch.nn.KLDivLoss(reduction="batchmean", size_average=False, log_target=True)
def forward(self, input, target):
return self.loss(input, target)
return self.loss(input, target.log())
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))

View File

@ -1740,9 +1740,9 @@
self: not_implemented("embedding_renorm")
- name: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
self: kl_div_backward(grad, self, target, reduction, log_target)
target: kl_div_target_backward(grad, self, target, reduction, log_target)
result: apply_loss_reduction(kl_div_backward(self_t, self_p, target_p, at::Reduction::None, log_target) + kl_div_target_backward(target_t, self_p, target_p, at::Reduction::None, log_target), reduction)
self: kl_div_backward_aux(grad, self, target, reduction, log_target)
target: kl_div_target_backward_aux(grad, self, target, reduction, log_target)
result: apply_loss_reduction(kl_div_backward_aux(self_t, self_p, target_p, at::Reduction::None, log_target) + kl_div_target_backward_aux(target_t, self_p, target_p, at::Reduction::None, log_target), reduction)
- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
self: l1_loss_backward(grad, self, target, reduction)
@ -2170,11 +2170,6 @@
self: zeros_like(grad)
result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction, log_target)
self: zeros_like(grad)
target: zeros_like(grad)
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
grad_output: l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
self: l1_loss_double_backward(grad, grad_output, self, target, reduction)

View File

@ -1378,31 +1378,33 @@ Tensor infinitely_differentiable_logit_backward(
}
}
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
// TODO: remove kl_div_backward from aten namespace and drop the aux here
Tensor kl_div_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
auto grad_input = (
log_target ? -at::exp(target)
: -target
) * grad_output;
if (reduction == at::Reduction::Mean) {
return result.mean();
} else if (reduction == at::Reduction::Sum) {
return result.sum();
grad_input /= input.numel();
}
return result;
return grad_input;
}
// Compute derivatives for targets.
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
// TODO: remove kl_div_backward from aten namespace and drop the aux here
Tensor kl_div_target_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
Tensor grad_target;
if (!log_target) {
if (!areAnyTensorSubclassLike({self, target}) && !grad_output._is_zerotensor()) {
grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
if (!areAnyTensorSubclassLike({input, target}) && !grad_output._is_zerotensor()) {
grad_target = grad_output.mul(target.log().add_(1).sub_(input)).masked_fill_(target == 0, 0.);
} else {
grad_target = grad_output.mul(target.log().add(1).sub(self)).masked_fill(target == 0, 0.);
grad_target = grad_output.mul(target.log().add(1).sub(input)).masked_fill(target == 0, 0.);
}
}
else {
if (!areAnyTensorSubclassLike({self, target})) {
grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
if (!areAnyTensorSubclassLike({input, target})) {
grad_target = grad_output.mul(target.add(1).sub_(input).mul_(target.exp()));
} else {
grad_target = grad_output.mul(target.add(1).sub(self).mul_(target.exp()));
grad_target = grad_output.mul(target.add(1).sub(input).mul_(target.exp()));
}
}
@ -1413,7 +1415,6 @@ Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, in
grad_target.div(target.numel());
}
}
return grad_target;
}

View File

@ -126,7 +126,8 @@ at::Tensor glu_double_backward_grad_output(const at::Tensor & grad, const at::Te
at::Tensor infinitely_differentiable_silu_backward(const at::Tensor& grad_output, const at::Tensor& input);
at::Tensor infinitely_differentiable_mish_backward(const at::Tensor& grad_output, const at::Tensor& input);
Tensor infinitely_differentiable_logit_backward(const Tensor& grad, const Tensor& self, c10::optional<double> eps);
at::Tensor kl_div_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, bool log_target);
Tensor kl_div_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target);
Tensor kl_div_target_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target);
Tensor binary_cross_entropy_target_backward(
const Tensor& grad,
const Tensor& self,
@ -262,7 +263,6 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
const c10::optional<Tensor> & save_invstd,
std::array<bool,3> output_mask);
std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res);
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target);
Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim,
bool complex_input, bool complex_output,
bool inverse, IntArrayRef checked_signal_sizes,

View File

@ -15532,22 +15532,13 @@ op_db: List[OpInfo] = [
"nn.functional.kl_div",
sample_inputs_func=sample_inputs_kl_div,
dtypes=floating_types_and(torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64),
backward_dtypesIfCPU=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64),
dtypesIfCUDA=floating_types_and(
torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64
),
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.int8, torch.int16, torch.int32, torch.int64),
supports_out=False,
check_batched_grad=False,
supports_forward_ad=True,
skips=(
# See https://github.com/pytorch/pytorch/issues/65466
DecorateInfo(
unittest.expectedFailure,
"TestGradients",
"test_fn_gradgrad",
),
),
gradcheck_fast_mode=False,
check_batched_grad=False,
),
OpInfo(
"diagflat",

View File

@ -4,7 +4,7 @@ import tempfile
import unittest
from copy import deepcopy
from functools import reduce
from functools import reduce, partial, wraps
from itertools import product
from operator import mul
from math import pi
@ -499,7 +499,7 @@ def bce_with_logistic_no_reduce_scalar_test():
def kldivloss_with_target_no_reduce_test():
i = torch.rand(10, 10).log()
i = torch.rand(10, 10)
return dict(
fullname='KLDivLoss_with_target_no_reduce',
constructor=wrap_functional(
@ -514,13 +514,13 @@ def kldivloss_with_target_no_reduce_test():
def kldivloss_no_reduce_test():
t = torch.randn(10, 10)
t = torch.rand(10, 10)
return dict(
fullname='KLDivLoss_no_reduce',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
input_fn=lambda: torch.rand(10, 10).log(),
input_fn=lambda: torch.rand(10, 10),
cpp_var_map={'i': '_get_input()', 't': t},
reference_fn=lambda i, *_:
loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
@ -530,13 +530,13 @@ def kldivloss_no_reduce_test():
def kldivloss_no_reduce_scalar_test():
t = torch.randn(())
t = torch.rand(())
return dict(
fullname='KLDivLoss_no_reduce_scalar',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
input_fn=lambda: torch.rand(()).log(),
input_fn=lambda: torch.rand(()),
cpp_var_map={'i': '_get_input()', 't': t},
reference_fn=lambda i, *_:
loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
@ -545,7 +545,7 @@ def kldivloss_no_reduce_scalar_test():
def kldivloss_with_log_target_no_reduce_test():
i = torch.rand(10, 10).log()
i = torch.rand(10, 10)
return dict(
fullname='KLDivLoss_with_log_target_no_reduce',
constructor=wrap_functional(
@ -560,13 +560,13 @@ def kldivloss_with_log_target_no_reduce_test():
def kldivloss_no_reduce_log_target_test():
t = torch.randn(10, 10)
t = torch.rand(10, 10).log()
return dict(
fullname='KLDivLoss_no_reduce_log_target',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
input_fn=lambda: torch.rand(10, 10).log(),
input_fn=lambda: torch.rand(10, 10),
cpp_var_map={'i': '_get_input()', 't': t},
reference_fn=lambda i, *_:
loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
@ -576,13 +576,13 @@ def kldivloss_no_reduce_log_target_test():
def kldivloss_no_reduce_scalar_log_target_test():
t = torch.randn(())
t = torch.rand(()).log()
return dict(
fullname='KLDivLoss_no_reduce_scalar_log_target',
constructor=wrap_functional(
lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
input_fn=lambda: torch.rand(()).log(),
input_fn=lambda: torch.rand(()),
cpp_var_map={'i': '_get_input()', 't': t},
reference_fn=lambda i, *_:
loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
@ -4336,9 +4336,9 @@ for non_linear_activation in non_linear_activations_no_batch:
def kldivloss_reference(input, target, reduction='mean'):
safe_target = target * (target > 0).type_as(target)
safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
result = safe_target * (safe_target_log - input)
result = target * (target.log() - input)
# continuous extension 0 * log(0) := 0
result[target == 0] = 0
if reduction == 'mean':
return result.mean()
elif reduction == 'sum':
@ -4842,7 +4842,7 @@ criterion_tests = [
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch.rand(10, 10).log(),
input_fn=lambda: torch.rand(10, 10),
target_fn=lambda: torch.rand(10, 10),
reference_fn=lambda i, t, m:
kldivloss_reference(i, t, get_reduction(m)),
@ -4850,10 +4850,12 @@ criterion_tests = [
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch.rand(10, 10).log(),
target_fn=lambda: torch.rand(10, 10),
constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
input_fn=lambda: torch.rand(10, 10),
target_fn=lambda: torch.rand(10, 10).log(),
reference_fn=lambda i, t, m:
kldivloss_log_target_reference(i, t.log(), get_reduction(m)),
kldivloss_log_target_reference(i, t, get_reduction(m)),
check_sum_reduction=True,
desc='log_target',
),
@ -5438,7 +5440,7 @@ criterion_tests = [
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch.rand(()).log(),
input_fn=lambda: torch.rand(()),
target_fn=lambda: torch.rand(()),
reference_fn=lambda i, t, m:
kldivloss_reference(i, t, get_reduction(m)),
@ -5447,10 +5449,12 @@ criterion_tests = [
),
dict(
module_name='KLDivLoss',
input_fn=lambda: torch.rand(()).log(),
target_fn=lambda: torch.rand(()),
constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
input_fn=lambda: torch.rand(()),
target_fn=lambda: torch.rand(()).log(),
reference_fn=lambda i, t, m:
kldivloss_log_target_reference(i, t.log(), get_reduction(m)),
kldivloss_log_target_reference(i, t, get_reduction(m)),
check_sum_reduction=True,
desc='scalar_log_target',
),
@ -5647,7 +5651,7 @@ def single_batch_reference_criterion_fn(*args):
# Check that regression criterion work with no batch dimensions
regression_criterion_no_batch = [
'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'KLDivLoss', 'HuberLoss', 'SmoothL1Loss'
'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
]
reductions = ['none', 'mean', 'sum']
for name, reduction in product(regression_criterion_no_batch, reductions):
@ -5661,6 +5665,17 @@ for name, reduction in product(regression_criterion_no_batch, reductions):
)
criterion_tests.append(regression_test_info)
for reduction in reductions:
regression_test_info = dict(
fullname=f"KLDivLoss_no_batch_dim_{reduction}",
constructor=lambda: nn.KLDivLoss(reduction=reduction),
input_fn=lambda: torch.rand((3,)),
target_fn=lambda: torch.rand((3,)),
reference_fn=single_batch_reference_criterion_fn,
test_cpp_api_parity=False,
)
criterion_tests.append(regression_test_info)
# Check that classification criterion work with no batch dimensions
# List of tuples of (name, input_fn, target_fn)