From 63632fc7eee1eaf5e63209777de01c08cfbe159c Mon Sep 17 00:00:00 2001 From: Jane Xu Date: Thu, 28 Aug 2025 13:57:24 +0000 Subject: [PATCH] Add new_zeros dtype variant to the shim and as a stable op (#161597) In case we want this before 2.9 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161597 Approved by: https://github.com/mikaylagawarecki --- .../libtorch_agnostic/csrc/kernel.cpp | 15 +++++++- .../libtorch_agnostic/ops.py | 12 ++++++ .../test/test_libtorch_agnostic.py | 8 ++++ .../aoti_torch/generated/c_shim_aten.h | 1 + torch/csrc/stable/ops.h | 38 +++++++++++++++++++ torchgen/aoti/fallback_ops.py | 1 + 6 files changed, 74 insertions(+), 1 deletion(-) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp index 943af3c3575..306a882627d 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/csrc/kernel.cpp @@ -343,7 +343,7 @@ void boxed_my_narrow( Tensor my_new_empty_dtype_variant(Tensor t) { std::vector sizes = {2, 5}; - auto dtype = std::make_optional(at::ScalarType::BFloat16); + auto dtype = std::make_optional(torch::headeronly::ScalarType::BFloat16); return new_empty(t, sizes, dtype); } @@ -352,6 +352,17 @@ void boxed_my_new_empty_dtype_variant(StableIValue* stack, uint64_t num_args, ui stack[0] = from(res); } +Tensor my_new_zeros_dtype_variant(Tensor t) { + std::vector sizes = {2, 5}; + auto dtype = std::make_optional(at::ScalarType::Float); + return new_zeros(t, sizes, dtype); +} + +void boxed_my_new_zeros_dtype_variant(StableIValue* stack, uint64_t num_args, uint64_t num_outputs) { + auto res = my_new_zeros_dtype_variant(to(stack[0])); + stack[0] = from(res); +} + STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_transpose(Tensor t, int dim0, int dim1) -> Tensor"); m.def("my_empty_like(Tensor t) -> Tensor"); @@ -359,6 +370,7 @@ STABLE_TORCH_LIBRARY_FRAGMENT(libtorch_agnostic, m) { m.def("my_pad(Tensor t) -> Tensor"); m.def("my_narrow(Tensor t, int dim, int start, int length) -> Tensor"); m.def("my_new_empty_dtype_variant(Tensor t) -> Tensor"); + m.def("my_new_zeros_dtype_variant(Tensor t) -> Tensor"); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { @@ -367,6 +379,7 @@ STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeExplicitAutograd, m) { m.impl("fill_infinity", &boxed_fill_infinity); m.impl("my_is_cpu", &boxed_my_is_cpu); m.impl("my_new_empty_dtype_variant", &boxed_my_new_empty_dtype_variant); + m.impl("my_new_zeros_dtype_variant", &boxed_my_new_zeros_dtype_variant); } STABLE_TORCH_LIBRARY_IMPL(libtorch_agnostic, CompositeImplicitAutograd, m) { diff --git a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py index ebb4ba58249..074461d3527 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/libtorch_agnostic/ops.py @@ -295,3 +295,15 @@ def my_new_empty_dtype_variant(t) -> Tensor: Returns: New empty tensor with shape [2, 5] and dtype bfloat16 """ return torch.ops.libtorch_agnostic.my_new_empty_dtype_variant.default(t) + + +def my_new_zeros_dtype_variant(t) -> Tensor: + """ + Returns a new tensor filled with 0s with shape [2, 5] and dtype Float + + Args: + t: Input tensor used as a reference for device and other properties + + Returns: New zeros tensor + """ + return torch.ops.libtorch_agnostic.my_new_zeros_dtype_variant.default(t) diff --git a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py index 6783f040bcd..0f471e8132a 100644 --- a/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py +++ b/test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py @@ -337,6 +337,14 @@ if not IS_WINDOWS: finally: torch.use_deterministic_algorithms(deterministic) + def test_my_new_zeros_dtype_variant(self, device): + import libtorch_agnostic + + t = torch.randn(3, 4, device=device) + out = libtorch_agnostic.ops.my_new_zeros_dtype_variant(t) + ref_out = t.new_zeros((2, 5), dtype=torch.float) + self.assertEqual(out, ref_out, exact_device=True) + instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None) if __name__ == "__main__": diff --git a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h index c262b91ab47..4672e3293c5 100644 --- a/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h +++ b/torch/csrc/inductor/aoti_torch/generated/c_shim_aten.h @@ -18,6 +18,7 @@ AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_amax(AtenTensorHandle self, con AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_fill__Scalar(AtenTensorHandle self, double value); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_narrow(AtenTensorHandle self, int64_t dim, int64_t start, int64_t length, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_empty(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); +AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_new_zeros(AtenTensorHandle self, const int64_t* size, int64_t size_len_, int32_t* dtype, int32_t* layout, int32_t* device, int32_t device_index_, int32_t* pin_memory, AtenTensorHandle* ret0); AOTI_TORCH_EXPORT AOTITorchError aoti_torch_aten_pad(AtenTensorHandle self, const int64_t* pad, int64_t pad_len_, const char* mode, double* value, AtenTensorHandle* ret0); #ifdef __cplusplus diff --git a/torch/csrc/stable/ops.h b/torch/csrc/stable/ops.h index d4bb5947abc..669007fcf9f 100644 --- a/torch/csrc/stable/ops.h +++ b/torch/csrc/stable/ops.h @@ -90,6 +90,44 @@ inline Tensor new_empty( return Tensor(ret0); } +// We expect this to be a stable version of the new_zeros op that takes in +// only dtype information. +inline Tensor new_zeros( + const Tensor& self, + std::vector size, + std::optional dtype = std::nullopt) { + int32_t device_type; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_device_type(self.get(), &device_type)); + + int32_t device_index; + TORCH_ERROR_CODE_CHECK( + aoti_torch_get_device_index(self.get(), &device_index)); + + int32_t target_dtype; + if (dtype.has_value()) { + target_dtype = to(from(dtype.value())); + } else { + TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype(self.get(), &target_dtype)); + } + + int32_t layout; + TORCH_ERROR_CODE_CHECK(aoti_torch_get_layout(self.get(), &layout)); + + AtenTensorHandle ath; + TORCH_ERROR_CODE_CHECK(aoti_torch_aten_new_zeros( + self.get(), + size.data(), + static_cast(size.size()), + &target_dtype, + &layout, + &device_type, + device_index, + nullptr, // pin_memory (nullptr for default) + &ath)); + + return Tensor(ath); +} + // We expect this to be the stable version of the pad.default op. // pad.default takes in a SymInt[] as the pad argument however pad is typed as // use std::vector because diff --git a/torchgen/aoti/fallback_ops.py b/torchgen/aoti/fallback_ops.py index b1e4618ef0d..611400d271d 100644 --- a/torchgen/aoti/fallback_ops.py +++ b/torchgen/aoti/fallback_ops.py @@ -187,4 +187,5 @@ aten_shimified_ops: dict[str, dict[str, list[str]]] = { "aten.narrow.default": {}, "aten.amax.default": {}, "aten.new_empty.default": {}, + "aten.new_zeros.default": {}, }