[inductor] add lowering for repeat_interleave.Tensor with output size specified (#147160) (#158462)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158462
Approved by: https://github.com/eellison
This commit is contained in:
Markus Hoehnerbach 2025-08-13 09:37:48 -07:00 committed by PyTorch MergeBot
parent 1ea688f9a2
commit 182efe31db
4 changed files with 53 additions and 1 deletions

View File

@ -13717,6 +13717,35 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
args = (inp, repeats, output_size)
self.assertEqual(fn(*args), torch.compile(fn)(*args))
@parametrize("dtype", [torch.int32, torch.int64])
@parametrize("nd", [1, 2])
def test_repeat_interleave_Tensor_decomp(self, dtype, nd):
# https://github.com/pytorch/pytorch/issues/147160
def f(input, repeats):
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device)
input = torch.arange(1, 2**nd + 1, dtype=dtype, device=self.device).reshape(
[2] * nd
)
repeat = torch.tensor([1, 2], device=self.device)
if input.device.type == "mps" and dtype == torch.int64:
raise unittest.SkipTest(
"torch.compile fails this test with mps & int64, "
"see https://github.com/pytorch/pytorch/issues/159408"
)
f_compiled = torch.compile(f)
output, (code,) = run_and_get_code(f_compiled, input, repeat)
reference = f(input, repeat)
self.assertEqual(output, reference)
# we don't lower when the cpp_wrapper is used because it cannot generate
# proper examples during autotune
can_lower = (not config.cpp_wrapper) and (input.device.type != "mps")
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
self.assertEqual(has_lowered, can_lower)
# end of class CommonTemplate - add new tests here

View File

@ -348,7 +348,7 @@ test_failures = {
"test_rand_like_deterministic_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")),
"test_slice_mutation2_dynamic_shapes": TestFailure(
("cpu", "cuda", "xpu"), is_skip=True
),

View File

@ -1154,3 +1154,25 @@ def rrelu_with_noise_functional(
else:
negative_slope = (lower + upper) / 2
return aten.leaky_relu(self, negative_slope), torch.Tensor()
@register_decomposition(aten.repeat_interleave.Tensor)
def repeat_interleave_Tensor(
repeat: torch.Tensor,
output_size: Optional[int] = None,
) -> torch.Tensor:
if config.triton.autotune_at_compile_time:
# We can't compile-time auto-tune this because
# it expects specific data in `repeat`
return NotImplemented
if output_size is None or type(output_size) is not int:
return NotImplemented
if repeat.device.type == "mps":
return NotImplemented
assert repeat.dtype in [torch.int32, torch.int64]
assert repeat.ndim == 1
cumsum = repeat.cumsum(0)
pos = torch.arange(output_size, device=repeat.device)
return torch.searchsorted(
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
)

View File

@ -2879,6 +2879,7 @@ make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
# index_reduce requires fallback when use_scatter_fallback(...) returns True
make_fallback(aten.index_reduce)
make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
# Register with type_promotion_kind None.