mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert D34803275: [Quant][core][gpu][improvement] Refactored implementation for conv2d_cudnn to use packed parameters
Test Plan: revert-hammer Differential Revision: D34803275 Original commit changeset: 299479c0315f Original Phabricator Diff: D34803275 fbshipit-source-id: bcdd615d7910ad150aed6a43f8812c385322d491 (cherry picked from commit edca65c3a6a6e197dadcafadd559faf5bcbb9deb)
This commit is contained in:
parent
dab5659d74
commit
22b876782f
|
|
@ -8,22 +8,25 @@
|
|||
|
||||
#if HAS_CUDNN_V8()
|
||||
|
||||
#include <cudnn_frontend.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <ATen/cudnn/Handle.h>
|
||||
#include <ATen/native/ConvUtils.h>
|
||||
#include <ATen/native/cudnn/ConvShared.h>
|
||||
#include <ATen/native/quantized/cudnn/cudnnpack_utils.h>
|
||||
#include <ATen/native/quantized/packed_params.h>
|
||||
#include <ATen/native/utils/ParamsHash.h>
|
||||
#include <ATen/cudnn/Handle.h>
|
||||
#include <ATen/TensorUtils.h>
|
||||
#include <cudnn_frontend.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <unordered_map>
|
||||
#include <iostream>
|
||||
|
||||
uint8_t getAlignment(const at::Tensor &t) {
|
||||
namespace at { namespace native{
|
||||
|
||||
namespace {
|
||||
|
||||
uint8_t getAlignment(const Tensor &t) {
|
||||
// alignment are in bytes
|
||||
uint8_t alignment = 1;
|
||||
uintptr_t address = reinterpret_cast<uintptr_t>(t.data_ptr());
|
||||
|
|
@ -31,7 +34,7 @@ uint8_t getAlignment(const at::Tensor &t) {
|
|||
return alignment;
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint8_t alignment) {
|
||||
cudnn_frontend::Tensor getTensorDescriptor(const Tensor &t, int64_t id, uint8_t alignment) {
|
||||
auto shape = t.sizes();
|
||||
auto strides = t.strides();
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
|
|
@ -39,11 +42,11 @@ cudnn_frontend::Tensor getTensorDescriptor(const at::Tensor &t, int64_t id, uint
|
|||
.setStrides(strides.size(), strides.data())
|
||||
.setId(id)
|
||||
.setAlignment(alignment)
|
||||
.setDataType(at::native::getCudnnDataType(t))
|
||||
.setDataType(getCudnnDataType(t))
|
||||
.build();
|
||||
}
|
||||
|
||||
cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const c10::IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment) {
|
||||
cudnn_frontend::Tensor getTensorDescriptor(const IntArrayRef& shape, const IntArrayRef& strides, cudnnDataType_t cudnn_dtype, int64_t id, uint8_t alignment) {
|
||||
return cudnn_frontend::TensorBuilder()
|
||||
.setDim(shape.size(), shape.data())
|
||||
.setStrides(strides.size(), strides.data())
|
||||
|
|
@ -55,7 +58,7 @@ cudnn_frontend::Tensor getTensorDescriptor(const c10::IntArrayRef& shape, const
|
|||
|
||||
// TODO: there is a table from input dtype and weight dtype to operator dtype,
|
||||
// we can derive the operator dtype based on input dtype
|
||||
cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, c10::IntArrayRef padding, c10::IntArrayRef stride, c10::IntArrayRef dilation) {
|
||||
cudnn_frontend::ConvDesc_v8 getConvDescriptor(cudnnDataType_t dataType, IntArrayRef padding, IntArrayRef stride, IntArrayRef dilation) {
|
||||
uint64_t convDim = stride.size();
|
||||
return cudnn_frontend::ConvDescBuilder()
|
||||
.setDataType(dataType)
|
||||
|
|
@ -104,7 +107,7 @@ void filterEngineConfigs(
|
|||
if (deterministic) {
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC>(c)) return true;
|
||||
}
|
||||
if (scalar_type == at::kFloat || scalar_type == at::kChar || !allow_tf32) {
|
||||
if (scalar_type == kFloat || scalar_type == kChar || !allow_tf32) {
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS>(c)) return true;
|
||||
if (cudnn_frontend::hasNumericalNote<CUDNN_NUMERICAL_NOTE_TENSOR_CORE>(c)) return true;
|
||||
}
|
||||
|
|
@ -150,7 +153,7 @@ get_execplan_from_heuristics_else_fall_back(cudnn_frontend::OperationGraph&& opG
|
|||
}
|
||||
|
||||
struct CacheKey {
|
||||
at::native::ConvolutionParams params;
|
||||
ConvolutionParams params;
|
||||
uint8_t input_alignment;
|
||||
uint8_t weight_alignment;
|
||||
uint8_t output_alignment;
|
||||
|
|
@ -159,9 +162,8 @@ struct CacheKey {
|
|||
};
|
||||
|
||||
// FIXME: make this thread-safe by reusing the benchmark cache in Conv_v7.cpp
|
||||
namespace {
|
||||
std::unordered_map<CacheKey, cudnn_frontend::ManagedOpaqueDescriptor, at::native::ParamsHash<CacheKey>, at::native::ParamsEqual<CacheKey>> execution_plan_cache;
|
||||
}
|
||||
std::unordered_map<CacheKey, cudnn_frontend::ManagedOpaqueDescriptor, ParamsHash<CacheKey>, ParamsEqual<CacheKey>> execution_plan_cache;
|
||||
|
||||
// TODO: we can use cudnn_frontend::ExecutionPlanCache when it supports caching
|
||||
// multiple operators
|
||||
// reference: https://github.com/NVIDIA/cudnn-frontend/blob/main/samples/conv_sample.cpp#L293
|
||||
|
|
@ -173,9 +175,9 @@ at::SmallVector<int64_t, kSpatialDim + 2> MakeConvOutputShape(
|
|||
int M, // output channels
|
||||
const std::array<int64_t, kSpatialDim>& input_image_shape,
|
||||
const std::vector<int64_t>& kernel,
|
||||
const torch::List<int64_t>& stride,
|
||||
const torch::List<int64_t>& padding,
|
||||
const torch::List<int64_t>& dilation);
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation);
|
||||
|
||||
template <>
|
||||
at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
|
||||
|
|
@ -183,9 +185,9 @@ at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
|
|||
int M, // output channels
|
||||
const std::array<int64_t, 2>& input_image_shape,
|
||||
const std::vector<int64_t>& kernel,
|
||||
const torch::List<int64_t>& stride,
|
||||
const torch::List<int64_t>& padding,
|
||||
const torch::List<int64_t>& dilation) {
|
||||
IntArrayRef stride,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef dilation) {
|
||||
const int H = input_image_shape[0];
|
||||
const int W = input_image_shape[1];
|
||||
const int64_t Y_H =
|
||||
|
|
@ -195,33 +197,45 @@ at::SmallVector<int64_t, 4> MakeConvOutputShape<2>(
|
|||
return {N, M, Y_H, Y_W};
|
||||
}
|
||||
|
||||
|
||||
// the parameter quantized_output is a quantized tensor
|
||||
template <int kSpatialDim>
|
||||
template <bool kReluFused>
|
||||
void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& quantized_output, const at::Tensor& input,
|
||||
double bias_multiplier, double requantize_multiplier) {
|
||||
void raw_cudnn_convolution_forward_out(
|
||||
const Tensor& quantized_output,
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const c10::optional<Tensor> &bias,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool benchmark,
|
||||
bool deterministic,
|
||||
bool allow_tf32,
|
||||
float bias_multiplier,
|
||||
float requantize_multiplier
|
||||
) {
|
||||
TORCH_CHECK(!benchmark, "not supported yet");
|
||||
if (quantized_output.numel() == 0) {
|
||||
return;
|
||||
}
|
||||
at::Tensor conv_output = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
|
||||
Tensor conv_output = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
// TODO: combine empty & fill_ using full_like or full
|
||||
at::Tensor requantize_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
Tensor requantize_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
requantize_multiplier_tensor.fill_(requantize_multiplier);
|
||||
c10::optional<at::Tensor> bias_multiplier_tensor;
|
||||
c10::optional<at::Tensor> after_scales_bias;
|
||||
c10::optional<at::Tensor> after_add;
|
||||
c10::optional<at::Tensor> broadcasted_bias;
|
||||
c10::optional<at::Tensor> after_relu;
|
||||
auto weight = orig_weight_.int_repr();
|
||||
if (bias_.has_value()) {
|
||||
if (bias.has_value()) {
|
||||
// the input bias is a 1-D tensor whose size is the same as the size of the second dimension of quantized_output.
|
||||
// we need to add trailing dimensions in order to properly broadcast bias, otherwise broadcast_to will fail.
|
||||
// the number of trailling dimensions is quantized_output.dim() - 2, so the new size of the broadcast_bias
|
||||
// becomes quantized_output.dim() - 2 + 1. nothing needs to be done for the leading dimensions
|
||||
std::vector<int64_t> new_size(quantized_output.dim() - 1, 1);
|
||||
new_size[0] = bias_.value().size(0);
|
||||
broadcasted_bias = bias_.value().reshape(new_size);
|
||||
new_size[0] = bias.value().size(0);
|
||||
broadcasted_bias = bias.value().reshape(new_size);
|
||||
broadcasted_bias.value() = broadcasted_bias.value().broadcast_to(quantized_output.sizes());
|
||||
broadcasted_bias.value() = broadcasted_bias.value().contiguous(c10::MemoryFormat::ChannelsLast);
|
||||
bias_multiplier_tensor = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
|
|
@ -233,21 +247,16 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
after_relu = at::empty(quantized_output.sizes(), at::device(at::kCUDA).dtype(at::kFloat), at::MemoryFormat::ChannelsLast);
|
||||
}
|
||||
|
||||
cudnnHandle_t handle = at::native::getCudnnHandle();
|
||||
cudnnHandle_t handle = getCudnnHandle();
|
||||
CacheKey key;
|
||||
bool deterministic{true};
|
||||
bool allow_tf32{false};
|
||||
auto padding_vec = padding_.vec();
|
||||
auto stride_vec = stride_.vec();
|
||||
auto dilation_vec = dilation_.vec();
|
||||
setConvolutionParams(&key.params, input, weight, padding_vec, stride_vec, dilation_vec, groups_, deterministic, allow_tf32);
|
||||
setConvolutionParams(&key.params, input, weight, padding, stride, dilation, groups, deterministic, allow_tf32);
|
||||
// operator datatype needs to be int32 for int8 convolution, but we can
|
||||
// set the datatype for output tensor to int32 or fp32
|
||||
key.params.dataType = CUDNN_DATA_INT32;
|
||||
key.input_alignment = getAlignment(input);
|
||||
key.output_alignment = getAlignment(conv_output);
|
||||
key.weight_alignment = getAlignment(weight);
|
||||
if (bias_.has_value()) {
|
||||
if (bias.has_value()) {
|
||||
key.bias_alignment = getAlignment(broadcasted_bias.value());
|
||||
} else {
|
||||
key.bias_alignment = -1;
|
||||
|
|
@ -255,7 +264,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
|
||||
auto run = [&](cudnn_frontend::ManagedOpaqueDescriptor plan_desc) {
|
||||
auto workspace_size = 0;
|
||||
auto workspace = at::empty({workspace_size}, input.options().dtype(at::kByte));
|
||||
auto workspace = at::empty({workspace_size}, input.options().dtype(kByte));
|
||||
std::vector<void *> data_ptrs;
|
||||
std::vector<int64_t> uids;
|
||||
data_ptrs.reserve(10);
|
||||
|
|
@ -265,7 +274,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
requantize_multiplier_tensor.data_ptr(),
|
||||
reinterpret_cast<int8_t*>(quantized_output.data_ptr())};
|
||||
uids = {'x', 'y', 'w', 's', 'r'};
|
||||
if (bias_.has_value()) {
|
||||
if (bias.has_value()) {
|
||||
data_ptrs.insert(data_ptrs.end(), {broadcasted_bias.value().data_ptr(), bias_multiplier_tensor.value().data_ptr(),
|
||||
after_scales_bias.value().data_ptr(), after_add.value().data_ptr()});
|
||||
uids.insert(uids.end(), {'b', 'c', 'd', 'e'});
|
||||
|
|
@ -301,13 +310,13 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
.setxDesc(getTensorDescriptor(input, 'x', key.input_alignment))
|
||||
.setyDesc(getTensorDescriptor(conv_output, 'y', key.output_alignment))
|
||||
.setwDesc(getTensorDescriptor(weight, 'w', key.weight_alignment))
|
||||
.setcDesc(getConvDescriptor(key.params.dataType, padding_vec, stride_vec, dilation_vec))
|
||||
.setcDesc(getConvDescriptor(key.params.dataType, padding, stride, dilation))
|
||||
.build();
|
||||
// std::cout << "operator:" << conv_op.describe() << std::endl;
|
||||
|
||||
c10::optional<cudnn_frontend::Operation> bias_mult_op;
|
||||
c10::optional<cudnn_frontend::Operation> sum_conv_bias_op;
|
||||
if (bias_.has_value()) {
|
||||
if (bias.has_value()) {
|
||||
// we can't directly assign bias_mult_op becauase operator= is deleted for cudnn_frontend::Operation;
|
||||
// alternatively, I think we can use std::unique_ptr and dynamically allocate these builder ops
|
||||
// but here, we chose to do it statically. c10::optional<T>::emplace() enables this approach
|
||||
|
|
@ -320,7 +329,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
.setxDesc(getTensorDescriptor(broadcasted_bias.value(), 'b', getAlignment(broadcasted_bias.value())))
|
||||
.setbDesc(getTensorDescriptor(bias_multiplier_tensor.value(), 'c', getAlignment(bias_multiplier_tensor.value())))
|
||||
.setyDesc(getTensorDescriptor(after_scales_bias.value(), 'd', getAlignment(after_scales_bias.value())))
|
||||
.setpwDesc(getPointWiseMulDescriptor(at::native::getCudnnDataType(bias_multiplier_tensor.value())))
|
||||
.setpwDesc(getPointWiseMulDescriptor(getCudnnDataType(bias_multiplier_tensor.value())))
|
||||
.build());
|
||||
|
||||
// TODO: can we assign the result back into conv_output and get rid of after_add?
|
||||
|
|
@ -332,7 +341,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
.setxDesc(conv_op.getOutputTensor())
|
||||
.setbDesc(getTensorDescriptor(after_scales_bias.value(), 'd', getAlignment(after_scales_bias.value())))
|
||||
.setyDesc(getTensorDescriptor(after_add.value(), 'e', getAlignment(after_add.value())))
|
||||
.setpwDesc(getPointWiseAddDescriptor(at::native::getCudnnDataType(after_scales_bias.value())))
|
||||
.setpwDesc(getPointWiseAddDescriptor(getCudnnDataType(after_scales_bias.value())))
|
||||
.build());
|
||||
}
|
||||
|
||||
|
|
@ -340,13 +349,13 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
// or relu(act_int8 * w_int8) if bias is not present.
|
||||
// output is a fp32 tensor
|
||||
c10::optional<cudnn_frontend::Operation> relu_op;
|
||||
std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias_.has_value() ? sum_conv_bias_op.value().getOutputTensor() : conv_op.getOutputTensor();
|
||||
std::shared_ptr<cudnn_frontend::OpaqueBackendPointer> tensor2requant_ptr = bias.has_value() ? sum_conv_bias_op.value().getOutputTensor() : conv_op.getOutputTensor();
|
||||
if (kReluFused) {
|
||||
// TODO: can we assign the result back into conv_output and get rid of after_relu?
|
||||
relu_op.emplace(cudnn_frontend::OperationBuilder(CUDNN_BACKEND_OPERATION_POINTWISE_DESCRIPTOR)
|
||||
.setxDesc(tensor2requant_ptr)
|
||||
.setyDesc(getTensorDescriptor(after_relu.value(), 'f', getAlignment(after_relu.value())))
|
||||
.setpwDesc(getPointWiseReluDescriptor(at::native::getCudnnDataType(after_relu.value())))
|
||||
.setpwDesc(getPointWiseReluDescriptor(getCudnnDataType(after_relu.value())))
|
||||
.build());
|
||||
}
|
||||
|
||||
|
|
@ -357,12 +366,12 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
.setxDesc(kReluFused ? relu_op.value().getOutputTensor() : tensor2requant_ptr)
|
||||
.setbDesc(getTensorDescriptor(requantize_multiplier_tensor, 's', getAlignment(requantize_multiplier_tensor)))
|
||||
.setyDesc(getTensorDescriptor(quantized_output.sizes(), quantized_output.strides(), CUDNN_DATA_INT8, 'r', getAlignment(quantized_output)))
|
||||
.setpwDesc(getPointWiseMulDescriptor(at::native::getCudnnDataType(requantize_multiplier_tensor)))
|
||||
.setpwDesc(getPointWiseMulDescriptor(getCudnnDataType(requantize_multiplier_tensor)))
|
||||
.build();
|
||||
// std::cout << "operator:" << requant_op.describe() << std::endl;
|
||||
|
||||
std::vector<cudnn_frontend::Operation const *> ops{&conv_op};
|
||||
if (bias_.has_value()) {
|
||||
if (bias.has_value()) {
|
||||
ops.emplace_back(&(bias_mult_op.value()));
|
||||
ops.emplace_back(&(sum_conv_bias_op.value()));
|
||||
}
|
||||
|
|
@ -403,7 +412,7 @@ void PackedConvWeightCudnn<kSpatialDim>::apply_impl_helper(const at::Tensor& qua
|
|||
run(plan_desc);
|
||||
execution_plan_cache[key] = plan_desc;
|
||||
return;
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(c10::CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
} catch (cudnn_frontend::cudnnException &e) {std::cout << "cudnn error:" << e.what() << std::endl;} catch(CuDNNError &e) { std::cout << "other error" << e.what() << std::endl;}
|
||||
}
|
||||
|
||||
TORCH_CHECK(false, "Unable to find an engine to execute this computation");
|
||||
|
|
@ -427,96 +436,94 @@ out_int8 = (act_fp32 * w_fp32 + [bias_fp32]) / out_scale + out_zero_point
|
|||
= (act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]) / (out_scale / (act_scale * w_scale))
|
||||
= requantize((act_int8 * w_int8 + [bias_fp32/(act_scale * w_scale)]), out_scale / (act_scale * w_scale))
|
||||
*/
|
||||
template <int kSpatialDim>
|
||||
template <bool kReluFused>
|
||||
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_impl(
|
||||
const at::Tensor& act,
|
||||
template <int kSpatialDim, bool kReluFused>
|
||||
Tensor raw_cudnn_convolution_forward(
|
||||
const Tensor& act,
|
||||
const Tensor& weight,
|
||||
c10::optional<Tensor> bias,
|
||||
IntArrayRef padding,
|
||||
IntArrayRef stride,
|
||||
IntArrayRef dilation,
|
||||
int64_t groups,
|
||||
bool benchmark,
|
||||
bool deterministic,
|
||||
bool allow_tf32,
|
||||
float bias_multiplier,
|
||||
float requantize_multiplier,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
// TODO: add dimension validations for input/weight/bias
|
||||
const int N = act.size(0);
|
||||
const int D = kSpatialDim == 3 ? act.size(2) : 1;
|
||||
const int H = act.size(kSpatialDim);
|
||||
const int W = act.size(kSpatialDim + 1);
|
||||
const int M = orig_weight_.size(0); // output channels
|
||||
std::vector<int64_t> kernel_size = {orig_weight_.size(2), orig_weight_.size(3)};
|
||||
at::SmallVector<int64_t, kSpatialDim + 2> output_shape = MakeConvOutputShape<kSpatialDim>(N, M, {H, W},
|
||||
kernel_size, stride_, padding_, dilation_);
|
||||
at::Tensor quantized_output = at::_empty_affine_quantized(
|
||||
const int M = weight.size(0); // output channels
|
||||
std::vector<int64_t> kernel_size = {weight.size(2), weight.size(3)};
|
||||
at::SmallVector<int64_t, kSpatialDim + 2> output_shape{MakeConvOutputShape<kSpatialDim>(N, M, {H, W},
|
||||
kernel_size, stride, padding, dilation)};
|
||||
Tensor quantized_output = at::_empty_affine_quantized(
|
||||
output_shape,
|
||||
at::device(at::kCUDA).dtype(at::ScalarType::QInt8),
|
||||
at::device(at::kCUDA).dtype(ScalarType::QInt8),
|
||||
output_scale,
|
||||
output_zero_point,
|
||||
at::MemoryFormat::ChannelsLast);
|
||||
// requantization
|
||||
// out_int8 = act_int8 * weight_int8 * act_scale * w_scale / output_scale
|
||||
// TODO: note we will remove the int_repr() in a subsequent PR, so we can move the computations for
|
||||
// the multipliers into the helper function
|
||||
auto act_scale = act.q_scale();
|
||||
auto weight_scale = orig_weight_.q_scale();
|
||||
auto requantize_multiplier = act_scale * weight_scale / output_scale;
|
||||
auto bias_multiplier = 1.0 / (act_scale * weight_scale);
|
||||
apply_impl_helper<kReluFused>(
|
||||
quantized_output, act.int_repr(), bias_multiplier, requantize_multiplier);
|
||||
raw_cudnn_convolution_forward_out<kReluFused>(
|
||||
quantized_output, act, weight, bias,
|
||||
padding, stride, dilation, groups,
|
||||
benchmark,
|
||||
deterministic,
|
||||
allow_tf32,
|
||||
bias_multiplier,
|
||||
requantize_multiplier);
|
||||
|
||||
return quantized_output;
|
||||
}
|
||||
|
||||
template <int kSpatialDim>
|
||||
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply(
|
||||
const at::Tensor& input,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
return apply_impl<false>(input, output_scale, output_zero_point);
|
||||
}
|
||||
|
||||
template <int kSpatialDim>
|
||||
at::Tensor PackedConvWeightCudnn<kSpatialDim>::apply_relu(
|
||||
const at::Tensor& input,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
return apply_impl<true>(input, output_scale, output_zero_point);
|
||||
}
|
||||
|
||||
template at::Tensor PackedConvWeightCudnn<2>::apply(
|
||||
const at::Tensor& act,
|
||||
double output_scale,
|
||||
int64_t output_zero_point);
|
||||
|
||||
template at::Tensor PackedConvWeightCudnn<2>::apply_relu(
|
||||
const at::Tensor& act,
|
||||
double output_scale,
|
||||
int64_t output_zero_point);
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
|
||||
template <int kSpatialDim, bool kReluFused>
|
||||
class QConvInt8 final {
|
||||
public:
|
||||
static at::Tensor run(
|
||||
at::Tensor act,
|
||||
const c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>>& packed_weight,
|
||||
static Tensor run(
|
||||
Tensor act,
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) {
|
||||
act = act.contiguous(c10::MemoryFormat::ChannelsLast);
|
||||
weight = weight.contiguous(c10::MemoryFormat::ChannelsLast);
|
||||
// requantization
|
||||
// out_int8 = act_int8 * weight_int8 * act_scale * w_scale / output_scale
|
||||
auto act_scale = act.q_scale();
|
||||
auto weight_scale = weight.q_scale();
|
||||
auto requantize_multiplier = act_scale * weight_scale / output_scale;
|
||||
auto bias_multiplier = 1.0 / (act_scale * weight_scale);
|
||||
|
||||
// TODO: check all zero_points are zero/all tensors are symmetrically quantized
|
||||
if (kReluFused) {
|
||||
return packed_weight->apply_relu(act, output_scale, output_zero_point);
|
||||
} else {
|
||||
return packed_weight->apply(act, output_scale, output_zero_point);
|
||||
}
|
||||
return raw_cudnn_convolution_forward<kSpatialDim, kReluFused>(
|
||||
act.int_repr(), weight.int_repr(), bias,
|
||||
IntArrayRef(padding.vec()), IntArrayRef(stride.vec()), IntArrayRef(dilation.vec()), groups,
|
||||
false /* benchmark */,
|
||||
true /* deterministic */,
|
||||
false /* allow_tf32 */,
|
||||
bias_multiplier,
|
||||
requantize_multiplier,
|
||||
output_scale,
|
||||
output_zero_point
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d.new"), QConvInt8<2, false>::run);
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu.new"), QConvInt8<2, true>::run);
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_cudnn"), QConvInt8<2, false>::run);
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_relu_cudnn"), QConvInt8<2, true>::run);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
}} // at::native
|
||||
|
||||
#endif // HAS_CUDNN_V8
|
||||
#endif // AT_CUDNN_ENABLED
|
||||
|
|
|
|||
|
|
@ -1,151 +0,0 @@
|
|||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
|
||||
#include <ATen/native/cudnn/Macros.h>
|
||||
|
||||
#if HAS_CUDNN_V8()
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/library.h>
|
||||
#include <ATen/native/quantized/cudnn/cudnnpack_utils.h>
|
||||
#include <ATen/native/quantized/packed_params.h>
|
||||
#include <ATen/quantized/Quantizer.h>
|
||||
#include <c10/core/QScheme.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
template <int kSpatialDim>
|
||||
c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> PackedConvWeightCudnn<
|
||||
kSpatialDim>::
|
||||
prepack(
|
||||
at::Tensor weight,
|
||||
c10::optional<at::Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> output_padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
bool transpose) {
|
||||
TORCH_CHECK(weight.qscheme() == c10::kPerTensorAffine, "Unsupported qscheme: ", toString(weight.qscheme()));
|
||||
TORCH_CHECK(
|
||||
weight.ndimension() == kSpatialDim + 2,
|
||||
"Weights are expected to have ",
|
||||
kSpatialDim + 2,
|
||||
" dimensions");
|
||||
TORCH_CHECK(
|
||||
stride.size() == kSpatialDim,
|
||||
"stride should contain ",
|
||||
kSpatialDim,
|
||||
" elements for ",
|
||||
kSpatialDim,
|
||||
"D convolution.");
|
||||
TORCH_CHECK(
|
||||
padding.size() == kSpatialDim,
|
||||
"quantized::conv_prepack (cudnn): Specify front/top/left padding only. "
|
||||
"end/bottom/right padding assumed to be equal to front/top/left");
|
||||
TORCH_CHECK(
|
||||
!transpose || output_padding.size() == kSpatialDim,
|
||||
"quantized::conv_prepack: Specify top/left output padding "
|
||||
"only. bottom/right padding assumed to be equal to top/left");
|
||||
TORCH_CHECK(
|
||||
dilation.size() == kSpatialDim,
|
||||
"quantized::conv_prepack (cudnn): dilation should contain ",
|
||||
kSpatialDim,
|
||||
" elements for ",
|
||||
kSpatialDim,
|
||||
"D convolution.");
|
||||
const int output_channels = transpose ? weight.size(1) * groups
|
||||
: weight.size(0);
|
||||
const auto qtype = weight.qscheme();
|
||||
if (bias.has_value()) {
|
||||
TORCH_CHECK(bias.value().dim() == 1, "bias should be a vector (1D Tensor)");
|
||||
TORCH_CHECK(
|
||||
bias.value().size(0) == output_channels,
|
||||
"bias should have K elements: " + std::to_string(output_channels));
|
||||
// TODO: we create a broadcasted_bias tensor later so I think we don't need to make this contiguous here.
|
||||
// we will revisit this when nvidia adds proper support for broadcasting
|
||||
// bias_contig = bias->contiguous();
|
||||
}
|
||||
|
||||
auto ret_ptr = c10::make_intrusive<PackedConvWeightCudnn<kSpatialDim>>(
|
||||
weight.contiguous(c10::MemoryFormat::ChannelsLast), // TODO: this assumes 2D I think. make it more general?
|
||||
bias,
|
||||
stride,
|
||||
padding,
|
||||
output_padding,
|
||||
dilation,
|
||||
groups,
|
||||
transpose,
|
||||
qtype);
|
||||
return ret_ptr;
|
||||
}
|
||||
|
||||
template
|
||||
c10::intrusive_ptr<ConvPackedParamsBase<2>> PackedConvWeightCudnn<
|
||||
2>::
|
||||
prepack(
|
||||
at::Tensor weight,
|
||||
c10::optional<at::Tensor> bias_in,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> output_padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
bool transpose);
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace {
|
||||
|
||||
template <int kSpatialDim = 2>
|
||||
class QConvPackWeightInt8Cudnn final {
|
||||
public:
|
||||
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> run_conv(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups) {
|
||||
torch::List<int64_t> output_padding;
|
||||
output_padding.reserve(kSpatialDim);
|
||||
for (const auto idx : c10::irange(kSpatialDim)) {
|
||||
(void)idx; //Suppress unused variable warning
|
||||
output_padding.push_back((int64_t)0);
|
||||
}
|
||||
return _run(weight, bias, stride, padding, output_padding, dilation, groups,
|
||||
/*transpose=*/false);
|
||||
}
|
||||
|
||||
private:
|
||||
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> _run(
|
||||
Tensor weight,
|
||||
c10::optional<Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> output_padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
bool transpose) {
|
||||
return PackedConvWeightCudnn<kSpatialDim>::prepack(
|
||||
weight, bias, stride, padding, output_padding, dilation, groups,
|
||||
transpose);
|
||||
}
|
||||
};
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, QuantizedCUDA, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::conv2d_prepack"), TORCH_FN(QConvPackWeightInt8Cudnn<2>::run_conv));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
#endif // HAS_CUDNN_V8
|
||||
#endif // AT_CUDNN_ENABLED
|
||||
#endif // USE_CUDA
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
|
||||
#include <ATen/native/cudnn/Macros.h>
|
||||
|
||||
#if HAS_CUDNN_V8()
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/native/quantized/cudnn/cudnnpack_utils.h>
|
||||
#include <ATen/native/quantized/packed_params.h>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include <tuple>
|
||||
|
||||
template <int kSpatialDim>
|
||||
std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
|
||||
kSpatialDim>::unpack() {
|
||||
return std::tuple<at::Tensor, c10::optional<at::Tensor>>{orig_weight_, bias_};
|
||||
}
|
||||
|
||||
template std::tuple<at::Tensor, c10::optional<at::Tensor>> PackedConvWeightCudnn<
|
||||
2>::unpack();
|
||||
|
||||
#endif // HAS_CUDNN_V8
|
||||
#endif // AT_CUDNN_ENABLED
|
||||
#endif // USE_CUDA
|
||||
|
|
@ -1,125 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <ATen/cuda/CUDAConfig.h> // for the definition of AT_CUDNN_ENABLED
|
||||
|
||||
#if AT_CUDNN_ENABLED()
|
||||
|
||||
#include <ATen/native/cudnn/Macros.h>
|
||||
|
||||
#if HAS_CUDNN_V8()
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/native/quantized/packed_params.h>
|
||||
#include <c10/core/QScheme.h>
|
||||
|
||||
template <int kSpatialDim = 2>
|
||||
struct TORCH_API PackedConvWeightCudnn : public ConvPackedParamsBase<kSpatialDim> {
|
||||
PackedConvWeightCudnn(
|
||||
at::Tensor orig_weight,
|
||||
c10::optional<at::Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> output_padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
bool transpose,
|
||||
c10::QScheme q_scheme)
|
||||
: orig_weight_(std::move(orig_weight)),
|
||||
bias_(std::move(bias)),
|
||||
stride_(std::move(stride)),
|
||||
padding_(std::move(padding)),
|
||||
output_padding_(std::move(output_padding)),
|
||||
dilation_(std::move(dilation)),
|
||||
groups_(groups),
|
||||
transpose_(transpose),
|
||||
q_scheme_(q_scheme) {}
|
||||
|
||||
at::Tensor apply(
|
||||
const at::Tensor& input,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) override;
|
||||
|
||||
at::Tensor apply_relu(
|
||||
const at::Tensor& input,
|
||||
double output_scale,
|
||||
int64_t output_zero_point) override;
|
||||
|
||||
at::Tensor apply_dynamic(
|
||||
const at::Tensor& input,
|
||||
bool reduce_range) {
|
||||
TORCH_CHECK(false, "apply_dynamic is currently not reported");
|
||||
}
|
||||
|
||||
at::Tensor apply_dynamic_relu(
|
||||
const at::Tensor& input,
|
||||
bool reduce_range) {
|
||||
TORCH_CHECK(false, "apply_dynamic_relu is currently not reported");
|
||||
}
|
||||
|
||||
std::tuple<at::Tensor, c10::optional<at::Tensor>> unpack() override;
|
||||
|
||||
static c10::intrusive_ptr<ConvPackedParamsBase<kSpatialDim>> prepack(
|
||||
at::Tensor weight,
|
||||
c10::optional<at::Tensor> bias,
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> output_padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
bool transpose);
|
||||
|
||||
const float* GetBiasData(at::Tensor* bias);
|
||||
|
||||
torch::List<int64_t> stride() const override {
|
||||
return stride_;
|
||||
}
|
||||
|
||||
torch::List<int64_t> padding() const override {
|
||||
return padding_;
|
||||
}
|
||||
|
||||
torch::List<int64_t> output_padding() const override {
|
||||
return output_padding_;
|
||||
}
|
||||
|
||||
torch::List<int64_t> dilation() const override {
|
||||
return dilation_;
|
||||
}
|
||||
|
||||
int64_t groups() const override {
|
||||
return groups_;
|
||||
}
|
||||
|
||||
bool transpose() const override {
|
||||
return transpose_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor orig_weight_;
|
||||
c10::optional<at::Tensor> bias_;
|
||||
torch::List<int64_t> stride_;
|
||||
torch::List<int64_t> padding_;
|
||||
torch::List<int64_t> output_padding_;
|
||||
torch::List<int64_t> dilation_;
|
||||
int64_t groups_;
|
||||
bool transpose_;
|
||||
c10::QScheme q_scheme_;
|
||||
|
||||
template <bool ReluFused>
|
||||
at::Tensor apply_impl(
|
||||
const at::Tensor& input,
|
||||
double output_scale,
|
||||
int64_t output_zero_point);
|
||||
|
||||
template <bool ReluFused>
|
||||
void apply_impl_helper(
|
||||
const at::Tensor& quantized_output,
|
||||
const at::Tensor& input,
|
||||
double bias_multiplier,
|
||||
double requantize_multiplier);
|
||||
};
|
||||
|
||||
#endif // HAS_CUDNN_V8
|
||||
#endif // AT_CUDNN_ENABLED
|
||||
#endif // USE_CUDA
|
||||
|
|
@ -188,6 +188,11 @@ TORCH_LIBRARY(quantized, m) {
|
|||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::leaky_relu(Tensor qx, Scalar negative_slope, bool inplace, float output_scale, int output_zero_point) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::sigmoid(Tensor qx, float output_scale, int output_zero_point) -> Tensor"));
|
||||
|
||||
// quantized ops implemented in cudnn, with QuantizedCUDA dispatch
|
||||
// TODO: use the same signature as quantized::conv2d
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_cudnn(Tensor act, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::conv2d_relu_cudnn(Tensor act, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, int groups, float output_scale, int output_zero_point) -> Tensor"));
|
||||
}
|
||||
|
||||
// According to #33294: The "_" prefix registration will be
|
||||
|
|
|
|||
|
|
@ -4225,9 +4225,9 @@ class TestQuantizedConv(TestCase):
|
|||
dilations = (dilation, dilation)
|
||||
|
||||
if use_relu:
|
||||
qconv = torch.ops.quantized.conv2d_relu
|
||||
qconv = torch.ops.quantized.conv2d_relu_cudnn
|
||||
else:
|
||||
qconv = torch.ops.quantized.conv2d
|
||||
qconv = torch.ops.quantized.conv2d_cudnn
|
||||
conv_op = torch.nn.Conv2d(
|
||||
input_channels,
|
||||
output_channels,
|
||||
|
|
@ -4238,7 +4238,7 @@ class TestQuantizedConv(TestCase):
|
|||
groups,
|
||||
).to(torch.device("cuda"))
|
||||
self._test_qconv_impl(
|
||||
qconv, torch.ops.quantized.conv2d_prepack, conv_op, batch_size,
|
||||
qconv, None, conv_op, batch_size,
|
||||
input_channels_per_group, (height, width),
|
||||
output_channels_per_group, groups, kernels, strides, pads, None,
|
||||
dilations, X_scale, X_zero_point, W_scale, W_zero_point,
|
||||
|
|
@ -4314,14 +4314,13 @@ class TestQuantizedConv(TestCase):
|
|||
weight_int8 = torch.quantize_per_tensor(weight, 1, 0, torch.qint8).contiguous(memory_format=torch.channels_last)
|
||||
scale = 1.0
|
||||
zero_point = 0
|
||||
conv_op = torch.ops.quantized.conv2d
|
||||
weight_prepacked = torch.ops.quantized.conv2d_prepack(weight_int8, None, stride, padding, dilation, groups)
|
||||
conv_op = torch.ops.quantized.conv2d_cudnn
|
||||
with profile(
|
||||
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
||||
schedule=my_schedule,
|
||||
on_trace_ready=trace_handler) as prof:
|
||||
for i in range(30):
|
||||
conv_op(input_int8, weight_prepacked, scale, zero_point)
|
||||
conv_op(input_int8, weight_int8, None, stride, padding, dilation, groups, scale, zero_point)
|
||||
prof.step()
|
||||
|
||||
print("int8 benchmark result:")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user