From d6887f444fa61b46c9c31114028484e658f3dc99 Mon Sep 17 00:00:00 2001 From: leslie-fang-intel Date: Thu, 3 Apr 2025 20:40:04 -0700 Subject: [PATCH] [Inductor] Fallback embedding when sparse is True (#150659) **Summary** Fix issue: https://github.com/pytorch/pytorch/issues/150656, fallback `embedding` when sparse is True. **Test Plan** ``` python -u -m pytest -s -v test/inductor/test_torchinductor.py -k test_embedding_sparse ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/150659 Approved by: https://github.com/jansel --- test/inductor/test_torchinductor.py | 13 +++++++++++++ .../test_torchinductor_codegen_dynamic_shapes.py | 1 + torch/_inductor/lowering.py | 5 +++++ 3 files changed, 19 insertions(+) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 30aafa06206..54524be5314 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -5360,6 +5360,19 @@ class CommonTemplate: (torch.randint(10, [2, 8]),), ) + def test_embedding_sparse(self): + # Fix https://github.com/pytorch/pytorch/issues/150656 + def fn(weight, indices): + return F.embedding(indices, weight, sparse=True) + + indices = torch.randint(10, (2, 3)) + weight = torch.randn(10, 3, requires_grad=True) + + self.common( + fn, + (weight, indices), + ) + def test_mean(self): def fn(x): return ( diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index c090b7b7846..29d74152bf4 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -137,6 +137,7 @@ test_failures = { "test_mul_index_expr_dynamic_shapes": TestFailure(("cpu",)), "test_flip_cat_dynamic_shapes": TestFailure(("cpu",)), "test_pad_single_dynamic_shapes": TestFailure(("cpu",)), + "test_embedding_sparse_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), # # Failed to find for loop/triton kernel: # diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 7fcf7904185..24520887f6a 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3376,6 +3376,11 @@ def gather(x, dim, index, sparse_grad=False): @register_lowering(aten.embedding, type_promotion_kind=None) def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False): + if sparse: + return fallback_handler(aten.embedding.default)( + weight, indices, padding_idx, scale_grad_by_freq, sparse + ) + assert not sparse assert isinstance(weight, TensorBox) assert isinstance(indices, TensorBox)