mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Currently the C++ API and C++ extensions are effectively two different, entirely orthogonal code paths. This PR unifies the C++ API with the C++ extension API by adding an element of Python binding support to the C++ API. This means the `torch/torch.h` included by C++ extensions, which currently routes to `torch/csrc/torch.h`, can now be rerouted to `torch/csrc/api/include/torch/torch.h` -- i.e. the main C++ API header. This header then includes Python binding support conditioned on a define (`TORCH_WITH_PYTHON_BINDINGS`), *which is only passed when building a C++ extension*. Currently stacked on top of https://github.com/pytorch/pytorch/pull/11498 Why is this useful? 1. One less codepath. In particular, there has been trouble again and again due to the two `torch/torch.h` header files and ambiguity when both ended up in the include path. This is now fixed. 2. I have found that it is quite common to want to bind a C++ API module back into Python. This could be for simple experimentation, or to have your training loop in Python but your models in C++. This PR makes this easier by adding pybind11 support to the C++ API. 3. The C++ extension API simply becomes richer by gaining access to the C++ API headers. soumith ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/11510 Reviewed By: ezyang Differential Revision: D9998835 Pulled By: goldsborough fbshipit-source-id: 7a94b44a9d7e0377b7f1cfc99ba2060874d51535
81 lines
3.2 KiB
C++
81 lines
3.2 KiB
C++
/*
|
|
* CuDNN ReLU extension. Simple function but contains the general structure of
|
|
* most CuDNN extensions:
|
|
* 1) Check arguments. at::check* functions provide a standard way to validate
|
|
* input and provide pretty errors.
|
|
* 2) Create descriptors. Most CuDNN functions require creating and setting a
|
|
* variety of descriptors.
|
|
* 3) Apply the CuDNN function.
|
|
* 4) Destroy your descriptors.
|
|
* 5) Return something (optional).
|
|
*/
|
|
|
|
#include <torch/extension.h>
|
|
|
|
#include <ATen/cudnn/Descriptors.h> // for TensorDescriptor
|
|
#include <ATen/cuda/Exceptions.h> // for CUDNN_CHECK
|
|
#include <ATen/cudnn/Handle.h> // for getCudnnHandle
|
|
|
|
// Name of function in python module and name used for error messages by
|
|
// at::check* functions.
|
|
const char* cudnn_relu_name = "cudnn_relu";
|
|
|
|
// Check arguments to cudnn_relu
|
|
void cudnn_relu_check(const at::Tensor& inputs, const at::Tensor& outputs) {
|
|
// Create TensorArgs. These record the names and positions of each tensor as a
|
|
// parameter.
|
|
at::TensorArg arg_inputs(inputs, "inputs", 0);
|
|
at::TensorArg arg_outputs(outputs, "outputs", 1);
|
|
// Check arguments. No need to return anything. These functions with throw an
|
|
// error if they fail. Messages are populated using information from
|
|
// TensorArgs.
|
|
at::checkContiguous(cudnn_relu_name, arg_inputs);
|
|
at::checkScalarType(cudnn_relu_name, arg_inputs, at::kFloat);
|
|
at::checkBackend(cudnn_relu_name, arg_inputs.tensor, at::Backend::CUDA);
|
|
at::checkContiguous(cudnn_relu_name, arg_outputs);
|
|
at::checkScalarType(cudnn_relu_name, arg_outputs, at::kFloat);
|
|
at::checkBackend(cudnn_relu_name, arg_outputs.tensor, at::Backend::CUDA);
|
|
at::checkSameSize(cudnn_relu_name, arg_inputs, arg_outputs);
|
|
}
|
|
|
|
void cudnn_relu(const at::Tensor& inputs, const at::Tensor& outputs) {
|
|
// Most CuDNN extensions will follow a similar pattern.
|
|
// Step 1: Check inputs. This will throw an error if inputs are invalid, so no
|
|
// need to check return codes here.
|
|
cudnn_relu_check(inputs, outputs);
|
|
// Step 2: Create descriptors
|
|
cudnnHandle_t cuDnn = at::native::getCudnnHandle();
|
|
// Note: 4 is minimum dim for a TensorDescriptor. Input and output are same
|
|
// size and type and contiguous, so one descriptor is sufficient.
|
|
at::native::TensorDescriptor input_tensor_desc(inputs, 4);
|
|
cudnnActivationDescriptor_t activationDesc;
|
|
// Note: Always check return value of cudnn functions using CUDNN_CHECK
|
|
AT_CUDNN_CHECK(cudnnCreateActivationDescriptor(&activationDesc));
|
|
AT_CUDNN_CHECK(cudnnSetActivationDescriptor(
|
|
activationDesc,
|
|
/*mode=*/CUDNN_ACTIVATION_RELU,
|
|
/*reluNanOpt=*/CUDNN_PROPAGATE_NAN,
|
|
/*coef=*/1.));
|
|
// Step 3: Apply CuDNN function
|
|
float alpha = 1.;
|
|
float beta = 0.;
|
|
AT_CUDNN_CHECK(cudnnActivationForward(
|
|
cuDnn,
|
|
activationDesc,
|
|
&alpha,
|
|
input_tensor_desc.desc(),
|
|
inputs.data_ptr(),
|
|
&beta,
|
|
input_tensor_desc.desc(), // output descriptor same as input
|
|
outputs.data_ptr()));
|
|
// Step 4: Destroy descriptors
|
|
AT_CUDNN_CHECK(cudnnDestroyActivationDescriptor(activationDesc));
|
|
// Step 5: Return something (optional)
|
|
}
|
|
|
|
// Create the pybind11 module
|
|
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
|
// Use the same name as the check functions so error messages make sense
|
|
m.def(cudnn_relu_name, &cudnn_relu, "CuDNN ReLU");
|
|
}
|