pytorch/torch/testing/_internal
Joel Schlosser a8382847f4 Support rms_norm() for NJT (#135872)
`rms_norm()` is a nice-to-have for ViT :)

This PR:
* SymInt-ifies `rms_norm()`, allowing NJT to use the same decomp.
* Adds torch_function-based input validation logic for nested-specific stuff (no normalization supported over the ragged dim for now) on the python NJT side.
* Adds multi-dim support (on non-ragged, non-batch dims) to `mean()` for NJT.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135872
Approved by: https://github.com/mikaylagawarecki
ghstack dependencies: #125947
2024-09-17 18:09:20 +00:00
..
codegen
data Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
distributed [DCP] Fixes the stateless optimizer issue of distributed state_dict (#135535) 2024-09-10 03:10:00 +00:00
generated
opinfo Support rms_norm() for NJT (#135872) 2024-09-17 18:09:20 +00:00
optests AOTDispatcher: limit cases when we detach() graph inputs to non-leaves (#134193) 2024-09-06 14:06:48 +00:00
test_module
__init__.py
autocast_test_lists.py Moving _run_autocast_outofplace to basic class named TestAutocast to reduce redundance (#134460) 2024-09-04 10:48:58 +00:00
autograd_function_db.py
check_kernel_launches.py
common_cuda.py Revert "Enable FlashAttention on Windows (#131906)" 2024-07-29 16:49:23 +00:00
common_device_type.py [BE][MPS] Prefer xfail to skip (#134858) 2024-08-31 00:29:48 +00:00
common_dist_composable.py
common_distributed.py [Traceable FSDP2][Inductor] Create grouped nodes for FSDP2 all-gather code block and reduce-scatter code block (after Buffer/Operation split) (#131510) 2024-07-27 08:39:58 +00:00
common_dtype.py
common_fsdp.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
common_jit.py
common_methods_invocations.py Revert "Add decomposition for permute_copy (#130944)" 2024-09-17 13:42:55 +00:00
common_mkldnn.py
common_modules.py Revert "Validate input types for torch.nn.Linear and torch.nn.Bilinear (#135596)" 2024-09-13 18:06:56 +00:00
common_nn.py Fix failures when default is flipped for weights_only (#127627) 2024-08-16 00:22:43 +00:00
common_optimizers.py parameterized test_graph_optims and test_graph_scaling_fused_optimizers (#133749) 2024-08-28 16:34:06 +00:00
common_pruning.py [BE][typing] fix types in common pruning (#132309) 2024-08-01 23:34:33 +00:00
common_quantization.py [export][training ir migration] quantized_decomposed.quantize_per_tensor decomposition (#134525) 2024-09-06 07:06:06 +00:00
common_quantized.py
common_subclass.py [subclasses] Do not fakeTensor const prop subclass args (#134855) 2024-09-03 13:31:49 +00:00
common_utils.py Exclude test_transformers and unit tests which require recent GPU arch (#132895) 2024-08-27 20:40:53 +00:00
composite_compliance.py
custom_op_db.py Revert "[BE] typing for decorators - _library/custom_ops (#131578)" 2024-07-28 03:29:32 +00:00
custom_tensor.py
dist_utils.py
dynamo_test_failures.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
hop_db.py [HOO] add hints_wrapper to support passing context hints (#132860) 2024-08-26 18:21:22 +00:00
hypothesis_utils.py
inductor_utils.py Revert "Add CI for Triton CPU backend (#135342)" 2024-09-16 18:33:33 +00:00
jit_metaprogramming_utils.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
jit_utils.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
logging_tensor.py
logging_utils.py
quantization_torch_package_models.py
static_module.py
torchbind_impls.py
triton_utils.py [aotinductor][UserDefinedTritonKernel] fix case with non-constexpr params declared after autotuned params (#134520) 2024-08-27 17:20:27 +00:00
two_tensor.py Add recursive metadata guard test (#131002) 2024-07-18 08:24:43 +00:00