Make torch.split take symint as arg (#91724)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91724
Approved by: https://github.com/voznesenskym
This commit is contained in:
Tugsbayasgalan Manlaibaatar 2023-01-06 17:07:41 +00:00 committed by PyTorch MergeBot
parent 08a378a286
commit b32b81a0c5
2 changed files with 15 additions and 1 deletions

View File

@ -2248,6 +2248,20 @@ class ReproTests(torch._dynamo.test_case.TestCase):
gm(torch.zeros(6, 4), torch.tensor(2)),
)
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
def test_tensor_split(self):
def f(x):
return torch.split(x, x.shape[0] // 2, dim=0)[0]
gm, _ = torch._dynamo.export(
f,
torch.zeros(6, 4),
aten_graph=True,
tracing_mode="symbolic",
)
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -789,7 +789,7 @@ class Tensor(torch._C._TensorBase):
except ValueError:
pass
if isinstance(split_size, int):
if isinstance(split_size, (int, torch.SymInt)):
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
else:
return torch._VF.split_with_sizes(self, split_size, dim)