[Functorch] Support Functorch for PrivateUse1 backend (#154700)

This PR enable that functorch to be used in 3rd party backends.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154700
Approved by: https://github.com/zou3519
This commit is contained in:
fan.mo 2025-05-31 07:28:42 +00:00 committed by PyTorch MergeBot
parent 15e9119a69
commit daff263062
3 changed files with 6 additions and 3 deletions

View File

@ -159,6 +159,7 @@ constexpr DispatchKeySet kKeysToPropagateToWrapper({
DispatchKey::XLA,
DispatchKey::CUDA,
DispatchKey::CPU,
DispatchKey::PrivateUse1,
});
inline DispatchKeySet getKeysToPropagateToWrapper(const Tensor& tensor, DispatchKeySet to_propagate=kKeysToPropagateToWrapper) {

View File

@ -143,7 +143,7 @@ static Tensor make_feature_noise(const Tensor& input) {
}
static bool is_fused_kernel_acceptable(const Tensor& input, double p) {
return (input.is_cuda() || input.is_xpu() || input.is_lazy()) && p > 0 && p < 1 && input.numel() > 0;
return (input.is_cuda() || input.is_xpu() || input.is_lazy() || input.is_privateuseone()) && p > 0 && p < 1 && input.numel() > 0;
}
// NB: sure, we could have used different overloads here, but I would feel insecure

View File

@ -56,7 +56,8 @@ void dumpTensorCout(const Tensor& tensor) {
static c10::intrusive_ptr<TensorWrapper> makeTensorWrapperPtr(const Tensor& tensor, int64_t level, const std::shared_ptr<bool>& life_handle) {
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
DispatchKey::AutogradPrivateUse1});
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
return c10::make_intrusive<TensorWrapper>(key_set, tensor, level, life_handle);
@ -76,7 +77,8 @@ static Tensor unsafeMakeTensorWrapper(
}
auto keys_to_propagate = kKeysToPropagateToWrapper | DispatchKeySet({
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA});
DispatchKey::AutogradCPU, DispatchKey::AutogradCUDA, DispatchKey::AutogradXLA,
DispatchKey::AutogradPrivateUse1});
auto key_set = getKeysToPropagateToWrapper(tensor, keys_to_propagate);
key_set = key_set.add(DispatchKey::FuncTorchGradWrapper);
auto result = at::detail::make_tensor<TensorWrapper>(