pytorch/torch/testing
Joel Schlosser 83a3ee0699 Support embedding_bag() with NJT input (#135888)
Fixes #93843

`EmbeddingBag()` / `embedding_bag()` support 1D inputs with offsets to handle raggedness. NJT is a natural fit here as it already maintains offsets of the same form. This PR updates the python-side to support NJT and adds corresponding OpInfo-based NJT tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135888
Approved by: https://github.com/cpuhrsch
2024-09-23 17:35:19 +00:00
..
_internal Support embedding_bag() with NJT input (#135888) 2024-09-23 17:35:19 +00:00
__init__.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
_comparison.py Strict shape checking for NJTs with TestCase.assertEqual() (#131898) 2024-07-30 20:05:48 +00:00
_creation.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00
_utils.py [BE][Easy][19/19] enforce style for empty lines in import segments in torch/[o-z]*/ (#129771) 2024-08-01 17:07:14 +00:00