[functorch.dims] Fix corner cases with permute (#88226)

Previously the permute function was extended to behave like the `order`
function for first-class dimensions. However, unlike `permute`,
`order` doesn't have a keyword argment `dims`, and there is no way to add
it in a way that makes both permute an order to continue to have the same
behavior. So this change just removes the extra functionality of permute,
which wasn't documented anyway. Fixes #88187
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88226
Approved by: https://github.com/zou3519
This commit is contained in:
Zachary DeVito 2022-11-01 11:35:23 -07:00 committed by PyTorch MergeBot
parent 84a302e534
commit 4a84d69f50
3 changed files with 16 additions and 2 deletions

View File

@ -1767,6 +1767,9 @@ static PyObject* order(PyObject *_,
PyObject *kwnames) {
Arena A;
PY_BEGIN
if (kwnames) {
py::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames);
}
AT_ASSERT(nargs-- > 0);
Slice<DimEntry> orig_levels;
Slice<DimEntry> levels;

View File

@ -102,9 +102,9 @@ wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.permute = _Tensor.order = _C._instancemethod(_C.order)
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.permute = _Tensor.order = reference.positional
_Tensor.order = reference.positional
_def('mean')
_def('sum')

View File

@ -592,6 +592,17 @@ class TestMin(TestCase):
BB = torch.mm(B[j], C) # 3, 4, 2
assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]
def test_permute_orig(self):
d = dims(1)
t_fc = torch.rand(1, 2, 3, 4)[d]
assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape
def test_order_keyword(self):
d = dims(1)
t = torch.rand(3)[d]
self.assertRaises(TypeError, lambda: t.order(wrong=3))
skip_functorch_only = ['test_time_mm_fuse', 'test_attn_cuda']
class TestMinFunctorchOnly(TestMin):