Fix tensordot allowing negative dims (#31954)

Summary:
fixes https://github.com/pytorch/pytorch/issues/31926
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31954

Differential Revision: D19331847

Pulled By: zou3519

fbshipit-source-id: e30dd9517917c056a52be7d16f23247fe28f4e28
This commit is contained in:
Tongzhou Wang 2020-01-10 07:40:26 -08:00 committed by Facebook Github Bot
parent 8ea49e7a08
commit b6f43afaca
2 changed files with 10 additions and 4 deletions

View File

@ -10917,6 +10917,10 @@ class TestTorchDeviceType(TestCase):
c = torch.tensordot(a, b, dims=2).cpu()
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
axes=2))
with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
torch.tensordot(a, b, dims=-1)
self.assertEqual(c, cn)
c = torch.tensordot(a, b).cpu()
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))

View File

@ -572,9 +572,9 @@ def tensordot(a, b, dims=2):
contract or explicit lists of dimensions for :attr:`a` and
:attr:`b` respectively
When called with an integer argument :attr:`dims` = :math:`d`, and the number of
dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively,
it computes
When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
respectively, :func:`~torch.tensordot` computes
.. math::
r_{i_0,...,i_{m-d}, i_d,...,i_n}
@ -582,7 +582,7 @@ def tensordot(a, b, dims=2):
When called with :attr:`dims` of the list form, the given dimensions will be contracted
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
in these dimensions must match, but :attr:`tensordot` will deal with broadcasted
in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
dimensions.
Examples::
@ -610,6 +610,8 @@ def tensordot(a, b, dims=2):
else:
if isinstance(dims, torch.Tensor):
dims = dims.item()
if dims < 0:
raise RuntimeError("tensordot expects dims >= 0, but got dims={}".format(dims))
dims_a = list(range(-dims, 0))
dims_b = list(range(dims))
return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)