pytorch/torch/nested
Joel Schlosser a8382847f4 Support rms_norm() for NJT (#135872)
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
2024-09-17 18:09:20 +00:00
..
_internal Support rms_norm() for NJT (#135872) 2024-09-17 18:09:20 +00:00
__init__.py Implement 2D version of masked_select for nestedtensors (#133889) 2024-08-20 21:46:32 +00:00