diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index 006e54c2495..1aab4b11c96 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -1212,6 +1212,28 @@ Tensor randint_like( return result.random_(0, high, std::nullopt); } +Tensor randint_like( + const Tensor& self, + const Tensor& high, + std::optional dtype, + std::optional layout, + std::optional device, + std::optional pin_memory, + std::optional optional_memory_format) { + TORCH_CHECK( + high.numel() == 1 && high.ndimension() == 0 && high.device().is_cpu(), + "high must be a scalar tensor and on CPU"); + int64_t high_scalar = high.item(); + return at::native::randint_like( + self, + high_scalar, + dtype, + layout, + device, + pin_memory, + optional_memory_format); +} + Tensor randint_like( const Tensor& self, int64_t low, diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 7c53e087d40..05992461ae6 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -4781,6 +4781,14 @@ CompositeExplicitAutograd: randint_like autogen: randint_like.out +- func: randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + tags: nondeterministic_seeded + dispatch: + # NB: Although this composite mutates on the inside, it is + # non-differentiable so NonFunctional doesn't apply + CompositeExplicitAutograd: randint_like + autogen: randint_like.Tensor_out + - func: randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor tags: nondeterministic_seeded dispatch: diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 95f577642c4..99ab4d27f3a 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -12603,6 +12603,17 @@ class MiscTestsDevice(torch._inductor.test_case.TestCase): res = opt_f() self.assertEqual(ref, res) + def test_randint_no_graphbreak(self): + @torch.compile(backend="aot_eager", fullgraph=True) + def f(actions, n_act, epsilon=0.1): + actions_random = torch.randint_like(actions, n_act) + + return actions_random + + x = torch.ones([1], dtype=torch.int64) + y = torch.tensor(5) + f(x, y) + devices = ("cuda", "hpu") instantiate_device_type_tests(MiscTestsDevice, globals(), only_for=devices) diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index e98eb91de8b..042959c22cd 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -1094,6 +1094,8 @@ aten::randint.low_generator_out aten::randint.low_out aten::randint.out aten::randint_like +aten::randint_like.Tensor +aten::randint_like.Tensor_out aten::randint_like.low_dtype aten::randint_like.low_dtype_out aten::randint_like.out diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 35b0fc9abd3..0e6d4ebf92e 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -3505,6 +3505,11 @@ def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1): return self.new_empty(self.size()) +@register_meta([aten.randint_like.Tensor]) +def meta_randint_like(self, high, **kwargs): + return self.new_empty(self.size()) + + @register_meta([aten._fused_adam_.default, aten._fused_adamw_.default]) def meta__fused_adam_( self, diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 2fa02c3eb2e..db28aad88eb 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -68,6 +68,8 @@ _like_tensor_constructors = ordered_set( aten.randn_like.default, aten.randn_like.out, aten.randint_like.default, + aten.randint_like.Tensor, + aten.randint_like.Tensor_out, aten.randint_like.out, aten.randint_like.low_dtype, aten.randint_like.low_dtype_out,