Add randint_like tensor overload for high (#154899)

Fixes #135664

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154899
Approved by: https://github.com/StrongerXi
This commit is contained in:
bobrenjc93 2025-06-05 08:06:33 -07:00 committed by PyTorch MergeBot
parent 7e4c097b07
commit fc77269262
6 changed files with 50 additions and 0 deletions

View File

@ -1212,6 +1212,28 @@ Tensor randint_like(
return result.random_(0, high, std::nullopt);
}
Tensor randint_like(
const Tensor& self,
const Tensor& high,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory,
std::optional<c10::MemoryFormat> optional_memory_format) {
TORCH_CHECK(
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
"high must be a scalar tensor and on CPU");
int64_t high_scalar = high.item<int64_t>();
return at::native::randint_like(
self,
high_scalar,
dtype,
layout,
device,
pin_memory,
optional_memory_format);
}
Tensor randint_like(
const Tensor& self,
int64_t low,

View File

@ -4781,6 +4781,14 @@
CompositeExplicitAutograd: randint_like
autogen: randint_like.out
- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:
# NB: Although this composite mutates on the inside, it is
# non-differentiable so NonFunctional doesn't apply
CompositeExplicitAutograd: randint_like
autogen: randint_like.Tensor_out
- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
tags: nondeterministic_seeded
dispatch:

View File

@ -12603,6 +12603,17 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
res = opt_f()
self.assertEqual(ref, res)
def test_randint_no_graphbreak(self):
@torch.compile(backend="aot_eager", fullgraph=True)
def f(actions, n_act, epsilon=0.1):
actions_random = torch.randint_like(actions, n_act)
return actions_random
x = torch.ones([1], dtype=torch.int64)
y = torch.tensor(5)
f(x, y)
devices = ("cuda", "hpu")
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)

View File

@ -1094,6 +1094,8 @@ aten::randint.low_generator_out
aten::randint.low_out
aten::randint.out
aten::randint_like
aten::randint_like.Tensor
aten::randint_like.Tensor_out
aten::randint_like.low_dtype
aten::randint_like.low_dtype_out
aten::randint_like.out

View File

@ -3505,6 +3505,11 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
return self.new_empty(self.size())
@register_meta([aten.randint_like.Tensor])
def meta_randint_like(self, high, **kwargs):
return self.new_empty(self.size())
@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
def meta__fused_adam_(
self,

View File

@ -68,6 +68,8 @@ _like_tensor_constructors = ordered_set(
aten.randn_like.default,
aten.randn_like.out,
aten.randint_like.default,
aten.randint_like.Tensor,
aten.randint_like.Tensor_out,
aten.randint_like.out,
aten.randint_like.low_dtype,
aten.randint_like.low_dtype_out,