pytorch/torch/nested
Joel Schlosser c9e2b3fefe NJT: Return correct number of outputs for chunk() on the batch dim (#141604)
Old logic was completely wrong, returning `chunk_size` chunks instead of the intended number. The original test didn't catch this because `chunk_size == num_chunks` :p New OpInfo-based testing covers it though.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141604
Approved by: https://github.com/soulitzer
ghstack dependencies: #141500, #140736, #140161, #141392, #141506
2024-11-27 02:31:23 +00:00
..
_internal NJT: Return correct number of outputs for chunk() on the batch dim (#141604) 2024-11-27 02:31:23 +00:00
__init__.py Allow NJT by default for weights_only torch.load (take 2) (#140739) 2024-11-19 02:44:53 +00:00