mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Autograd] Add Default Autograd Fallback for PrivateUse1 in PyTorch (#165315)
Please refer to this [link](https://github.com/pytorch/pytorch/issues/163979) for more background. - Allow register fallback for AutogradPrivateUse1 multiple. - Add Autograd fallback implemetation for AutogradPrivateUse1 PyTorch can privide a common implementation for AutogradPrivateUse1, and the user can override it based on the need of specififc accelerator. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165315 Approved by: https://github.com/albanD
This commit is contained in:
parent
79a4a9c02e
commit
0c9763a5a0
|
|
@ -109,6 +109,10 @@ TORCH_LIBRARY_IMPL(_, AutogradHPU, m) {
|
||||||
m.fallback(AUTOGRAD_FALLBACK);
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||||
|
m.fallback(AUTOGRAD_FALLBACK);
|
||||||
|
}
|
||||||
|
|
||||||
#undef AUTOGRAD_FALLBACK
|
#undef AUTOGRAD_FALLBACK
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -442,11 +442,17 @@ RegistrationHandleRAII Dispatcher::registerFallback(DispatchKey dispatchKey, Ker
|
||||||
|
|
||||||
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
auto idx = getDispatchTableIndexForDispatchKey(dispatchKey);
|
||||||
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
TORCH_CHECK(idx >= 0 && static_cast<uint64_t>(idx) < backendFallbackKernels_.size(), "idx=", idx);
|
||||||
|
// NB: Perserve BC for registering fallback for AutogradPrivateUse1 multiple time,
|
||||||
|
// refer to https://github.com/pytorch/pytorch/issues/163979 for more informations.
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
!backendFallbackKernels_[idx].kernel.isValid(),
|
dispatchKey == DispatchKey::AutogradPrivateUse1 ||
|
||||||
"Tried to register multiple backend fallbacks for the same dispatch key ", dispatchKey, "; previous registration ",
|
!backendFallbackKernels_[idx].kernel.isValid(),
|
||||||
backendFallbackKernels_[idx].debug, ", new registration ", debug
|
"Tried to register multiple backend fallbacks for the same dispatch key ",
|
||||||
);
|
dispatchKey,
|
||||||
|
"; previous registration ",
|
||||||
|
backendFallbackKernels_[idx].debug,
|
||||||
|
", new registration ",
|
||||||
|
debug);
|
||||||
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
// NB: inferred function schema is always nullptr for fallbacks, as fallbacks
|
||||||
// cannot be unboxed
|
// cannot be unboxed
|
||||||
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
backendFallbackKernels_[idx] = impl::AnnotatedKernel(std::move(kernel), nullptr, std::move(debug));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user