mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make torch.split take symint as arg (#91724)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91724 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
08a378a286
commit
b32b81a0c5
|
|
@ -2248,6 +2248,20 @@ class ReproTests(torch._dynamo.test_case.TestCase):
|
|||
gm(torch.zeros(6, 4), torch.tensor(2)),
|
||||
)
|
||||
|
||||
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
||||
def test_tensor_split(self):
|
||||
def f(x):
|
||||
return torch.split(x, x.shape[0] // 2, dim=0)[0]
|
||||
|
||||
gm, _ = torch._dynamo.export(
|
||||
f,
|
||||
torch.zeros(6, 4),
|
||||
aten_graph=True,
|
||||
tracing_mode="symbolic",
|
||||
)
|
||||
|
||||
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -789,7 +789,7 @@ class Tensor(torch._C._TensorBase):
|
|||
except ValueError:
|
||||
pass
|
||||
|
||||
if isinstance(split_size, int):
|
||||
if isinstance(split_size, (int, torch.SymInt)):
|
||||
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
|
||||
else:
|
||||
return torch._VF.split_with_sizes(self, split_size, dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user