diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 2ef129d5fe1..b9b9bef12e5 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -13654,6 +13654,31 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar inputs = (torch.randn(4, device=self.device),) self.common(Model(), inputs) + @parametrize("dtype", [torch.int32, torch.int64]) + def test_repeat_interleave_Tensor_decomp(self, dtype): + # https://github.com/pytorch/pytorch/issues/147160 + def f(input, repeats): + return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1 + + input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device) + repeat = torch.tensor([1, 2], device=self.device) + + if input.device.type == "mps" and dtype == torch.int64: + raise unittest.SkipTest( + "torch.compile fails this test with mps & int64, " + "see https://github.com/pytorch/pytorch/issues/159408" + ) + + f_compiled = torch.compile(f) + output, (code,) = run_and_get_code(f_compiled, input, repeat) + reference = f(input, repeat) + self.assertEqual(output, reference) + # we don't lower when the cpp_wrapper is used because it cannot generate + # proper examples during autotune + can_lower = (not config.cpp_wrapper) and (input.device.type != "mps") + has_lowered = not re.search(r"repeat_interleave.Tensor", code) + self.assertEqual(has_lowered, can_lower) + @dataclasses.dataclass class TestFailure: diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index 6a7d40b6b7c..ea5a0c7f40a 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -347,7 +347,7 @@ test_failures = { "test_rand_like_deterministic_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), - "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), + "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")), "test_slice_mutation2_dynamic_shapes": TestFailure( ("cpu", "cuda", "xpu"), is_skip=True ), diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 2622ab6b95e..f799f647063 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -1159,3 +1159,25 @@ def rrelu_with_noise_functional( else: negative_slope = (lower + upper) / 2 return aten.leaky_relu(self, negative_slope), torch.Tensor() + + +@register_decomposition(aten.repeat_interleave.Tensor) +def repeat_interleave_Tensor( + repeat: torch.Tensor, + output_size: Optional[int] = None, +) -> torch.Tensor: + if config.triton.autotune_at_compile_time: + # We can't compile-time auto-tune this because + # it expects specific data in `repeat` + return NotImplemented + if output_size is None or type(output_size) is not int: + return NotImplemented + if repeat.device.type == "mps": + return NotImplemented + assert repeat.dtype in [torch.int32, torch.int64] + assert repeat.ndim == 1 + cumsum = repeat.cumsum(0) + pos = torch.arange(output_size, device=repeat.device) + return torch.searchsorted( + cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True + ) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 5c8110f6c39..1b50f9d9856 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -2874,6 +2874,7 @@ make_fallback(aten._efficient_attention_backward.default, sdpa_constraint) # index_reduce requires fallback when use_scatter_fallback(...) returns True make_fallback(aten.index_reduce) +make_fallback(aten.repeat_interleave.Tensor, override_decomp=True) # Register with type_promotion_kind None.