mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
`rms_norm()` is a nice-to-have for ViT :) This PR: * SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp. * Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side. * Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872 Approved by: https://github.com/mikaylagawarecki ghstack dependencies: #125947 |
||
|---|---|---|
| .. | ||
| _internal | ||
| __init__.py | ||