mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Fix tensordot allowing negative dims (#31954)
Summary: fixes https://github.com/pytorch/pytorch/issues/31926 Pull Request resolved: https://github.com/pytorch/pytorch/pull/31954 Differential Revision: D19331847 Pulled By: zou3519 fbshipit-source-id: e30dd9517917c056a52be7d16f23247fe28f4e28
This commit is contained in:
parent
8ea49e7a08
commit
b6f43afaca
|
|
@ -10917,6 +10917,10 @@ class TestTorchDeviceType(TestCase):
|
|||
c = torch.tensordot(a, b, dims=2).cpu()
|
||||
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy(),
|
||||
axes=2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "expects dims >= 0"):
|
||||
torch.tensordot(a, b, dims=-1)
|
||||
|
||||
self.assertEqual(c, cn)
|
||||
c = torch.tensordot(a, b).cpu()
|
||||
cn = torch.from_numpy(np.tensordot(a.cpu().numpy(), b.cpu().numpy()))
|
||||
|
|
|
|||
|
|
@ -572,9 +572,9 @@ def tensordot(a, b, dims=2):
|
|||
contract or explicit lists of dimensions for :attr:`a` and
|
||||
:attr:`b` respectively
|
||||
|
||||
When called with an integer argument :attr:`dims` = :math:`d`, and the number of
|
||||
dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`, respectively,
|
||||
it computes
|
||||
When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
|
||||
the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
|
||||
respectively, :func:`~torch.tensordot` computes
|
||||
|
||||
.. math::
|
||||
r_{i_0,...,i_{m-d}, i_d,...,i_n}
|
||||
|
|
@ -582,7 +582,7 @@ def tensordot(a, b, dims=2):
|
|||
|
||||
When called with :attr:`dims` of the list form, the given dimensions will be contracted
|
||||
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
|
||||
in these dimensions must match, but :attr:`tensordot` will deal with broadcasted
|
||||
in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
|
||||
dimensions.
|
||||
|
||||
Examples::
|
||||
|
|
@ -610,6 +610,8 @@ def tensordot(a, b, dims=2):
|
|||
else:
|
||||
if isinstance(dims, torch.Tensor):
|
||||
dims = dims.item()
|
||||
if dims < 0:
|
||||
raise RuntimeError("tensordot expects dims >= 0, but got dims={}".format(dims))
|
||||
dims_a = list(range(-dims, 0))
|
||||
dims_b = list(range(dims))
|
||||
return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user