From 0d62fd5c3cc839a95e652c81ef2157cb8437d13d Mon Sep 17 00:00:00 2001 From: "Andy (An) Wang" Date: Fri, 23 May 2025 17:59:47 +0000 Subject: [PATCH] [MTIA Aten Backend][2/n] Migrate clamp ops(clamp.out/clamp_min.out/clamp_max.out) from out-of-tree to in-tree (#154015) Summary: # Context See the first PR https://github.com/pytorch/pytorch/pull/153670 # This PR 1. Migrate 3 clamp ops from out-of-tree to in-tree(had to migrate the 3 ops altogether, because clamp.out calls all 3 stubs, which are also called by the other 2 ops): - clamp.out - clamp_min.out - clamp_max.out 2. Also enabled structured kernel codegen for MTIA, which is needed by clamp 3. Also introduced the `--mtia` flag to torchgen to prevent OSS from gencoding MTIA code.(Otherwise we got such link error `lib/libtorch_cpu.so: undefined reference to at::detail::empty_mtia`) Differential Revision: D74674418 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154015 Approved by: https://github.com/albanD, https://github.com/nautsimon --- aten/src/ATen/TensorIterator.cpp | 1 - aten/src/ATen/native/mtia/EmptyTensor.cpp | 86 ++++++++++++++++++++++ aten/src/ATen/native/mtia/EmptyTensor.h | 42 +++++++++++ aten/src/ATen/native/native_functions.yaml | 6 +- build.bzl | 2 +- torchgen/dest/register_dispatch_key.py | 6 ++ torchgen/gen.py | 11 +++ torchgen/model.py | 1 + 8 files changed, 150 insertions(+), 5 deletions(-) create mode 100644 aten/src/ATen/native/mtia/EmptyTensor.cpp create mode 100644 aten/src/ATen/native/mtia/EmptyTensor.h diff --git a/aten/src/ATen/TensorIterator.cpp b/aten/src/ATen/TensorIterator.cpp index 805f1f2f6c2..28c5fd6012f 100644 --- a/aten/src/ATen/TensorIterator.cpp +++ b/aten/src/ATen/TensorIterator.cpp @@ -1535,7 +1535,6 @@ void TensorIteratorBase::build(TensorIteratorConfig& config) { // Nothing beyond this point is important for meta functions, so it's fine to exit early here. // Extend the condition to MAIA tesnors as MAIA tensors also don't have storage. if (privateuse1_without_storage || - common_device_.type() == DeviceType::MTIA || common_device_.type() == DeviceType::XLA || common_device_.type() == DeviceType::IPU || common_device_.type() == DeviceType::Lazy || diff --git a/aten/src/ATen/native/mtia/EmptyTensor.cpp b/aten/src/ATen/native/mtia/EmptyTensor.cpp new file mode 100644 index 00000000000..e7ff8b719aa --- /dev/null +++ b/aten/src/ATen/native/mtia/EmptyTensor.cpp @@ -0,0 +1,86 @@ +#include +#include +#include +#include +#include + +namespace at::detail { + +at::Allocator* GetMTIAAllocator() { + return GetAllocator(DeviceType::MTIA); +} + +TensorBase empty_mtia( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt) { + at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); + const auto device = device_or_default(device_opt); + TORCH_INTERNAL_ASSERT(device.is_mtia()); + const DeviceGuard device_guard(device); + auto* allocator = GetMTIAAllocator(); + constexpr c10::DispatchKeySet mtia_dks(c10::DispatchKey::MTIA); + return at::detail::empty_generic( + size, allocator, mtia_dks, dtype, memory_format_opt); +} + +TensorBase empty_mtia( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto dtype = dtype_or_default(dtype_opt); + return at::detail::empty_mtia(size, dtype, device_opt, memory_format_opt); +} + +TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options) { + return at::detail::empty_mtia( + size, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt(), + options.memory_format_opt()); +} + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt) { + at::globalContext().lazyInitDevice(c10::DeviceType::MTIA); + const auto device = device_or_default(device_opt); + const DeviceGuard device_guard(device); + auto* allocator = GetMTIAAllocator(); + constexpr c10::DispatchKeySet mtia_dks(c10::DispatchKey::MTIA); + return at::detail::empty_strided_generic( + size, stride, allocator, mtia_dks, dtype); +} + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto dtype = dtype_or_default(dtype_opt); + return at::detail::empty_strided_mtia(size, stride, dtype, device_opt); +} + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options) { + return at::detail::empty_strided_mtia( + size, + stride, + optTypeMetaToScalarType(options.dtype_opt()), + options.layout_opt(), + options.device_opt(), + options.pinned_memory_opt()); +} +} // namespace at::detail diff --git a/aten/src/ATen/native/mtia/EmptyTensor.h b/aten/src/ATen/native/mtia/EmptyTensor.h new file mode 100644 index 00000000000..afd1e58d40f --- /dev/null +++ b/aten/src/ATen/native/mtia/EmptyTensor.h @@ -0,0 +1,42 @@ + +#pragma once +#include + +namespace at::detail { + +TensorBase empty_mtia( + IntArrayRef size, + ScalarType dtype, + std::optional device_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TensorBase empty_mtia(IntArrayRef size, const TensorOptions& options); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + std::optional device_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TensorBase empty_strided_mtia( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 91358def69a..8ba40757482 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1548,7 +1548,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_out + CPU, CUDA, MTIA: clamp_out MPS: clamp_out_mps tags: pointwise @@ -1588,7 +1588,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_max_out + CPU, CUDA, MTIA: clamp_max_out MPS: clamp_max_out_mps tags: pointwise @@ -1628,7 +1628,7 @@ structured: True structured_inherits: TensorIteratorBase dispatch: - CPU, CUDA: clamp_min_out + CPU, CUDA, MTIA: clamp_min_out MPS: clamp_min_out_mps tags: pointwise diff --git a/build.bzl b/build.bzl index 28a7cedbee3..7c2c3e24dc5 100644 --- a/build.bzl +++ b/build.bzl @@ -72,7 +72,7 @@ def define_targets(rules): "--install_dir=$(RULEDIR)", "--source-path aten/src/ATen", "--aoti_install_dir=$(RULEDIR)/torch/csrc/inductor/aoti_torch/generated" - ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else [])) + ] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else []) + ["--mtia"]) gen_aten_outs_cuda = ( GENERATED_H_CUDA + GENERATED_CPP_CUDA + GENERATED_AOTI_CUDA_CPP + diff --git a/torchgen/dest/register_dispatch_key.py b/torchgen/dest/register_dispatch_key.py index af3d4c9ca74..ffe90bcaba8 100644 --- a/torchgen/dest/register_dispatch_key.py +++ b/torchgen/dest/register_dispatch_key.py @@ -66,6 +66,8 @@ def gen_registration_headers( elif backend_index.dispatch_key == DispatchKey.XPU: # XPU specific, this header resides in third_party/torch-xpu-ops headers.append("#include ") + elif backend_index.dispatch_key == DispatchKey.MTIA: + headers.append("#include ") elif per_operator_headers: headers += [ "#include ", @@ -92,6 +94,7 @@ def gen_empty_impl_names( DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.XPU, + DispatchKey.MTIA, ): dispatch = str(backend_index.dispatch_key).lower() empty_impl = f"at::detail::empty_{dispatch}" @@ -645,6 +648,7 @@ if (C10_UNLIKELY(maybe_proxy.has_value())) { DispatchKey.CUDA, DispatchKey.MPS, DispatchKey.XPU, + DispatchKey.MTIA, DispatchKey.CompositeExplicitAutogradNonFunctional, ) return f"""{maybe_set_guard_line} @@ -724,6 +728,8 @@ resize_out(out, sizes, strides, options); guard_field = "c10::OptionalDeviceGuard guard_;" elif self.backend_index.dispatch_key == DispatchKey.XPU: guard_field = "c10::OptionalDeviceGuard guard_;" + elif self.backend_index.dispatch_key == DispatchKey.MTIA: + guard_field = "c10::OptionalDeviceGuard guard_;" else: guard_field = "" diff --git a/torchgen/gen.py b/torchgen/gen.py index 025a1dde421..b584a87880f 100644 --- a/torchgen/gen.py +++ b/torchgen/gen.py @@ -2820,6 +2820,11 @@ def main() -> None: action="store_true", help="Generate XPU registration code when set", ) + parser.add_argument( + "--mtia", + action="store_true", + help="Generate MTIA registration code when set", + ) # TODO: --op-registration-whitelist will be removed when all call-sites # for gen.py are moved over to using the operator YAML file for mobile @@ -2918,6 +2923,12 @@ def main() -> None: if DispatchKey.XPU in dispatch_keys: del dispatch_keys[dispatch_keys.index(DispatchKey.XPU)] + if not options.mtia: + ignore_keys.add(DispatchKey.MTIA) + + if DispatchKey.MTIA in dispatch_keys: + del dispatch_keys[dispatch_keys.index(DispatchKey.MTIA)] + parsed_yaml = parse_native_yaml(native_yaml_path, tags_yaml_path, ignore_keys) valid_tags = _GLOBAL_PARSE_TAGS_YAML_CACHE[tags_yaml_path] native_functions, backend_indices = ( diff --git a/torchgen/model.py b/torchgen/model.py index a3c0542788b..89a56d98e74 100644 --- a/torchgen/model.py +++ b/torchgen/model.py @@ -271,6 +271,7 @@ STRUCTURED_DISPATCH_KEYS = { DispatchKey.CUDA, DispatchKey.CPU, DispatchKey.XPU, + DispatchKey.MTIA, } UFUNC_DISPATCH_KEYS = {DispatchKey.CUDA, DispatchKey.CPU}