mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[functorch.dims] Fix corner cases with permute (#88226)
Previously the permute function was extended to behave like the `order` function for first-class dimensions. However, unlike `permute`, `order` doesn't have a keyword argment `dims`, and there is no way to add it in a way that makes both permute an order to continue to have the same behavior. So this change just removes the extra functionality of permute, which wasn't documented anyway. Fixes #88187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88226 Approved by: https://github.com/zou3519
This commit is contained in:
parent
84a302e534
commit
4a84d69f50
|
|
@ -1767,6 +1767,9 @@ static PyObject* order(PyObject *_,
|
|||
PyObject *kwnames) {
|
||||
Arena A;
|
||||
PY_BEGIN
|
||||
if (kwnames) {
|
||||
py::raise_error(PyExc_TypeError, "unexpected keyword arguments %S", kwnames);
|
||||
}
|
||||
AT_ASSERT(nargs-- > 0);
|
||||
Slice<DimEntry> orig_levels;
|
||||
Slice<DimEntry> levels;
|
||||
|
|
|
|||
|
|
@ -102,9 +102,9 @@ wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
|
|||
del _Tensor.ndim
|
||||
|
||||
if use_c:
|
||||
_Tensor.permute = _Tensor.order = _C._instancemethod(_C.order)
|
||||
_Tensor.order = _C._instancemethod(_C.order)
|
||||
else:
|
||||
_Tensor.permute = _Tensor.order = reference.positional
|
||||
_Tensor.order = reference.positional
|
||||
|
||||
_def('mean')
|
||||
_def('sum')
|
||||
|
|
|
|||
|
|
@ -592,6 +592,17 @@ class TestMin(TestCase):
|
|||
BB = torch.mm(B[j], C) # 3, 4, 2
|
||||
assert list(torch.mm(AA.T, BB).order(i, j).shape) == [3, 3, 2, 2]
|
||||
|
||||
def test_permute_orig(self):
|
||||
d = dims(1)
|
||||
t_fc = torch.rand(1, 2, 3, 4)[d]
|
||||
assert t_fc.permute(dims=(1, 0, 2)).shape == t_fc.permute(1, 0, 2).shape
|
||||
|
||||
def test_order_keyword(self):
|
||||
d = dims(1)
|
||||
t = torch.rand(3)[d]
|
||||
self.assertRaises(TypeError, lambda: t.order(wrong=3))
|
||||
|
||||
|
||||
|
||||
skip_functorch_only = ['test_time_mm_fuse', 'test_attn_cuda']
|
||||
class TestMinFunctorchOnly(TestMin):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user