mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[OpenReg] Remove the Unnecessary Fallback Implementation for AutogradPrivate1 (#165316)
As the title stated. The fallback for AutogradPrivateUse1 is builtin in PyTorch, so it is no need to register general implementation for out of tree backend. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165316 Approved by: https://github.com/ezyang, https://github.com/albanD ghstack dependencies: #165315
This commit is contained in:
parent
0c9763a5a0
commit
1d13c314b3
|
|
@ -272,7 +272,7 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
|||
* Name: `input`
|
||||
* Output Type: `Tensor`
|
||||
|
||||
2. **Register Operator&Autograd Fallback:**
|
||||
2. **Register Operator**
|
||||
|
||||
::::{tab-set}
|
||||
|
||||
|
|
@ -285,19 +285,11 @@ Here, we'll briefly introduce the implementation process of custom operators, fo
|
|||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
:linenos:
|
||||
|
||||
.. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp
|
||||
:language: c++
|
||||
:start-after: LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
:end-before: LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
:emphasize-lines: 2
|
||||
:linenos:
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
::::
|
||||
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. However, because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Therefore, we also need to register the corresponding implementation for `AutogradPrivateUse1` of the `custom_abs` operator. Fortunately, PyTorch also provides a general `Autograd Fallback` mechanism named `torch::autograd::autogradNotImplementedFallback`, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. Because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Fortunately, PyTorch have implemented a general `Autograd Fallback` for PrivateUse1 as well, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown.
|
||||
|
||||
3. **Register Metadata(optional, but required by the graph mode, etc.):**
|
||||
|
||||
|
|
|
|||
|
|
@ -156,12 +156,6 @@ TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) {
|
|||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT
|
||||
|
||||
// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK
|
||||
TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) {
|
||||
m.fallback(torch::autograd::autogradNotImplementedFallback());
|
||||
}
|
||||
// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK
|
||||
|
||||
// The rest is for testing purposes
|
||||
TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
|
||||
/*
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user