mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add new reduction mode in kl_div (#14457)
Summary: Fixes #6622 . We used to average over all elements for kl divergence, which is not aligned with its math definition. This PR corrects the default reduction behavior of KL divergence that it now naverages over batch dimension. - In KL, default behavior `reduction=mean` averages over batch dimension. While for most other loss functions, `reduction=mean` averages over all elements. - We used to support scalar tensor as well. For BC purpose, we still support it, no reduction is performed on scalar tensor. - Added a new reduction mode called `batchmean` which has the correct behavior for KL. Add a warning to make `batchmean` as default for KL instead of `mean` in next major release. - [deprecated]I chose to not add a new reduction option, since "mean over batch dimension" is kinda special, and it only makes sense in few cases like KL. We don't want to explain why there's a option "batchmean" but it's not applicable for all other functions. I'm open to discussion on this one, as I cannot think of a perfect solution for this. Pull Request resolved: https://github.com/pytorch/pytorch/pull/14457 Differential Revision: D13236016 Pulled By: ailzhang fbshipit-source-id: 905cc7b3bfc35a11d7cf098b1ebc382170a087a7
This commit is contained in:
parent
773f4d8081
commit
ef91cfd68b
|
|
@ -2208,6 +2208,8 @@ def kldivloss_reference(input, target, reduction='mean'):
|
|||
return result.mean()
|
||||
elif reduction == 'sum':
|
||||
return result.sum()
|
||||
elif reduction == 'batchmean' and results.dim() != 0:
|
||||
return result.sum() / result.size(0)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4016,6 +4016,19 @@ class TestNN(NNTestCase):
|
|||
with self.assertRaisesRegex(ValueError, 'Expected.*batch_size'):
|
||||
F.nll_loss(x, t)
|
||||
|
||||
def test_KLDivLoss_batch_mean(self):
|
||||
input_shape = (2, 5)
|
||||
log_prob1 = F.log_softmax(torch.randn(input_shape), 1)
|
||||
prob2 = F.softmax(torch.randn(input_shape), 1)
|
||||
|
||||
loss = nn.KLDivLoss(reduction='batchmean')
|
||||
l = loss(log_prob1, prob2)
|
||||
|
||||
loss_none_reduce = nn.KLDivLoss(reduction='sum')(log_prob1, prob2)
|
||||
expected = loss_none_reduce / input_shape[0]
|
||||
|
||||
self.assertEqual(l, expected)
|
||||
|
||||
@unittest.skipIf(not (TEST_CUDNN and TEST_CUDNN_VERSION >= 7000), "needs cudnn >= 7.0")
|
||||
def test_CTCLoss_cudnn(self):
|
||||
target_lengths = [30, 25, 20]
|
||||
|
|
|
|||
|
|
@ -1885,17 +1885,40 @@ def kl_div(input, target, size_average=None, reduce=None, reduction='mean'):
|
|||
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
|
||||
batch element instead and ignores :attr:`size_average`. Default: ``True``
|
||||
reduction (string, optional): Specifies the reduction to apply to the output:
|
||||
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
|
||||
'mean': the sum of the output will be divided by the number of
|
||||
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
|
||||
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
|
||||
specifying either of those two args will override :attr:`reduction`. Default: 'mean'
|
||||
'none' | 'batchmean' | 'sum' | 'mean'.
|
||||
'none': no reduction will be applied
|
||||
'batchmean': the sum of the output will be divided by the batchsize
|
||||
'sum': the output will be summed
|
||||
'mean': the output will be divided by the number of elements in the output
|
||||
Default: 'mean'
|
||||
|
||||
.. note:: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
|
||||
and in the meantime, specifying either of those two args will override :attr:`reduction`.
|
||||
|
||||
.. note:: `reduction='mean'` doesn't return the true kl divergence value, please use
|
||||
`reduction='batchmean'` which aligns with KL math definition.
|
||||
In the next major release, 'mean' will be changed to be the same as 'batchmean'.
|
||||
"""
|
||||
if size_average is not None or reduce is not None:
|
||||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
return torch.kl_div(input, target, reduction_enum)
|
||||
if reduction == 'mean':
|
||||
warnings.warn("reduction: 'mean' divides the total loss by both the batch size and the support size."
|
||||
"'batchmean' divides only by the batch size, and aligns with the KL div math definition."
|
||||
"'mean' will be changed to behave the same as 'batchmean' in the next major release.")
|
||||
|
||||
# special case for batchmean
|
||||
if reduction == 'batchmean':
|
||||
reduction_enum = _Reduction.get_enum('sum')
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
|
||||
reduced = torch.kl_div(input, target, reduction_enum)
|
||||
|
||||
if reduction == 'batchmean' and input.dim() != 0:
|
||||
reduced = reduced / input.size()[0]
|
||||
|
||||
return reduced
|
||||
|
||||
|
||||
@torch._jit_internal.weak_script
|
||||
|
|
|
|||
|
|
@ -282,7 +282,7 @@ class KLDivLoss(_Loss):
|
|||
|
||||
As with :class:`~torch.nn.NLLLoss`, the `input` given is expected to contain
|
||||
*log-probabilities*. However, unlike :class:`~torch.nn.NLLLoss`, `input` is not
|
||||
restricted to a 2D Tensor, because the criterion is applied element-wise.
|
||||
restricted to a 2D Tensor.
|
||||
The targets are given as *probabilities* (i.e. without taking the logarithm).
|
||||
|
||||
This criterion expects a `target` `Tensor` of the same size as the
|
||||
|
|
@ -303,31 +303,14 @@ class KLDivLoss(_Loss):
|
|||
\operatorname{sum}(L), & \text{if}\; \text{size\_average} = \text{False}.
|
||||
\end{cases}
|
||||
|
||||
By default, the losses are averaged for each minibatch over observations
|
||||
**as well as** over dimensions. However, if the field
|
||||
:attr:`size_average` is set to ``False``, the losses are instead summed.
|
||||
In default reduction mode 'mean', the losses are averaged for each minibatch over observations
|
||||
**as well as** over dimensions. 'batchmean' mode gives the correct KL divergence where losses
|
||||
are averaged over batch dimension only. 'mean' mode's behavior will be changed to the same as
|
||||
'batchmean' in the next major release.
|
||||
|
||||
.. _Kullback-Leibler divergence:
|
||||
https://en.wikipedia.org/wiki/Kullback-Leibler_divergence
|
||||
|
||||
.. note:: The default averaging means that the loss is actually **not** the
|
||||
KL Divergence because the terms are already probability weighted.
|
||||
A future release of PyTorch may move the default loss closer to the
|
||||
mathematical definition.
|
||||
|
||||
To get the real KL Divergence, use ``size_average=False``, and
|
||||
then divide the output by the batch size.
|
||||
|
||||
Example::
|
||||
|
||||
>>> loss = nn.KLDivLoss(size_average=False)
|
||||
>>> batch_size = 5
|
||||
>>> log_probs1 = F.log_softmax(torch.randn(batch_size, 10), 1)
|
||||
>>> probs2 = F.softmax(torch.randn(batch_size, 10), 1)
|
||||
>>> loss(log_probs1, probs2) / batch_size
|
||||
tensor(0.7142)
|
||||
|
||||
|
||||
Args:
|
||||
size_average (bool, optional): Deprecated (see :attr:`reduction`). By default,
|
||||
the losses are averaged over each loss element in the batch. Note that for
|
||||
|
|
@ -339,11 +322,18 @@ class KLDivLoss(_Loss):
|
|||
on :attr:`size_average`. When :attr:`reduce` is ``False``, returns a loss per
|
||||
batch element instead and ignores :attr:`size_average`. Default: ``True``
|
||||
reduction (string, optional): Specifies the reduction to apply to the output:
|
||||
'none' | 'mean' | 'sum'. 'none': no reduction will be applied,
|
||||
'mean': the sum of the output will be divided by the number of
|
||||
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
|
||||
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
|
||||
specifying either of those two args will override :attr:`reduction`. Default: 'mean'
|
||||
'none' | 'batchmean' | 'sum' | 'mean'.
|
||||
'none': no reduction will be applied.
|
||||
'batchmean': the sum of the output will be divided by batchsize.
|
||||
'sum': the output will be summed.
|
||||
'mean': the output will be divided by the number of elements in the output.
|
||||
Default: 'mean'
|
||||
.. note:: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated,
|
||||
and in the meantime, specifying either of those two args will override :attr:`reduction`.
|
||||
.. note:: `reduction='mean'` doesn't return the true kl divergence value, please use
|
||||
`reduction='batchmean'` which aligns with KL math definition.
|
||||
In the next major release, 'mean' will be changed to be the same as 'batchmean'.
|
||||
|
||||
|
||||
Shape:
|
||||
- input: :math:`(N, *)` where `*` means, any number of additional
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user