[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160) (#158462)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158462
Approved by: https://github.com/eellison
This commit is contained in:
Markus Hoehnerbach 2025-07-29 16:41:26 -07:00 committed by PyTorch MergeBot
parent 8fedcfa59a
commit f89c28cc6b
4 changed files with 49 additions and 1 deletions

View File

@ -13654,6 +13654,31 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
inputs = (torch.randn(4, device=self.device),) inputs = (torch.randn(4, device=self.device),)
self.common(Model(), inputs) self.common(Model(), inputs)
@parametrize("dtype", [torch.int32, torch.int64])
def test_repeat_interleave_Tensor_decomp(self, dtype):
# https://github.com/pytorch/pytorch/issues/147160
def f(input, repeats):
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device)
repeat = torch.tensor([1, 2], device=self.device)
if input.device.type == "mps" and dtype == torch.int64:
raise unittest.SkipTest(
"torch.compile fails this test with mps & int64, "
"see https://github.com/pytorch/pytorch/issues/159408"
)
f_compiled = torch.compile(f)
output, (code,) = run_and_get_code(f_compiled, input, repeat)
reference = f(input, repeat)
self.assertEqual(output, reference)
# we don't lower when the cpp_wrapper is used because it cannot generate
# proper examples during autotune
can_lower = (not config.cpp_wrapper) and (input.device.type != "mps")
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
self.assertEqual(has_lowered, can_lower)
@dataclasses.dataclass @dataclasses.dataclass
class TestFailure: class TestFailure:

View File

@ -347,7 +347,7 @@ test_failures = {
"test_rand_like_deterministic_dynamic_shapes": TestFailure( "test_rand_like_deterministic_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True ("cpu", "cuda", "xpu"), is_skip=True
), ),
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")), "test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")),
"test_slice_mutation2_dynamic_shapes": TestFailure( "test_slice_mutation2_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True ("cpu", "cuda", "xpu"), is_skip=True
), ),

View File

@ -1159,3 +1159,25 @@ def rrelu_with_noise_functional(
else: else:
negative_slope = (lower + upper) / 2 negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope), torch.Tensor() return aten.leaky_relu(self, negative_slope), torch.Tensor()
@register_decomposition(aten.repeat_interleave.Tensor)
def repeat_interleave_Tensor(
repeat: torch.Tensor,
output_size: Optional[int] = None,
) -> torch.Tensor:
if config.triton.autotune_at_compile_time:
# We can't compile-time auto-tune this because
# it expects specific data in `repeat`
return NotImplemented
if output_size is None or type(output_size) is not int:
return NotImplemented
if repeat.device.type == "mps":
return NotImplemented
assert repeat.dtype in [torch.int32, torch.int64]
assert repeat.ndim == 1
cumsum = repeat.cumsum(0)
pos = torch.arange(output_size, device=repeat.device)
return torch.searchsorted(
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
)

View File

@ -2874,6 +2874,7 @@ make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
# index_reduce requires fallback when use_scatter_fallback(...) returns True # index_reduce requires fallback when use_scatter_fallback(...) returns True
make_fallback(aten.index_reduce) make_fallback(aten.index_reduce)
make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
# Register with type_promotion_kind None. # Register with type_promotion_kind None.