From daff26306202e843cb15d402b17ed9f320139504 Mon Sep 17 00:00:00 2001 From: "fan.mo" Date: Sat, 31 May 2025 07:28:42 +0000 Subject: [PATCH] [Functorch] Support Functorch for PrivateUse1 backend (#154700) This PR enable that functorch to be used in 3rd party backends. Pull Request resolved: https://github.com/pytorch/pytorch/pull/154700 Approved by: https://github.com/zou3519 --- aten/src/ATen/functorch/BatchedTensorImpl.h | 1 + aten/src/ATen/functorch/PyTorchOperatorHacks.cpp | 2 +- aten/src/ATen/functorch/TensorWrapper.cpp | 6 ++++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/aten/src/ATen/functorch/BatchedTensorImpl.h b/aten/src/ATen/functorch/BatchedTensorImpl.h index e42f8dd87b5..ce3d2900841 100644 --- a/aten/src/ATen/functorch/BatchedTensorImpl.h +++ b/aten/src/ATen/functorch/BatchedTensorImpl.h @@ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({ DispatchKey::XLA, DispatchKey::CUDA, DispatchKey::CPU, + DispatchKey::PrivateUse1, }); inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) { diff --git a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp index 7bc3a3cbfe4..ecedc729ccd 100644 --- a/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp +++ b/aten/src/ATen/functorch/PyTorchOperatorHacks.cpp @@ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) { } static bool is_fused_kernel_acceptable(const Tensor& input, double p) { - return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0; + return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0; } // NB: sure, we could have used different overloads here, but I would feel insecure diff --git a/aten/src/ATen/functorch/TensorWrapper.cpp b/aten/src/ATen/functorch/TensorWrapper.cpp index 4f50a1fe2b4..65de9268927 100644 --- a/aten/src/ATen/functorch/TensorWrapper.cpp +++ b/aten/src/ATen/functorch/TensorWrapper.cpp @@ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) { static c10::intrusive_ptr makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr& life_handle) { auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ - DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradPrivateUse1}); auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); return c10::make_intrusive(key_set, tensor, level, life_handle); @@ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper( } auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({ - DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA}); + DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA, + DispatchKey::AutogradPrivateUse1}); auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate); key_set = key_set.add(DispatchKey::FuncTorchGradWrapper); auto result = at::detail::make_tensor(