mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[Autograd] Add Default Autograd Fallback for PrivateUse1 in PyTorch (#165315)
Please refer to this [link](https://github.com/pytorch/pytorch/issues/163979) for more background. - Allow register fallback for AutogradPrivateUse1 multiple. - Add Autograd fallback implemetation for AutogradPrivateUse1 PyTorch can privide a common implementation for AutogradPrivateUse1, and the user can override it based on the need of specififc accelerator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165315 Approved by: https://github.com/albanD
This commit is contained in:
parent
79a4a9c02e
commit
0c9763a5a0
|
|
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
|||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(AUTOGRAD_FALLBACK);
|
||||
}
|
||||
|
||||
#undef AUTOGRAD_FALLBACK
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
|||
|
||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||
TORCH_CHECK(
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
||||
);
|
||||
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||
dispatchKey,
|
||||
"; previous registration ",
|
||||
backendFallbackKernels_[idx].debug,
|
||||
", new registration ",
|
||||
debug);
|
||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||
// cannot be unboxed
|
||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user