mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
fix kl_div for negative targets
ghstack-source-id: d69d60f4fe
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69212
This commit is contained in:
parent
334339a3d2
commit
20c2bb4c9f
|
|
@ -147,21 +147,10 @@ Tensor margin_ranking_loss(const Tensor& input1, const Tensor& input2, const Ten
|
|||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
||||
Tensor _kl_div_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
|
||||
auto output = at::exp(target) * (target - input);
|
||||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
||||
Tensor _kl_div_non_log_target(const Tensor& input, const Tensor& target, int64_t reduction) {
|
||||
auto output_pos = target * (at::log(target) - input);
|
||||
auto zeros = at::zeros_like(output_pos, LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
auto output = at::where(target > 0, output_pos, zeros);
|
||||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
||||
Tensor kl_div(const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
|
||||
return log_target ? _kl_div_log_target(input, target, reduction)
|
||||
: _kl_div_non_log_target(input, target, reduction);
|
||||
auto output = log_target ? at::exp(target) * (target - input)
|
||||
: at::xlogy(target, target) - target * input;
|
||||
return apply_loss_reduction(output, reduction);
|
||||
}
|
||||
|
||||
Tensor kl_div_backward_cpu(const Tensor& grad, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
|
||||
|
|
|
|||
|
|
@ -7092,16 +7092,16 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
@skipIfUnsupportedMinOpsetVersion(9)
|
||||
def test_kldiv_loss(self):
|
||||
|
||||
x = torch.randn(5)
|
||||
y = torch.randn(5)
|
||||
x = torch.rand(5).log()
|
||||
y = torch.rand(5)
|
||||
self._kldiv_loss(x, y)
|
||||
|
||||
x = torch.randn(2, 3, 5)
|
||||
y = torch.randn(2, 3, 5)
|
||||
x = torch.rand(2, 3, 5).log()
|
||||
y = torch.rand(2, 3, 5)
|
||||
self._kldiv_loss(x, y)
|
||||
|
||||
x = torch.randn(2, 3, 5, 7)
|
||||
y = torch.randn(2, 3, 5, 7)
|
||||
x = torch.rand(2, 3, 5, 7).log()
|
||||
y = torch.rand(2, 3, 5, 7)
|
||||
self._kldiv_loss(x, y)
|
||||
|
||||
def _kldiv_loss(self, x, y):
|
||||
|
|
@ -7111,7 +7111,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.loss = torch.nn.KLDivLoss(reduction="none", log_target=True)
|
||||
|
||||
def forward(self, input, target):
|
||||
return self.loss(input, target)
|
||||
return self.loss(input, target.log())
|
||||
|
||||
self.run_test(KLDivLossNone(), input=(x, y))
|
||||
|
||||
|
|
@ -7131,7 +7131,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.loss = torch.nn.KLDivLoss(reduction="sum", log_target=True)
|
||||
|
||||
def forward(self, input, target):
|
||||
return self.loss(input, target)
|
||||
return self.loss(input, target.log())
|
||||
|
||||
self.run_test(KLDivLossSum(), input=(x, y))
|
||||
|
||||
|
|
@ -7151,7 +7151,7 @@ class TestONNXRuntime(unittest.TestCase):
|
|||
self.loss = torch.nn.KLDivLoss(reduction="batchmean", size_average=False, log_target=True)
|
||||
|
||||
def forward(self, input, target):
|
||||
return self.loss(input, target)
|
||||
return self.loss(input, target.log())
|
||||
|
||||
self.run_test(KLDivLossMiniBatchMean(), input=(x, y))
|
||||
|
||||
|
|
|
|||
|
|
@ -1740,9 +1740,9 @@
|
|||
self: not_implemented("embedding_renorm")
|
||||
|
||||
- name: kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
|
||||
self: kl_div_backward(grad, self, target, reduction, log_target)
|
||||
target: kl_div_target_backward(grad, self, target, reduction, log_target)
|
||||
result: apply_loss_reduction(kl_div_backward(self_t, self_p, target_p, at::Reduction::None, log_target) + kl_div_target_backward(target_t, self_p, target_p, at::Reduction::None, log_target), reduction)
|
||||
self: kl_div_backward_aux(grad, self, target, reduction, log_target)
|
||||
target: kl_div_target_backward_aux(grad, self, target, reduction, log_target)
|
||||
result: apply_loss_reduction(kl_div_backward_aux(self_t, self_p, target_p, at::Reduction::None, log_target) + kl_div_target_backward_aux(target_t, self_p, target_p, at::Reduction::None, log_target), reduction)
|
||||
|
||||
- name: l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor
|
||||
self: l1_loss_backward(grad, self, target, reduction)
|
||||
|
|
@ -2170,11 +2170,6 @@
|
|||
self: zeros_like(grad)
|
||||
result: at::where((self_p > min_val).logical_and(self_p < max_val), grad_output_t, at::zeros({}, result.options()).expand_as(result))
|
||||
|
||||
- name: kl_div_backward(Tensor grad_output, Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor
|
||||
grad_output: kl_div_double_backward_grad_output(grad, self, target, reduction, log_target)
|
||||
self: zeros_like(grad)
|
||||
target: zeros_like(grad)
|
||||
|
||||
- name: l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor
|
||||
grad_output: l1_loss_double_backward_grad_output(grad, grad_output, self, target, reduction)
|
||||
self: l1_loss_double_backward(grad, grad_output, self, target, reduction)
|
||||
|
|
|
|||
|
|
@ -1378,31 +1378,33 @@ Tensor infinitely_differentiable_logit_backward(
|
|||
}
|
||||
}
|
||||
|
||||
Tensor kl_div_double_backward_grad_output(const Tensor & grad, const Tensor & input, const Tensor & target, int64_t reduction, bool log_target) {
|
||||
auto result = kl_div_backward(grad, input, target, at::Reduction::None, log_target);
|
||||
// TODO: remove kl_div_backward from aten namespace and drop the aux here
|
||||
Tensor kl_div_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
|
||||
auto grad_input = (
|
||||
log_target ? -at::exp(target)
|
||||
: -target
|
||||
) * grad_output;
|
||||
if (reduction == at::Reduction::Mean) {
|
||||
return result.mean();
|
||||
} else if (reduction == at::Reduction::Sum) {
|
||||
return result.sum();
|
||||
grad_input /= input.numel();
|
||||
}
|
||||
return result;
|
||||
return grad_input;
|
||||
}
|
||||
|
||||
// Compute derivatives for targets.
|
||||
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target) {
|
||||
// TODO: remove kl_div_backward from aten namespace and drop the aux here
|
||||
Tensor kl_div_target_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target) {
|
||||
Tensor grad_target;
|
||||
if (!log_target) {
|
||||
if (!areAnyTensorSubclassLike({self, target}) && !grad_output._is_zerotensor()) {
|
||||
grad_target = grad_output.mul(target.log().add_(1).sub_(self)).masked_fill_(target == 0, 0.);
|
||||
if (!areAnyTensorSubclassLike({input, target}) && !grad_output._is_zerotensor()) {
|
||||
grad_target = grad_output.mul(target.log().add_(1).sub_(input)).masked_fill_(target == 0, 0.);
|
||||
} else {
|
||||
grad_target = grad_output.mul(target.log().add(1).sub(self)).masked_fill(target == 0, 0.);
|
||||
grad_target = grad_output.mul(target.log().add(1).sub(input)).masked_fill(target == 0, 0.);
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!areAnyTensorSubclassLike({self, target})) {
|
||||
grad_target = grad_output.mul(target.add(1).sub_(self).mul_(target.exp()));
|
||||
if (!areAnyTensorSubclassLike({input, target})) {
|
||||
grad_target = grad_output.mul(target.add(1).sub_(input).mul_(target.exp()));
|
||||
} else {
|
||||
grad_target = grad_output.mul(target.add(1).sub(self).mul_(target.exp()));
|
||||
grad_target = grad_output.mul(target.add(1).sub(input).mul_(target.exp()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1413,7 +1415,6 @@ Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, in
|
|||
grad_target.div(target.numel());
|
||||
}
|
||||
}
|
||||
|
||||
return grad_target;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -126,7 +126,8 @@ at::Tensor glu_double_backward_grad_output(const at::Tensor & grad, const at::Te
|
|||
at::Tensor infinitely_differentiable_silu_backward(const at::Tensor& grad_output, const at::Tensor& input);
|
||||
at::Tensor infinitely_differentiable_mish_backward(const at::Tensor& grad_output, const at::Tensor& input);
|
||||
Tensor infinitely_differentiable_logit_backward(const Tensor& grad, const Tensor& self, c10::optional<double> eps);
|
||||
at::Tensor kl_div_double_backward_grad_output(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & target, int64_t reduction, bool log_target);
|
||||
Tensor kl_div_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target);
|
||||
Tensor kl_div_target_backward_aux(const Tensor& grad_output, const Tensor& input, const Tensor& target, int64_t reduction, bool log_target);
|
||||
Tensor binary_cross_entropy_target_backward(
|
||||
const Tensor& grad,
|
||||
const Tensor& self,
|
||||
|
|
@ -262,7 +263,6 @@ std::tuple<Tensor, Tensor, Tensor> batchnorm_double_backward(
|
|||
const c10::optional<Tensor> & save_invstd,
|
||||
std::array<bool,3> output_mask);
|
||||
std::tuple<Tensor, Tensor> _euclidean_dist_backward(const Tensor & grad, const Tensor & x1, const Tensor & x2, const Tensor & res);
|
||||
Tensor kl_div_target_backward(Tensor grad_output, Tensor self, Tensor target, int64_t reduction, bool log_target);
|
||||
Tensor fft_backward(const Tensor& self, const Tensor& grad, int64_t signal_ndim,
|
||||
bool complex_input, bool complex_output,
|
||||
bool inverse, IntArrayRef checked_signal_sizes,
|
||||
|
|
|
|||
|
|
@ -15532,22 +15532,13 @@ op_db: List[OpInfo] = [
|
|||
"nn.functional.kl_div",
|
||||
sample_inputs_func=sample_inputs_kl_div,
|
||||
dtypes=floating_types_and(torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64),
|
||||
backward_dtypesIfCPU=floating_types_and(torch.int8, torch.int16, torch.int32, torch.int64),
|
||||
dtypesIfCUDA=floating_types_and(
|
||||
torch.float16, torch.bfloat16, torch.int8, torch.int16, torch.int32, torch.int64
|
||||
),
|
||||
backward_dtypesIfCUDA=floating_types_and(torch.float16, torch.int8, torch.int16, torch.int32, torch.int64),
|
||||
supports_out=False,
|
||||
check_batched_grad=False,
|
||||
supports_forward_ad=True,
|
||||
skips=(
|
||||
# See https://github.com/pytorch/pytorch/issues/65466
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
"TestGradients",
|
||||
"test_fn_gradgrad",
|
||||
),
|
||||
),
|
||||
gradcheck_fast_mode=False,
|
||||
check_batched_grad=False,
|
||||
),
|
||||
OpInfo(
|
||||
"diagflat",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import tempfile
|
|||
import unittest
|
||||
|
||||
from copy import deepcopy
|
||||
from functools import reduce
|
||||
from functools import reduce, partial, wraps
|
||||
from itertools import product
|
||||
from operator import mul
|
||||
from math import pi
|
||||
|
|
@ -499,7 +499,7 @@ def bce_with_logistic_no_reduce_scalar_test():
|
|||
|
||||
|
||||
def kldivloss_with_target_no_reduce_test():
|
||||
i = torch.rand(10, 10).log()
|
||||
i = torch.rand(10, 10)
|
||||
return dict(
|
||||
fullname='KLDivLoss_with_target_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
|
|
@ -514,13 +514,13 @@ def kldivloss_with_target_no_reduce_test():
|
|||
|
||||
|
||||
def kldivloss_no_reduce_test():
|
||||
t = torch.randn(10, 10)
|
||||
t = torch.rand(10, 10)
|
||||
return dict(
|
||||
fullname='KLDivLoss_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
|
||||
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
|
||||
input_fn=lambda: torch.rand(10, 10).log(),
|
||||
input_fn=lambda: torch.rand(10, 10),
|
||||
cpp_var_map={'i': '_get_input()', 't': t},
|
||||
reference_fn=lambda i, *_:
|
||||
loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
|
||||
|
|
@ -530,13 +530,13 @@ def kldivloss_no_reduce_test():
|
|||
|
||||
|
||||
def kldivloss_no_reduce_scalar_test():
|
||||
t = torch.randn(())
|
||||
t = torch.rand(())
|
||||
return dict(
|
||||
fullname='KLDivLoss_no_reduce_scalar',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.kl_div(i, t.type_as(i), reduction='none')),
|
||||
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone))',
|
||||
input_fn=lambda: torch.rand(()).log(),
|
||||
input_fn=lambda: torch.rand(()),
|
||||
cpp_var_map={'i': '_get_input()', 't': t},
|
||||
reference_fn=lambda i, *_:
|
||||
loss_reference_fns['KLDivLoss'](i, t.type_as(i), reduction='none'),
|
||||
|
|
@ -545,7 +545,7 @@ def kldivloss_no_reduce_scalar_test():
|
|||
|
||||
|
||||
def kldivloss_with_log_target_no_reduce_test():
|
||||
i = torch.rand(10, 10).log()
|
||||
i = torch.rand(10, 10)
|
||||
return dict(
|
||||
fullname='KLDivLoss_with_log_target_no_reduce',
|
||||
constructor=wrap_functional(
|
||||
|
|
@ -560,13 +560,13 @@ def kldivloss_with_log_target_no_reduce_test():
|
|||
|
||||
|
||||
def kldivloss_no_reduce_log_target_test():
|
||||
t = torch.randn(10, 10)
|
||||
t = torch.rand(10, 10).log()
|
||||
return dict(
|
||||
fullname='KLDivLoss_no_reduce_log_target',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
|
||||
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
|
||||
input_fn=lambda: torch.rand(10, 10).log(),
|
||||
input_fn=lambda: torch.rand(10, 10),
|
||||
cpp_var_map={'i': '_get_input()', 't': t},
|
||||
reference_fn=lambda i, *_:
|
||||
loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
|
||||
|
|
@ -576,13 +576,13 @@ def kldivloss_no_reduce_log_target_test():
|
|||
|
||||
|
||||
def kldivloss_no_reduce_scalar_log_target_test():
|
||||
t = torch.randn(())
|
||||
t = torch.rand(()).log()
|
||||
return dict(
|
||||
fullname='KLDivLoss_no_reduce_scalar_log_target',
|
||||
constructor=wrap_functional(
|
||||
lambda i: F.kl_div(i, t.type_as(i), reduction='none', log_target=True)),
|
||||
cpp_function_call='F::kl_div(i, t.to(i.options()), F::KLDivFuncOptions().reduction(torch::kNone).log_target(true))',
|
||||
input_fn=lambda: torch.rand(()).log(),
|
||||
input_fn=lambda: torch.rand(()),
|
||||
cpp_var_map={'i': '_get_input()', 't': t},
|
||||
reference_fn=lambda i, *_:
|
||||
loss_reference_fns['KLDivLoss_log_target'](i, t.type_as(i), reduction='none'),
|
||||
|
|
@ -4336,9 +4336,9 @@ for non_linear_activation in non_linear_activations_no_batch:
|
|||
|
||||
|
||||
def kldivloss_reference(input, target, reduction='mean'):
|
||||
safe_target = target * (target > 0).type_as(target)
|
||||
safe_target_log = (safe_target + (target <= 0).type_as(target)).log()
|
||||
result = safe_target * (safe_target_log - input)
|
||||
result = target * (target.log() - input)
|
||||
# continuous extension 0 * log(0) := 0
|
||||
result[target == 0] = 0
|
||||
if reduction == 'mean':
|
||||
return result.mean()
|
||||
elif reduction == 'sum':
|
||||
|
|
@ -4842,7 +4842,7 @@ criterion_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='KLDivLoss',
|
||||
input_fn=lambda: torch.rand(10, 10).log(),
|
||||
input_fn=lambda: torch.rand(10, 10),
|
||||
target_fn=lambda: torch.rand(10, 10),
|
||||
reference_fn=lambda i, t, m:
|
||||
kldivloss_reference(i, t, get_reduction(m)),
|
||||
|
|
@ -4850,10 +4850,12 @@ criterion_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='KLDivLoss',
|
||||
input_fn=lambda: torch.rand(10, 10).log(),
|
||||
target_fn=lambda: torch.rand(10, 10),
|
||||
constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
|
||||
cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
|
||||
input_fn=lambda: torch.rand(10, 10),
|
||||
target_fn=lambda: torch.rand(10, 10).log(),
|
||||
reference_fn=lambda i, t, m:
|
||||
kldivloss_log_target_reference(i, t.log(), get_reduction(m)),
|
||||
kldivloss_log_target_reference(i, t, get_reduction(m)),
|
||||
check_sum_reduction=True,
|
||||
desc='log_target',
|
||||
),
|
||||
|
|
@ -5438,7 +5440,7 @@ criterion_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='KLDivLoss',
|
||||
input_fn=lambda: torch.rand(()).log(),
|
||||
input_fn=lambda: torch.rand(()),
|
||||
target_fn=lambda: torch.rand(()),
|
||||
reference_fn=lambda i, t, m:
|
||||
kldivloss_reference(i, t, get_reduction(m)),
|
||||
|
|
@ -5447,10 +5449,12 @@ criterion_tests = [
|
|||
),
|
||||
dict(
|
||||
module_name='KLDivLoss',
|
||||
input_fn=lambda: torch.rand(()).log(),
|
||||
target_fn=lambda: torch.rand(()),
|
||||
constructor=wraps(nn.KLDivLoss)(partial(nn.KLDivLoss, log_target=True)),
|
||||
cpp_constructor_args='torch::nn::KLDivLossOptions().log_target(true)',
|
||||
input_fn=lambda: torch.rand(()),
|
||||
target_fn=lambda: torch.rand(()).log(),
|
||||
reference_fn=lambda i, t, m:
|
||||
kldivloss_log_target_reference(i, t.log(), get_reduction(m)),
|
||||
kldivloss_log_target_reference(i, t, get_reduction(m)),
|
||||
check_sum_reduction=True,
|
||||
desc='scalar_log_target',
|
||||
),
|
||||
|
|
@ -5647,7 +5651,7 @@ def single_batch_reference_criterion_fn(*args):
|
|||
|
||||
# Check that regression criterion work with no batch dimensions
|
||||
regression_criterion_no_batch = [
|
||||
'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'KLDivLoss', 'HuberLoss', 'SmoothL1Loss'
|
||||
'L1Loss', 'MSELoss', 'PoissonNLLLoss', 'HuberLoss', 'SmoothL1Loss'
|
||||
]
|
||||
reductions = ['none', 'mean', 'sum']
|
||||
for name, reduction in product(regression_criterion_no_batch, reductions):
|
||||
|
|
@ -5661,6 +5665,17 @@ for name, reduction in product(regression_criterion_no_batch, reductions):
|
|||
)
|
||||
criterion_tests.append(regression_test_info)
|
||||
|
||||
for reduction in reductions:
|
||||
regression_test_info = dict(
|
||||
fullname=f"KLDivLoss_no_batch_dim_{reduction}",
|
||||
constructor=lambda: nn.KLDivLoss(reduction=reduction),
|
||||
input_fn=lambda: torch.rand((3,)),
|
||||
target_fn=lambda: torch.rand((3,)),
|
||||
reference_fn=single_batch_reference_criterion_fn,
|
||||
test_cpp_api_parity=False,
|
||||
)
|
||||
criterion_tests.append(regression_test_info)
|
||||
|
||||
|
||||
# Check that classification criterion work with no batch dimensions
|
||||
# List of tuples of (name, input_fn, target_fn)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user