mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158462 Approved by: https://github.com/eellison
This commit is contained in:
parent
1ea688f9a2
commit
182efe31db
|
|
@ -13717,6 +13717,35 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar
|
|||
args = (inp, repeats, output_size)
|
||||
self.assertEqual(fn(*args), torch.compile(fn)(*args))
|
||||
|
||||
@parametrize("dtype", [torch.int32, torch.int64])
|
||||
@parametrize("nd", [1, 2])
|
||||
def test_repeat_interleave_Tensor_decomp(self, dtype, nd):
|
||||
# https://github.com/pytorch/pytorch/issues/147160
|
||||
def f(input, repeats):
|
||||
return torch.repeat_interleave(input, repeats, dim=0, output_size=3) + 1
|
||||
|
||||
input = torch.tensor([[1, 2], [3, 4]], dtype=dtype, device=self.device)
|
||||
input = torch.arange(1, 2**nd + 1, dtype=dtype, device=self.device).reshape(
|
||||
[2] * nd
|
||||
)
|
||||
repeat = torch.tensor([1, 2], device=self.device)
|
||||
|
||||
if input.device.type == "mps" and dtype == torch.int64:
|
||||
raise unittest.SkipTest(
|
||||
"torch.compile fails this test with mps & int64, "
|
||||
"see https://github.com/pytorch/pytorch/issues/159408"
|
||||
)
|
||||
|
||||
f_compiled = torch.compile(f)
|
||||
output, (code,) = run_and_get_code(f_compiled, input, repeat)
|
||||
reference = f(input, repeat)
|
||||
self.assertEqual(output, reference)
|
||||
# we don't lower when the cpp_wrapper is used because it cannot generate
|
||||
# proper examples during autotune
|
||||
can_lower = (not config.cpp_wrapper) and (input.device.type != "mps")
|
||||
has_lowered = not re.search(r"repeat_interleave.Tensor", code)
|
||||
self.assertEqual(has_lowered, can_lower)
|
||||
|
||||
# end of class CommonTemplate - add new tests here
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ test_failures = {
|
|||
"test_rand_like_deterministic_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu"), is_skip=True
|
||||
),
|
||||
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "cuda", "xpu")),
|
||||
"test_repeat_interleave_2_dynamic_shapes": TestFailure(("cpu", "xpu")),
|
||||
"test_slice_mutation2_dynamic_shapes": TestFailure(
|
||||
("cpu", "cuda", "xpu"), is_skip=True
|
||||
),
|
||||
|
|
|
|||
|
|
@ -1154,3 +1154,25 @@ def rrelu_with_noise_functional(
|
|||
else:
|
||||
negative_slope = (lower + upper) / 2
|
||||
return aten.leaky_relu(self, negative_slope), torch.Tensor()
|
||||
|
||||
|
||||
@register_decomposition(aten.repeat_interleave.Tensor)
|
||||
def repeat_interleave_Tensor(
|
||||
repeat: torch.Tensor,
|
||||
output_size: Optional[int] = None,
|
||||
) -> torch.Tensor:
|
||||
if config.triton.autotune_at_compile_time:
|
||||
# We can't compile-time auto-tune this because
|
||||
# it expects specific data in `repeat`
|
||||
return NotImplemented
|
||||
if output_size is None or type(output_size) is not int:
|
||||
return NotImplemented
|
||||
if repeat.device.type == "mps":
|
||||
return NotImplemented
|
||||
assert repeat.dtype in [torch.int32, torch.int64]
|
||||
assert repeat.ndim == 1
|
||||
cumsum = repeat.cumsum(0)
|
||||
pos = torch.arange(output_size, device=repeat.device)
|
||||
return torch.searchsorted(
|
||||
cumsum, pos, out_int32=(repeat.dtype == torch.int32), right=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2879,6 +2879,7 @@ make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
|
|||
|
||||
# index_reduce requires fallback when use_scatter_fallback(...) returns True
|
||||
make_fallback(aten.index_reduce)
|
||||
make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
|
||||
|
||||
|
||||
# Register with type_promotion_kind None.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user