diff --git a/docs/source/accelerator/operators.md b/docs/source/accelerator/operators.md index d5ae2aa5a2c..a172bd7d93f 100644 --- a/docs/source/accelerator/operators.md +++ b/docs/source/accelerator/operators.md @@ -272,7 +272,7 @@ Here, we'll briefly introduce the implementation process of custom operators, fo * Name: `input` * Output Type: `Tensor` -2. **Register Operator&Autograd Fallback:** +2. **Register Operator** ::::{tab-set} @@ -285,19 +285,11 @@ Here, we'll briefly introduce the implementation process of custom operators, fo :end-before: LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT :linenos: - .. literalinclude:: ../../../test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp - :language: c++ - :start-after: LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK - :end-before: LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK - :emphasize-lines: 2 - :linenos: - ``` - ::: :::: - Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. However, because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Therefore, we also need to register the corresponding implementation for `AutogradPrivateUse1` of the `custom_abs` operator. Fortunately, PyTorch also provides a general `Autograd Fallback` mechanism named `torch::autograd::autogradNotImplementedFallback`, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown. + Use `TORCH_LIBRARY_IMPL` to register the `wrapper_custom_abs` implementation for the `custom_abs` operator in `PrivateUse1`. Because `Autograd` is always enabled in PyTorch, PyTorch defaults to finding and executing the corresponding backward implementation even if only forward computation is required(will fallthrough in backward implementation). Fortunately, PyTorch have implemented a general `Autograd Fallback` for PrivateUse1 as well, if only forward computation is involved, it is equivalent to a fallthrough operation, selecting the next DispatchKey for computation; if backward computation is involved, an error is thrown. 3. **Register Metadata(optional, but required by the graph mode, etc.):** diff --git a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp index 04ba6d48e89..1ba82564cc3 100644 --- a/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp +++ b/test/cpp_extensions/open_registration_extension/torch_openreg/csrc/aten/OpenRegExtra.cpp @@ -156,12 +156,6 @@ TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) { } // LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT -// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK -TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { - m.fallback(torch::autograd::autogradNotImplementedFallback()); -} -// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK - // The rest is for testing purposes TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { /*