mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Fallback embedding when sparse is True (#150659)
**Summary** Fix issue: https://github.com/pytorch/pytorch/issues/150656, fallback `embedding` when sparse is True. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_embedding_sparse ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150659 Approved by: https://github.com/jansel
This commit is contained in:
parent
2e23768d25
commit
d6887f444f
|
|
@ -5360,6 +5360,19 @@ class CommonTemplate:
|
|||
(torch.randint(10, [2, 8]),),
|
||||
)
|
||||
|
||||
def test_embedding_sparse(self):
|
||||
# Fix https://github.com/pytorch/pytorch/issues/150656
|
||||
def fn(weight, indices):
|
||||
return F.embedding(indices, weight, sparse=True)
|
||||
|
||||
indices = torch.randint(10, (2, 3))
|
||||
weight = torch.randn(10, 3, requires_grad=True)
|
||||
|
||||
self.common(
|
||||
fn,
|
||||
(weight, indices),
|
||||
)
|
||||
|
||||
def test_mean(self):
|
||||
def fn(x):
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -137,6 +137,7 @@ test_failures = {
|
|||
"test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_flip_cat_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_pad_single_dynamic_shapes": TestFailure(("cpu",)),
|
||||
"test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
#
|
||||
# Failed to find for loop/triton kernel:
|
||||
#
|
||||
|
|
|
|||
|
|
@ -3376,6 +3376,11 @@ def gather(x, dim, index, sparse_grad=False):
|
|||
|
||||
@register_lowering(aten.embedding, type_promotion_kind=None)
|
||||
def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
|
||||
if sparse:
|
||||
return fallback_handler(aten.embedding.default)(
|
||||
weight, indices, padding_idx, scale_grad_by_freq, sparse
|
||||
)
|
||||
|
||||
assert not sparse
|
||||
assert isinstance(weight, TensorBox)
|
||||
assert isinstance(indices, TensorBox)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user