pytorch/torch/testing/_internal
Joel Schlosser 83a3ee0699 Support embedding_bag() with NJT input (#135888)
Fixes #93843

`EmbeddingBag()` / `embedding_bag()` support 1D inputs with offsets to handle raggedness. NJT is a natural fit here as it already maintains offsets of the same form. This PR updates the python-side to support NJT and adds corresponding OpInfo-based NJT tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135888
Approved by: https://github.com/cpuhrsch
2024-09-23 17:35:19 +00:00
..
codegen
data Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
distributed Fix ROCm skip decorator for test_ddp_tp and multiprocess UTs (#136161) 2024-09-18 11:01:23 +00:00
generated
opinfo Support embedding_bag() with NJT input (#135888) 2024-09-23 17:35:19 +00:00
optests AOTDispatcher: limit cases when we detach() graph inputs to non-leaves (#134193) 2024-09-06 14:06:48 +00:00
test_module
__init__.py
autocast_test_lists.py Add _addmm_activation to lower precision cast policy on AutocastCPU (#135936) 2024-09-18 16:31:27 +00:00
autograd_function_db.py
check_kernel_launches.py
common_cuda.py Revert "Enable FlashAttention on Windows (#131906)" 2024-07-29 16:49:23 +00:00
common_device_type.py [BE][MPS] Prefer xfail to skip (#134858) 2024-08-31 00:29:48 +00:00
common_dist_composable.py
common_distributed.py [c10d] Make test compatible for new pytest (#136158) 2024-09-18 17:10:55 +00:00
common_dtype.py Enable UFMT on common_device_type.py and common_dtype.py (#128490) 2024-06-15 00:07:42 +00:00
common_fsdp.py [reland][dtensor] move DTensor to public namespace (#134203) 2024-09-08 17:08:40 +00:00
common_jit.py
common_methods_invocations.py Remove prims.slice_in_dim and prims.slice (#136150) 2024-09-23 01:27:22 +00:00
common_mkldnn.py
common_modules.py Revert "Validate input types for torch.nn.Linear and torch.nn.Bilinear (#135596)" 2024-09-13 18:06:56 +00:00
common_nn.py Fix failures when default is flipped for weights_only (#127627) 2024-08-16 00:22:43 +00:00
common_optimizers.py parameterized test_graph_optims and test_graph_scaling_fused_optimizers (#133749) 2024-08-28 16:34:06 +00:00
common_pruning.py [BE][typing] fix types in common pruning (#132309) 2024-08-01 23:34:33 +00:00
common_quantization.py [export][training ir migration] quantized_decomposed.quantize_per_tensor decomposition (#134525) 2024-09-06 07:06:06 +00:00
common_quantized.py
common_subclass.py [subclasses] Do not fakeTensor const prop subclass args (#134855) 2024-09-03 13:31:49 +00:00
common_utils.py XFAIL test_segfault (#136252) 2024-09-19 04:17:06 +00:00
composite_compliance.py
custom_op_db.py Revert "[BE] typing for decorators - _library/custom_ops (#131578)" 2024-07-28 03:29:32 +00:00
custom_tensor.py Nested tensor subclass support (#127431) 2024-06-26 04:45:22 +00:00
dist_utils.py
dynamo_test_failures.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
hop_db.py [HOO] add hints_wrapper to support passing context hints (#132860) 2024-08-26 18:21:22 +00:00
hypothesis_utils.py
inductor_utils.py Revert "Add CI for Triton CPU backend (#135342)" 2024-09-16 18:33:33 +00:00
jit_metaprogramming_utils.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
jit_utils.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
logging_tensor.py [compiled autograd] Fix LoggingTensor flaky test (#126144) 2024-05-16 22:23:02 +00:00
logging_utils.py
quantization_torch_package_models.py
static_module.py Flip default value for mypy disallow_untyped_defs [9/11] (#127846) 2024-06-08 18:50:06 +00:00
torchbind_impls.py Allow kwargs in _remove_effect_tokens_pass (#130491) 2024-07-11 19:03:19 +00:00
triton_utils.py [aotinductor][UserDefinedTritonKernel] fix case with non-constexpr params declared after autotuned params (#134520) 2024-08-27 17:20:27 +00:00
two_tensor.py Add recursive metadata guard test (#131002) 2024-07-18 08:24:43 +00:00