[Inductor] Fallback embedding when sparse is True (#150659)

**Summary**
Fix issue: https://github.com/pytorch/pytorch/issues/150656, fallback `embedding` when sparse is True.

**Test Plan**
```
python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_embedding_sparse
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150659
Approved by: https://github.com/jansel
This commit is contained in:
leslie-fang-intel 2025-04-03 20:40:04 -07:00 committed by PyTorch MergeBot
parent 2e23768d25
commit d6887f444f
3 changed files with 19 additions and 0 deletions

View File

@ -5360,6 +5360,19 @@ class CommonTemplate:
(torch.randint(10, [2, 8]),),
)
def test_embedding_sparse(self):
# Fix https://github.com/pytorch/pytorch/issues/150656
def fn(weight, indices):
return F.embedding(indices, weight, sparse=True)
indices = torch.randint(10, (2, 3))
weight = torch.randn(10, 3, requires_grad=True)
self.common(
fn,
(weight, indices),
)
def test_mean(self):
def fn(x):
return (

View File

@ -137,6 +137,7 @@ test_failures = {
"test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)),
"test_flip_cat_dynamic_shapes": TestFailure(("cpu",)),
"test_pad_single_dynamic_shapes": TestFailure(("cpu",)),
"test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
#
# Failed to find for loop/triton kernel:
#

View File

@ -3376,6 +3376,11 @@ def gather(x, dim, index, sparse_grad=False):
@register_lowering(aten.embedding, type_promotion_kind=None)
def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
if sparse:
return fallback_handler(aten.embedding.default)(
weight, indices, padding_idx, scale_grad_by_freq, sparse
)
assert not sparse
assert isinstance(weight, TensorBox)
assert isinstance(indices, TensorBox)