[Autograd] Add Default Autograd Fallback for PrivateUse1 in PyTorch (#165315)

Please refer to this [link](https://github.com/pytorch/pytorch/issues/163979) for more background.

- Allow register fallback for AutogradPrivateUse1 multiple.
- Add Autograd fallback implemetation for AutogradPrivateUse1

PyTorch can privide a common implementation for AutogradPrivateUse1, and the user can override it based on the need of specififc accelerator.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165315
Approved by: https://github.com/albanD
This commit is contained in:
FFFrog 2025-10-24 13:53:33 +00:00 committed by PyTorch MergeBot
parent 79a4a9c02e
commit 0c9763a5a0
2 changed files with 14 additions and 4 deletions

View File

@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
m.fallback(AUTOGRAD_FALLBACK);
}
#undef AUTOGRAD_FALLBACK
} // namespace

View File

@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
TORCH_CHECK(
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
backendFallbackKernels_[idx].debug, ", new registration ", debug
);
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
!backendFallbackKernels_[idx].kernel.isValid(),
"Tried to register multiple backend fallbacks for the same dispatch key ",
dispatchKey,
"; previous registration ",
backendFallbackKernels_[idx].debug,
", new registration ",
debug);
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
// cannot be unboxed
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));