mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add randint_like tensor overload for high (#154899)
Fixes #135664 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154899 Approved by: https://github.com/StrongerXi
This commit is contained in:
parent
7e4c097b07
commit
fc77269262
|
|
@ -1212,6 +1212,28 @@ Tensor randint_like(
|
|||
return result.random_(0, high, std::nullopt);
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
const Tensor& high,
|
||||
std::optional<ScalarType> dtype,
|
||||
std::optional<Layout> layout,
|
||||
std::optional<Device> device,
|
||||
std::optional<bool> pin_memory,
|
||||
std::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(),
|
||||
"high must be a scalar tensor and on CPU");
|
||||
int64_t high_scalar = high.item<int64_t>();
|
||||
return at::native::randint_like(
|
||||
self,
|
||||
high_scalar,
|
||||
dtype,
|
||||
layout,
|
||||
device,
|
||||
pin_memory,
|
||||
optional_memory_format);
|
||||
}
|
||||
|
||||
Tensor randint_like(
|
||||
const Tensor& self,
|
||||
int64_t low,
|
||||
|
|
|
|||
|
|
@ -4781,6 +4781,14 @@
|
|||
CompositeExplicitAutograd: randint_like
|
||||
autogen: randint_like.out
|
||||
|
||||
- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
dispatch:
|
||||
# NB: Although this composite mutates on the inside, it is
|
||||
# non-differentiable so NonFunctional doesn't apply
|
||||
CompositeExplicitAutograd: randint_like
|
||||
autogen: randint_like.Tensor_out
|
||||
|
||||
- func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
tags: nondeterministic_seeded
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -12603,6 +12603,17 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase):
|
|||
res = opt_f()
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_randint_no_graphbreak(self):
|
||||
@torch.compile(backend="aot_eager", fullgraph=True)
|
||||
def f(actions, n_act, epsilon=0.1):
|
||||
actions_random = torch.randint_like(actions, n_act)
|
||||
|
||||
return actions_random
|
||||
|
||||
x = torch.ones([1], dtype=torch.int64)
|
||||
y = torch.tensor(5)
|
||||
f(x, y)
|
||||
|
||||
|
||||
devices = ("cuda", "hpu")
|
||||
instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices)
|
||||
|
|
|
|||
|
|
@ -1094,6 +1094,8 @@ aten::randint.low_generator_out
|
|||
aten::randint.low_out
|
||||
aten::randint.out
|
||||
aten::randint_like
|
||||
aten::randint_like.Tensor
|
||||
aten::randint_like.Tensor_out
|
||||
aten::randint_like.low_dtype
|
||||
aten::randint_like.low_dtype_out
|
||||
aten::randint_like.out
|
||||
|
|
|
|||
|
|
@ -3505,6 +3505,11 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
|
|||
return self.new_empty(self.size())
|
||||
|
||||
|
||||
@register_meta([aten.randint_like.Tensor])
|
||||
def meta_randint_like(self, high, **kwargs):
|
||||
return self.new_empty(self.size())
|
||||
|
||||
|
||||
@register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
|
||||
def meta__fused_adam_(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ _like_tensor_constructors = ordered_set(
|
|||
aten.randn_like.default,
|
||||
aten.randn_like.out,
|
||||
aten.randint_like.default,
|
||||
aten.randint_like.Tensor,
|
||||
aten.randint_like.Tensor_out,
|
||||
aten.randint_like.out,
|
||||
aten.randint_like.low_dtype,
|
||||
aten.randint_like.low_dtype_out,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user