From b6f43afacaff516a8e1681afac367bb4cb15aa62 Mon Sep 17 00:00:00 2001 From: Tongzhou Wang Date: Fri, 10 Jan 2020 07:40:26 -0800 Subject: [PATCH] Fix tensordot allowing negative dims (#31954) Summary: fixes https://github.com/pytorch/pytorch/issues/31926 Pull Request resolved: https://github.com/pytorch/pytorch/pull/31954 Differential Revision: D19331847 Pulled By: zou3519 fbshipit-source-id: e30dd9517917c056a52be7d16f23247fe28f4e28 --- test/test_torch.py | 4 ++++ torch/functional.py | 10 ++++++---- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/test/test_torch.py b/test/test_torch.py index 91db02397b1..475482b7047 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -10917,6 +10917,10 @@ class TestTorchDeviceType(TestCase): c = torch.tensordot(a, b, dims=2).cpu() cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(), axes=2)) + + with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"): + torch.tensordot(a, b, dims=-1) + self.assertEqual(c, cn) c = torch.tensordot(a, b).cpu() cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy())) diff --git a/torch/functional.py b/torch/functional.py index 1a2cdb3ec9c..afed5c91217 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -572,9 +572,9 @@ def tensordot(a, b, dims=2): contract or explicit lists of dimensions for :attr:`a` and :attr:`b` respectively - When called with an integer argument :attr:`dims` = :math:`d`, and the number of - dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively, - it computes + When called with a non-negative integer argument :attr:`dims` = :math:`d`, and + the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, + respectively, :func:`~torch.tensordot` computes .. math:: r_{i_0,...,i_{m-d}, i_d,...,i_n} @@ -582,7 +582,7 @@ def tensordot(a, b, dims=2): When called with :attr:`dims` of the list form, the given dimensions will be contracted in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes - in these dimensions must match, but :attr:`tensordot` will deal with broadcasted + in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted dimensions. Examples:: @@ -610,6 +610,8 @@ def tensordot(a, b, dims=2): else: if isinstance(dims, torch.Tensor): dims = dims.item() + if dims < 0: + raise RuntimeError("tensordot expects dims >= 0, but got dims={}".format(dims)) dims_a = list(range(-dims, 0)) dims_b = list(range(dims)) return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)