Remove deprecated fbgemm operators (#104535)

These operators are not used and have been deprecated since #72690 (Feb 2022). Additionally, the `torch.jit.quantized` interface has been deprecated since #40102 (June 2020).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104535
Approved by: https://github.com/ezyang
This commit is contained in:
Peter Bell 2023-10-16 15:09:17 -07:00 committed by PyTorch MergeBot
parent bf01a7b023
commit 57c7aa12db
12 changed files with 55 additions and 1808 deletions

View File

@ -1,585 +0,0 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <vector>
#include <ATen/core/Tensor.h>
#include <ATen/Parallel.h>
#include <ATen/WrapDimUtilsMulti.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/PackedParams.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
#include <ATen/NativeFunctions.h>
#else
#include <ATen/ops/empty.h>
#include <ATen/ops/empty_like_native.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
#include <ATen/ops/fbgemm_linear_int8_weight_native.h>
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
#endif
#include <c10/util/irange.h>
#ifdef USE_FBGEMM
#include <fbgemm/Fbgemm.h>
#include <fbgemm/FbgemmFP16.h>
#include <fbgemm/QuantUtils.h>
#endif // USE_FBGEMM
namespace caffe2 {
CAFFE_KNOWN_TYPE(c10::intrusive_ptr<LinearPackedParamsBase>);
} // namespace caffe2
#ifdef USE_FBGEMM
namespace caffe2 {
// Required for cpp_custom_type_hack to work
CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
CAFFE_KNOWN_TYPE(c10::intrusive_ptr<PackedLinearWeightFp16>);
} // namespace caffe2
#endif // USE_FBGEMM
namespace at {
namespace native {
#ifdef USE_FBGEMM
Tensor fbgemm_linear_int8_weight_fp32_activation(
const Tensor& input,
const Tensor& weight,
const Tensor& packed,
const Tensor& col_offsets,
const Scalar& weight_scale,
const Scalar& weight_zero_point,
const Tensor& bias) {
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")
const Tensor input_contig = input.contiguous();
const float* input_ptr = input_contig.data_ptr<float>();
TORCH_CHECK(input.dim() >= 2);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
const int64_t K = input.size(input.dim() - 1);
TORCH_CHECK(weight.dim() == 2);
TORCH_CHECK(K == weight.size(1));
const int64_t N = weight.size(0);
TORCH_CHECK(bias.dim() == 1);
TORCH_CHECK(bias.size(0) == N);
TORCH_CHECK(weight_scale.isFloatingPoint());
TORCH_CHECK(weight_zero_point.isIntegral(false));
// Calculate statistics for quantization of the input Tensor
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_min;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float x_max;
fbgemm::FindMinMax(
/*m=*/input_ptr,
/*min=*/&x_min,
/*max=*/&x_max,
/*len=*/input.numel());
// Input tensor is quantized as 8-bit unsigned values
constexpr int kPrecision = 8;
constexpr bool kIsSigned = false;
constexpr int kBound = (1 << (kPrecision - 1));
// Calculate scale and zero point for quantization of input tensor
auto q_params = fbgemm::ChooseQuantizationParams(
/*min=*/x_min,
/*max=*/x_max,
/*qmin=*/kIsSigned ? -kBound : 0,
/*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
/*preserve_sparsity=*/false);
q_params.precision = kPrecision;
// ReQuantizeForFloat requires pointers to the scale and zero point values,
// since in the case of rowwise quantization these will be arrays rather than
// scalars. But in this case, we're doing whole-tensor quantization so we just
// pass a pointer to the scale values (and internally ReQuantizeFor Float
// won't index past 0
const float weight_scale_float =
static_cast<float>(weight_scale.to<double>());
const int32_t weight_zero_point_int32 =
static_cast<int32_t>(weight_zero_point.to<int64_t>());
const Tensor bias_contig = bias.contiguous();
// Allocate output Tensor and a buffer for fbgemmPacked to use
std::vector<int64_t> output_size = input.sizes().vec();
output_size.back() = N;
Tensor output = at::empty(output_size, input.options().dtype(at::kFloat), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
Tensor buffer = at::empty(output_size, input.options().dtype(at::kInt), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// Pull out the PackBMatrix instance from the owning tensor
auto& pack_b =
cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
const int num_tasks = at::get_num_threads();
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
// This operation does the following:
// 1) Quantizes the input matrix given the statistics we've calculated
// above.
// 2) Creates a "row buffer" vector with offset values that must be added
// to the integer matrix multiplication operation to ensure correctness.
// 3) Packs the resulting quantized matrix into vector-register and cache
// friendly tiles.
//
// Note this is not executed eagerly, but rather within the fbgemmPacked
// call below.
fbgemm::PackAWithQuantRowOffset<uint8_t> pack_a(
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
/*nRow=*/M,
/*nCol=*/K,
/*smat=*/input_ptr,
/*ld=*/K,
/*pmat=*/nullptr, // pack_a manages ownership of `pmat`
/*scale=*/q_params.scale,
/*zero_pt=*/q_params.zero_point);
// This is the end of the pipeline, pass the resulting matrix through
fbgemm::DoNothing<float, float> kDoNothingObj{};
for (const auto task_id : c10::irange(begin, end)) {
// After the uint8 * int8 matrix multiplication is performed, this
// operation does:
// 1) Add in row and column offsets to the rows and columns, respectively
// 2) Dequantize the results into floating point
// 3) Add in the bias term
fbgemm::ReQuantizeForFloat</* FUSE_RELU */ false> output_proc_obj(
/*nextop=*/kDoNothingObj,
/*Aq_scale=*/q_params.scale,
/*Bq_scale=*/&weight_scale_float,
/*Aq_zero_point=*/q_params.zero_point,
/*Bq_zero_point=*/&weight_zero_point_int32,
/*row_offsets=*/pack_a.getRowOffsetBuffer(),
/*col_offsets=*/col_offsets.data_ptr<int32_t>(),
/*bias=*/bias_contig.data_ptr<float>(),
/*nCol=*/N);
// Do the GEMM
fbgemm::fbgemmPacked(
/*packA=*/pack_a,
/*packB=*/pack_b,
/*C=*/output.data_ptr<float>(),
/*C_buffer=*/buffer.data_ptr<int32_t>(),
/*ldc=*/N,
/*outProcess=*/output_proc_obj,
/*thread_id=*/task_id,
/*num_threads=*/num_tasks);
}
});
return output;
}
Tensor fbgemm_linear_int8_weight(
const Tensor& input,
const Tensor& weight,
const Tensor& packed,
const Tensor& col_offsets,
const Scalar& weight_scale,
const Scalar& weight_zero_point,
const Tensor& bias) {
return at::native::fbgemm_linear_int8_weight_fp32_activation(
input,
weight,
packed,
col_offsets,
weight_scale,
weight_zero_point,
bias);
}
namespace {
// Calculate the column offsets
// Note this includes the sum of the columns as well as the scalar term
// B_zero_point * K, whereas the row_offsets created by
// PackAWithQuantRowOffset is only the sum of the A rows.
void CalcColOffsetsTranspose(
int K,
int N,
const int8_t* Bint8,
int32_t B_zero_point,
int32_t* col_offsets) {
for (const auto i : c10::irange(N)) {
int32_t sum = 0;
for (const auto j : c10::irange(K)) {
sum += Bint8[i * K + j];
}
col_offsets[i] = sum - B_zero_point * K;
}
}
} // namespace
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
const Tensor& weight) {
TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
const Tensor weight_contig = weight.contiguous();
// Calculate weight statistics
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float w_min;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
float w_max;
fbgemm::FindMinMax(
/*m=*/weight_contig.data_ptr<float>(),
/*min=*/&w_min,
/*max=*/&w_max,
/*len=*/weight_contig.numel());
// Choose parameters for quantizing the weight as 8-bit signed integer
constexpr bool kIsSigned = true;
constexpr int kPrecision = 8;
constexpr int kBound = (1 << (kPrecision - 1));
auto q_params = fbgemm::ChooseQuantizationParams(
/*min=*/w_min,
/*max=*/w_max,
/*qmin=*/kIsSigned ? -kBound : 0,
/*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
/*preserve_sparsity=*/false);
q_params.precision = kPrecision;
Tensor quantized = at::native::empty_like(
weight_contig,
at::kChar,
weight_contig.options().layout_opt(),
weight_contig.options().device_opt(),
weight_contig.options().pinned_memory_opt(),
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
// Tensor quantized = at::native::empty_cpu(
// weight_contig.sizes(), weight_contig.options().dtype(at::kChar));
fbgemm::Quantize<int8_t, false /*LEGACY*/>(
/*src=*/weight_contig.data_ptr<float>(),
/*dst=*/quantized.data_ptr<int8_t>(),
/*len=*/weight_contig.numel(),
/*qparams=*/q_params);
// Calculate column offsets of the weight and store them away in a tensor.
// Similarly to quantization, this can be done once and cached.
Tensor col_offsets = at::empty(
{weight_contig.size(0)},
at::kInt,
weight_contig.options().layout_opt(),
weight_contig.options().device_opt(),
weight_contig.options().pinned_memory_opt(),
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
CalcColOffsetsTranspose(
/*K=*/quantized.size(1),
/*N=*/quantized.size(0),
/*Bint8=*/quantized.data_ptr<int8_t>(),
/*B_zero_point=*/q_params.zero_point,
/*col_offsets=*/col_offsets.data_ptr<int32_t>());
return std::make_tuple(
quantized, col_offsets, q_params.scale, q_params.zero_point);
}
Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
const int64_t K = weight.size(1);
const int64_t N = weight.size(0);
const Tensor weight_contig = weight.contiguous();
const int8_t* weight_ptr = weight_contig.data_ptr<int8_t>();
auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
/*trans=*/fbgemm::matrix_op_t::Transpose,
/*nRow=*/K,
/*nCol=*/N,
/*smat=*/weight_ptr,
/*ld=*/K,
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
/*groups=*/1);
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
}
Tensor fbgemm_pack_quantized_matrix(
const Tensor& weight,
int64_t K,
int64_t N) {
// Replace after https://github.com/pytorch/pytorch/issues/24354 is fixed
// TORCH_WARN(
// "fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon."
// "Please use fbgemm_pack_quantized_matrix(weight) instead.");
return at::native::fbgemm_pack_quantized_matrix(weight);
}
namespace {
float RawUint16ToFp16(unsigned short value) {
// Convert raw 16 bits half precision floating point number
// to single precision floating point number.
const unsigned short sign_bits = value >> 15;
const unsigned short exponent_bits = value >> 10 & 0x1f;
const unsigned short significand_bits = value & 0x3ff;
const float sign = sign_bits ? -1 : 1;
const float significand =
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const float exponent = exponent_bits - 0xf;
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
return sign * std::ldexp(significand, exponent);
}
template <typename T>
bool CheckAndSaturate(T max_val, T* element) {
if (*element > max_val) {
*element = max_val;
return true;
}
if (*element < -max_val) {
*element = -max_val;
return true;
}
return false;
}
// The range for using FP16 quantization of weights requires that the elements
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
// number will be saturated to max or min representable values by FP16.
void HandleWeightsSaturation(int64_t N, float* weight) {
const float kFp16Max = RawUint16ToFp16(0x7BFF);
bool found_out_of_range = false;
for (const auto i : c10::irange(N)) {
if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
found_out_of_range = true;
}
}
if (found_out_of_range) {
TORCH_WARN("FOUND weight out of range ");
}
}
} // namespace
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
const int64_t K = weight.size(1);
const int64_t N = weight.size(0);
Tensor weight_contig = weight.contiguous();
float* weight_contig_ptr = weight_contig.data_ptr<float>();
HandleWeightsSaturation(K * N, weight_contig_ptr);
// TODO(mingzhe09088):
// Consider using a functor here in PackedGemmMatrixFP16
// Comments from (XQ): Not entirely sure this make_unique is safe. make_unique
// is created with regular "new", and freed through TypeMetaData::deleteFn in
// this function. This is perfectly fine if the tensors are created and freed
// within this translation unit. It might be very problematic if that tensor
// flows across dll boundaries.
auto ptr = std::make_unique<fbgemm::PackedGemmMatrixFP16>(
fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr);
c10::intrusive_ptr<LinearPackedParamsBase> packed_weight =
c10::make_intrusive<PackedLinearWeightFp16>(std::move(ptr), c10::nullopt);
auto unique_ptr_wrapper =
std::make_unique<decltype(packed_weight)>(std::move(packed_weight));
return cpp_custom_type_hack::create(
std::move(unique_ptr_wrapper), weight.options());
}
Tensor fbgemm_linear_fp16_weight_fp32_activation(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
const Tensor input_contig = input.contiguous();
const float* input_ptr = input_contig.data_ptr<float>();
// Pull out the PackedGemmMatrixFP16 instance from the owning tensor
const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
*c10::dynamic_intrusive_pointer_cast<PackedLinearWeightFp16>(
cpp_custom_type_hack::cast<
c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight))
->w;
TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
TORCH_CHECK(input.dim() >= 2);
TORCH_CHECK(bias.dim() == 1);
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
const int64_t N = packed_weight_fp16.numCols();
std::vector<int64_t> output_size = input.sizes().vec();
output_size.back() = N;
Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
// Call the fp16 gemm interface
fbgemm::cblas_gemm_compute(
fbgemm::matrix_op_t::NoTranspose,
M,
input_ptr,
packed_weight_fp16,
0.0f,
output.data_ptr<float>());
// Add bias term
output.add_(bias);
return output;
}
Tensor fbgemm_linear_fp16_weight(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
return at::native::fbgemm_linear_fp16_weight_fp32_activation(
input, packed_weight, bias);
}
#else // USE_FBGEMM
Tensor fbgemm_linear_int8_weight_fp32_activation(
const Tensor& /*input*/,
const Tensor& /*weight*/,
const Tensor& /*packed*/,
const Tensor& /*col_offsets*/,
const Scalar& /*weight_scale*/,
const Scalar& /*weight_zero_point*/,
const Tensor& /*bias*/) {
TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_linear_int8_weight(
const Tensor& /*input*/,
const Tensor& /*weight*/,
const Tensor& /*packed*/,
const Tensor& /*col_offsets*/,
const Scalar& /*weight_scale*/,
const Scalar& /*weight_zero_point*/,
const Tensor& /*bias*/) {
TORCH_WARN_ONCE("fbgemm_linear_int8_weight is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
const Tensor& /*weight*/) {
TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_pack_quantized_matrix(const Tensor& /*input*/) {
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_pack_quantized_matrix(
const Tensor& /*input*/,
int64_t /*K*/,
int64_t /*N*/) {
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_linear_fp16_weight_fp32_activation(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
Tensor fbgemm_linear_fp16_weight(
const Tensor& input,
const Tensor& packed_weight,
const Tensor& bias) {
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight is deprecated "
"and will be removed in a future PyTorch release.")
// We make a strong guarantee that models using these operators will have the
// same numerics across different machines. Therefore, we do not provide a
// fallback path and rather fail loudly if we cannot run FBGEMM.
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
}
#endif // USE_FBGEMM
} // namespace native
} // namespace at

View File

@ -32,19 +32,12 @@
#include <ATen/ops/cat.h>
#include <ATen/ops/cudnn_is_acceptable.h>
#include <ATen/ops/dropout.h>
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
#include <ATen/ops/gru_cell_native.h>
#include <ATen/ops/gru_native.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/lstm_cell_native.h>
#include <ATen/ops/lstm_native.h>
#include <ATen/ops/matmul.h>
#include <ATen/ops/quantized_gru_cell_native.h>
#include <ATen/ops/quantized_lstm_cell_native.h>
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/rnn_relu_cell_native.h>
#include <ATen/ops/rnn_relu_native.h>
@ -208,158 +201,6 @@ struct CellParams : public CellParamsBase {
}
};
c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor bias_ih,
at::Tensor bias_hh);
struct QuantizedCellParams : public CellParamsBase {
QuantizedCellParams(
Tensor _w_ih,
Tensor _w_hh,
Tensor _b_ih,
Tensor _b_hh,
Tensor _packed_ih,
Tensor _packed_hh,
Tensor _col_offsets_ih,
Tensor _col_offsets_hh,
Scalar _scale_ih,
Scalar _scale_hh,
Scalar _zero_point_ih,
Scalar _zero_point_hh)
: w_ih(std::move(_w_ih)),
w_hh(std::move(_w_hh)),
b_ih_(std::move(_b_ih)),
b_hh_(std::move(_b_hh)),
packed_ih(std::move(_packed_ih)),
packed_hh(std::move(_packed_hh)),
col_offsets_ih(std::move(_col_offsets_ih)),
col_offsets_hh(std::move(_col_offsets_hh)),
scale_ih(std::move(_scale_ih)),
scale_hh(std::move(_scale_hh)),
zero_point_ih(std::move(_zero_point_ih)),
zero_point_hh(std::move(_zero_point_hh)) {}
const Tensor w_ih;
const Tensor w_hh;
const Tensor b_ih_;
const Tensor b_hh_;
const Tensor packed_ih;
const Tensor packed_hh;
const Tensor col_offsets_ih;
const Tensor col_offsets_hh;
const Scalar scale_ih;
const Scalar scale_hh;
const Scalar zero_point_ih;
const Scalar zero_point_hh;
Tensor matmul_ih(const Tensor& input) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor matmul_hh(const Tensor& h) const override {
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
}
Tensor linear_ih(const Tensor& input) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_);
}
Tensor linear_hh(const Tensor& h) const override {
return at::fbgemm_linear_int8_weight_fp32_activation(
h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_);
}
const Tensor& b_ih() const override {
return b_ih_;
}
const Tensor& b_hh() const override {
return b_hh_;
}
CellParamsSerializationType __getstate__() const override {
std::vector<at::Tensor> tensors_to_serialize = {
w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh};
std::vector<double> doubles_to_serialize = {scale_ih.toDouble(),
scale_hh.toDouble()};
std::vector<int64_t> longs_to_serialize = {zero_point_ih.toLong(),
zero_point_hh.toLong()};
return CellParamsSerializationType(
"quantized",
std::move(tensors_to_serialize),
std::move(doubles_to_serialize),
std::move(longs_to_serialize),
{});
}
static c10::intrusive_ptr<CellParamsBase> __setstate__(
CellParamsSerializationType state) {
std::vector<at::Tensor> tensors;
std::vector<double> doubles;
std::vector<int64_t> longs;
std::tie(std::ignore, tensors, doubles, longs, std::ignore) =
std::move(state);
TORCH_INTERNAL_ASSERT(tensors.size() == 6);
TORCH_INTERNAL_ASSERT(doubles.size() == 2);
TORCH_INTERNAL_ASSERT(longs.size() == 2);
at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]),
b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]),
col_offsets_ih = std::move(tensors[4]),
col_offsets_hh = std::move(tensors[5]);
double scale_ih = doubles[0], scale_hh = doubles[1];
int64_t zero_point_ih = longs[0], zero_point_hh = longs[1];
at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih);
at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh);
return c10::make_intrusive<QuantizedCellParams>(
/*w_ih=*/std::move(qw_ih),
/*w_hh=*/std::move(qw_hh),
/*b_ih_=*/std::move(b_ih),
/*b_hh_=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/scale_ih,
/*scale_hh=*/scale_hh,
/*zero_point_ih=*/zero_point_ih,
/*zero_point_hh=*/zero_point_hh);
}
};
c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
const at::Tensor& w_ih,
const at::Tensor& w_hh,
at::Tensor b_ih,
at::Tensor b_hh) {
auto make_vals = [&](const at::Tensor& W) {
auto params = at::native::fbgemm_linear_quantize_weight(W);
at::Tensor packed_weight =
at::native::fbgemm_pack_quantized_matrix(std::get<0>(params));
return std::tuple_cat(
std::make_tuple(std::move(packed_weight)), std::move(params));
};
at::Tensor qw_ih, qw_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh;
at::Scalar scale_ih, scale_hh, zero_point_ih, zero_point_hh;
std::tie(packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih) =
make_vals(w_ih);
std::tie(packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh) =
make_vals(w_hh);
return c10::make_intrusive<QuantizedCellParams>(
/*qw_ih=*/std::move(qw_ih),
/*qw_hh=*/std::move(qw_hh),
/*b_ih=*/std::move(b_ih),
/*b_hh=*/std::move(b_hh),
/*packed_ih=*/std::move(packed_ih),
/*packed_hh=*/std::move(packed_hh),
/*col_offsets_ih=*/std::move(col_offsets_ih),
/*col_offsets_hh=*/std::move(col_offsets_hh),
/*scale_ih=*/std::move(scale_ih),
/*scale_hh=*/std::move(scale_hh),
/*zero_point_ih=*/std::move(zero_point_ih),
/*zero_point_hh=*/std::move(zero_point_hh));
}
// QuantizedCellParams vs. QuantizedCellParamsDynamic
//
@ -536,7 +377,6 @@ static std::unordered_map<
std::string,
c10::intrusive_ptr<CellParamsBase> (*)(CellParamsSerializationType)>
cell_params_deserializers = {
{"quantized", &QuantizedCellParams::__setstate__},
{"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__},
{"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}};
@ -1841,38 +1681,6 @@ static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
"using the newer definitions in torch.jit.quantized");
}
#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
return_type name( \
const Tensor& input, \
hx_type hx, \
const Tensor& w_ih, \
const Tensor& w_hh, \
const Tensor& b_ih, \
const Tensor& b_hh, \
const Tensor& packed_ih, \
const Tensor& packed_hh, \
const Tensor& col_offsets_ih, \
const Tensor& col_offsets_hh, \
const Scalar& scale_ih, \
const Scalar& scale_hh, \
const Scalar& zero_point_ih, \
const Scalar& zero_point_hh) { \
QuantizedCellParams params( \
w_ih, \
w_hh, \
b_ih, \
b_hh, \
packed_ih, \
packed_hh, \
col_offsets_ih, \
col_offsets_hh, \
scale_ih, \
scale_hh, \
zero_point_ih, \
zero_point_hh); \
return cell_type{}( \
input, prepare_hx_fn(hx), params); \
}
// Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels
// QNNPACK does not reduce range for activations
#define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \
@ -1895,7 +1703,6 @@ return_type name( \
}
// Quantized LSTM cell
using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
return std::make_tuple(hx[0], hx[1]);
@ -1904,7 +1711,6 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
// Quantized LSTM cell
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
@ -1915,22 +1721,15 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
}
// Quantized GRU cell
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
// Quantized RNN w/ ReLU cell
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
// Quantized RNN w/ tanh cell
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
@ -1972,7 +1771,6 @@ TORCH_LIBRARY_FRAGMENT(aten, m) {
TORCH_LIBRARY_FRAGMENT(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
@ -1992,7 +1790,6 @@ TORCH_LIBRARY_IMPL(aten, CPU, m) {
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic));
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic));

View File

@ -3285,22 +3285,6 @@
dispatch:
CUDA: _mixed_dtypes_linear
- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)
- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor
- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor
variants: function, method
@ -7611,15 +7595,6 @@
# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
#
# Quantized RNN cells
- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)
- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
# PackedSequence utilities
- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
dispatch:

View File

@ -1324,7 +1324,6 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/PointwiseOps.cpp",
"aten/src/ATen/native/Pooling.cpp",
"aten/src/ATen/native/Pow.cpp",
"aten/src/ATen/native/QuantizedLinear.cpp",
"aten/src/ATen/native/RNN.cpp",
"aten/src/ATen/native/RangeFactories.cpp",
"aten/src/ATen/native/ReduceAllOps.cpp",

View File

@ -1131,7 +1131,6 @@ set(ATen_CPU_INCLUDE
${CMAKE_BINARY_DIR}/aten/src)
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/QuantizedLinear.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/RNN.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)

View File

@ -312,6 +312,18 @@ ALLOW_LIST = [
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_int8_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_quantize_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_gemm_matrix_fp16", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight_fp32_activation", datetime.date(2023, 12, 31)),
("aten::fbgemm_linear_fp16_weight", datetime.date(2023, 12, 31)),
("aten::fbgemm_pack_quantized_matrix", datetime.date(2023, 12, 31)),
("aten::quantized_lstm_cell", datetime.date(2023, 12, 31)),
("aten::quantized_gru_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)),
("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)),
("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)),
]
ALLOW_LIST_COMPILED = [

View File

@ -16,7 +16,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
from torch.testing._internal.common_utils import slowTest, suppress_warnings
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
if __name__ == '__main__':
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
@ -407,12 +406,6 @@ class TestModels(JitTestCase):
def test_snli(self):
self._test_snli(self, device='cpu')
@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_snli_quantized(self):
self._test_snli(self, device='cpu', quantized=True)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_snli_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)
@ -553,12 +546,6 @@ class TestModels(JitTestCase):
def test_vae(self):
self._test_vae(self, device='cpu')
@skipIfNoFBGEMM
# Suppression: this exercises a deprecated API
@suppress_warnings
def test_vae_quantized(self):
self._test_vae(self, device='cpu', quantized=True)
@unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_vae_cuda(self):
# XXX: export_import on CUDA modules doesn't work (#11480)

View File

@ -3200,81 +3200,6 @@ class TestDynamicQuantizedOps(TestCase):
self.assertEqual(Y_fp32, Y_fp32_ref,
msg="torch.ops.quantized.linear_dynamic results are off")
@skipIfNoFBGEMM
@given(
batch_size=st.integers(1, 4),
input_channels=st.integers(16, 32),
output_channels=st.integers(4, 8),
)
def test_qlinear_legacy(self, batch_size, input_channels, output_channels):
X_scale = 1.0
X_zp = 0
X_value_min = 0
X_value_max = 255
X_q0 = np.round(np.random.rand(batch_size, input_channels) * (
X_value_max - X_value_min) + X_value_min
).astype(np.uint8)
X_q0[0, 0] = X_value_min
X_q0[0, 1] = X_value_max
W_scale = 1.0
W_zp = 0
W_value_min = -128
W_value_max = 127
W_q0 = np.round(
np.random.rand(output_channels, input_channels)
* (W_value_max - W_value_min)
+ W_value_min
).astype(np.int8)
W_q0[0, 0] = W_value_min
W_q0[1, 0] = W_value_max
b_value_min = -10
b_value_max = 10
b_q0 = np.round(
np.random.rand(output_channels) * (b_value_max - b_value_min) +
b_value_min
).astype(np.int32)
avoid_vpmaddubsw_overflow_linear(
batch_size,
input_channels,
output_channels,
X_q0,
X_value_min,
X_value_max,
W_q0,
W_value_min,
W_value_max,
)
X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float)
W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
b_fp32 = torch.from_numpy(
_dequantize(b_q0, X_scale * W_scale, 0)
).to(dtype=torch.float)
W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8)
W_q = torch.quantize_per_tensor(W_fp32, scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
# Observe X_fp32 and determine X_scale and X_zero_point, this should match
# internals of dynamic linear.
X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8)
X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(W_q.dequantize())
W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8.clone(), W_int8.size(1), W_int8.size(0))
# Quantized Linear operator with prepacked weight
Y_fp32 = torch.fbgemm_linear_int8_weight(
X_q.dequantize(), W_q.dequantize(), W_prepack, col_offsets,
W_scale, W_zp, b_fp32)
Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32)
# Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
self.assertEqual(Y_fp32, Y_fp32_ref,
msg="torch.ops.quantized.fbgemm_linear_dynamic results are off")
@skipIfNoFBGEMM
@given(
input_channels=st.integers(16, 32),

View File

@ -4,11 +4,8 @@ import torch
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM
)
from torch.testing._internal.common_utils import suppress_warnings
from torch.testing._internal.jit_utils import JitTestCase
from typing import Tuple
import copy
class TestDeprecatedJitQuantized(JitTestCase):
@skipIfNoFBGEMM
@ -54,54 +51,8 @@ class TestDeprecatedJitQuantized(JitTestCase):
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
ref = copy.deepcopy(cell)
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float)
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
cx = torch.tensor(h0_vals, dtype=torch.float)
hiddens = (hx, cx)
else:
hiddens = hx
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell
@torch.jit.script_method
def forward(self, x: torch.Tensor,
hiddens: Tuple[torch.Tensor, torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)
else:
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell
@torch.jit.script_method
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
return self.cell(x, hiddens)
cell = ScriptWrapper(cell)
outs = cell(x, hiddens)
cell = self.getExportImportCopyWithPacking(cell)
outs = cell(x, hiddens)
ref_outs = ref(x, hiddens)
self.assertEqual(len(outs), len(ref_outs))
for out, ref_out in zip(outs, ref_outs):
torch.testing.assert_close(out, ref_out)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_cell_modules function is no longer supported"):
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
@skipIfNoFBGEMM
def test_rnn_quantized(self):
@ -143,85 +94,14 @@ class TestDeprecatedJitQuantized(JitTestCase):
torch.tensor(vals, dtype=torch.float),
requires_grad=False)
ref = copy.deepcopy(cell)
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
niter = 10
x = torch.tensor([[100, -155],
[-155, 100],
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
h0_vals = [[-155, 100],
[-155, 155],
[100, -155]]
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
if isinstance(ref, torch.nn.LSTM):
hiddens = (hx, cx)
elif isinstance(ref, torch.nn.GRU):
hiddens = hx
ref_out, ref_hid = ref(x, hiddens)
# Compare int8 quantized to unquantized
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
torch.testing.assert_close(output_int8, ref_out)
for out, ref in zip(final_hiddens_int8, ref_hid):
torch.testing.assert_close(out, ref)
# Compare fp16 quantized to unquantized
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
torch.testing.assert_close(output_fp16, ref_out)
for out, ref in zip(final_hiddens_fp16, ref_hid):
torch.testing.assert_close(out, ref)
def compare_quantized_unquantized(ScriptWrapper, cell):
wrapper = ScriptWrapper(cell)
# Compare quantize scripted module to unquantized
script_out, script_hid = wrapper(x, hiddens)
torch.testing.assert_close(script_out, ref_out)
for out, ref in zip(script_hid, ref_hid):
torch.testing.assert_close(out, ref)
# Compare export/import to unquantized
export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
ei_out, ei_hid = export_import_wrapper(x, hiddens)
torch.testing.assert_close(ei_out, ref_out)
for out, ref in zip(ei_hid, ref_hid):
torch.testing.assert_close(out, ref)
if isinstance(cell, torch.jit.quantized.QuantizedGRU):
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell
@torch.jit.script_method
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
return self.cell(x, hiddens)
compare_quantized_unquantized(ScriptWrapper, cell)
elif isinstance(cell, torch.jit.quantized.QuantizedLSTM):
for cell in [cell_int8, cell_fp16]:
class ScriptWrapper(torch.jit.ScriptModule):
def __init__(self, cell):
super().__init__()
self.cell = cell
@torch.jit.script_method
def forward(self, x, hiddens):
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
return self.cell(x, hiddens)
compare_quantized_unquantized(ScriptWrapper, cell)
if 'fbgemm' in torch.backends.quantized.supported_engines:
# Suppression: using deprecated quant api
@suppress_warnings
def test_quantization_modules(self):
K1, N1 = 2, 2
@ -244,18 +124,11 @@ class TestDeprecatedJitQuantized(JitTestCase):
y_ref = fb(value)
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
traced_int8 = torch.jit.trace(fb_int8, (x,))
fb_int8 = self.getExportImportCopyWithPacking(traced_int8)
y_int8 = fb_int8(value)
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
traced_fp16 = torch.jit.trace(fb_fp16, (x,))
fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
y_fp16 = fb_fp16(value)
torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3)
torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
@skipIfNoFBGEMM
def test_erase_class_tensor_shapes(self):

View File

@ -2112,25 +2112,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float()
self.assertEqual(res_bf16, expected)
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
' with instruction set support avx2 or newer.')
def test_fb_fc_packed(self):
X = np.random.rand(16, 16).astype(np.float32) - 0.5
W = np.random.rand(16, 16).astype(np.float32) - 0.5
b = np.random.rand(16).astype(np.float32) - 0.5
def fc_op(X, W, b):
return np.dot(X, W.T) + b
x_tensor = torch.tensor(X)
w_tensor = torch.tensor(W)
b_tensor = torch.tensor(b)
packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
expected_output = fc_op(X, W, b)
torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3)
def test_pad_scalar_error(self):
inputs = torch.tensor(0., requires_grad=True)
self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))

View File

@ -1,798 +1,99 @@
import warnings
from typing import List, Optional, Tuple
import torch
from torch import _VF, Tensor # noqa: F401
from torch.nn.utils.rnn import PackedSequence
class QuantizedLinear(torch.jit.ScriptModule):
__constants__ = ["scale", "zero_point"]
def __init__(self, other):
super().__init__()
warnings.warn(
"torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
raise RuntimeError(
"torch.jit.QuantizedLinear is no longer supported. Please use "
"torch.ao.nn.quantized.dynamic.Linear instead."
)
self.in_features = other.in_features
self.out_features = other.out_features
# Quantize weight and discard the original
(
self.weight,
self.col_offsets,
self.scale,
self.zero_point,
) = torch.fbgemm_linear_quantize_weight(
other.weight.clone(memory_format=torch.contiguous_format).float()
)
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
assert other.bias is not None, "QuantizedLinear requires a bias"
self.bias = torch.nn.Parameter(
other.bias.clone(memory_format=torch.contiguous_format).float(),
requires_grad=False,
)
self.register_buffer(
"packed_tensor_ptr",
torch.fbgemm_pack_quantized_matrix(
self.weight.clone(memory_format=torch.contiguous_format)
),
)
@torch.jit.script_method
def _unpack(self):
self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight))
@torch.jit.script_method
def _pack(self):
self.packed_tensor_ptr.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
)
@torch.jit.script_method
def forward(self, input):
out = torch.fbgemm_linear_int8_weight_fp32_activation(
input.float(),
self.weight,
self.packed_tensor_ptr,
self.col_offsets,
self.scale,
self.zero_point,
self.bias,
)
return out.to(input.dtype)
def extra_repr(self):
repr = (
"in_features={in_features}, out_features={out_features}, "
"scale={scale}, zero_point={zero_point}".format(**self.__dict__)
)
return repr
# FP16 weights
class QuantizedLinearFP16(torch.jit.ScriptModule):
def __init__(self, other):
super().__init__()
warnings.warn(
"torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
raise RuntimeError(
"torch.jit.QuantizedLinearFP16 is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.Linear instead."
)
self.in_features = other.in_features
self.out_features = other.out_features
self.original_weight = other.weight
self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
other.weight.clone(memory_format=torch.contiguous_format).float()
)
assert other.bias is not None, "QuantizedLinearFP16 requires a bias"
self.bias = torch.nn.Parameter(
other.bias.clone(memory_format=torch.contiguous_format).float(),
requires_grad=False,
)
self.register_buffer("packed_weight", self.weight)
@torch.jit.script_method
def _unpack(self):
self.packed_weight.set_(
torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight)
)
@torch.jit.script_method
def _pack(self):
self.packed_weight.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
)
@torch.jit.script_method
def forward(self, input):
out = torch.fbgemm_linear_fp16_weight_fp32_activation(
input.float(), self.packed_weight, self.bias
)
return out
def extra_repr(self):
repr = "in_features={in_features}, out_features={out_features}, ".format(
**self.__dict__
)
return repr
# Quantized RNN cell implementations
class QuantizedRNNCellBase(torch.jit.ScriptModule):
__constants__ = [
"input_size",
"hidden_size",
"bias",
"scale_hh",
"scale_ih",
"zero_point_ih",
"zero_point_hh",
]
def __init__(self, other):
super().__init__()
warnings.warn(
"torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
)
self.input_size = other.input_size
self.hidden_size = other.hidden_size
self.bias = other.bias
if not self.bias:
raise ValueError("Quantized RNN cells require bias terms")
(
weight_ih,
col_offsets_ih,
self.scale_ih,
self.zero_point_ih,
) = torch.fbgemm_linear_quantize_weight(
other.weight_ih.clone(memory_format=torch.contiguous_format).float()
)
self.register_buffer("weight_ih", weight_ih)
self.register_buffer("col_offsets_ih", col_offsets_ih)
(
weight_hh,
col_offsets_hh,
self.scale_hh,
self.zero_point_hh,
) = torch.fbgemm_linear_quantize_weight(
other.weight_hh.clone(memory_format=torch.contiguous_format).float()
)
self.register_buffer("weight_hh", weight_hh)
self.register_buffer("col_offsets_hh", col_offsets_hh)
packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
self.register_buffer("packed_ih", packed_ih)
packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
self.register_buffer("packed_hh", packed_hh)
self.bias_ih = torch.nn.Parameter(
other.bias_ih.clone(memory_format=torch.contiguous_format).float(),
requires_grad=False,
)
self.bias_hh = torch.nn.Parameter(
other.bias_hh.clone(memory_format=torch.contiguous_format).float(),
requires_grad=False,
)
def extra_repr(self):
s = "{input_size}, {hidden_size}"
if "bias" in self.__dict__ and self.bias is not True:
s += ", bias={bias}"
if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
s += ", nonlinearity={nonlinearity}"
return s.format(**self.__dict__)
@torch.jit.script_method
def check_forward_input(self, input):
if input.size(1) != self.input_size:
raise RuntimeError(
f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}"
)
@torch.jit.script_method
def check_forward_hidden(
self, input: Tensor, hx: Tensor, hidden_label: str = ""
) -> None:
if input.size(0) != hx.size(0):
raise RuntimeError(
f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}"
)
if hx.size(1) != self.hidden_size:
raise RuntimeError(
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}"
)
# TODO: for some reason weak_script_method causes a destruction of the
# module to occur, which in turn frees the packed_ih object via its DataPtr
# deleter. This is bizarre and should probably get fixed.
# @torch._jit_internal.weak_script_method
@torch.jit.script_method
def _unpack(self):
self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))
# @torch._jit_internal.weak_script_method
@torch.jit.script_method
def _pack(self):
self.packed_ih.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
)
self.packed_hh.set_(
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
raise RuntimeError(
"torch.jit.QuantizedRNNCellBase is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
)
class QuantizedRNNCell(QuantizedRNNCellBase):
__constants__ = [
"input_size",
"hidden_size",
"bias",
"scale_hh",
"scale_ih",
"zero_point_ih",
"zero_point_hh",
"nonlinearity",
]
def __init__(self, other):
super().__init__(other)
warnings.warn(
"torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
raise RuntimeError(
"torch.jit.QuantizedRNNCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
)
self.nonlinearity = other.nonlinearity
@torch.jit.script_method
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
self.check_forward_hidden(input, hx, "")
if self.nonlinearity == "tanh":
ret = _VF.quantized_rnn_tanh_cell(
input,
hx,
self.weight_ih,
self.weight_hh,
self.bias_ih,
self.bias_hh,
self.packed_ih,
self.packed_hh,
self.col_offsets_ih,
self.col_offsets_hh,
self.scale_ih,
self.scale_hh,
self.zero_point_ih,
self.zero_point_hh,
)
elif self.nonlinearity == "relu":
ret = _VF.quantized_rnn_relu_cell(
input,
hx,
self.weight_ih,
self.weight_hh,
self.bias_ih,
self.bias_hh,
self.packed_ih,
self.packed_hh,
self.col_offsets_ih,
self.col_offsets_hh,
self.scale_ih,
self.scale_hh,
self.zero_point_ih,
self.zero_point_hh,
)
else:
ret = input # TODO: remove when jit supports exception flow
raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
return ret
class QuantizedLSTMCell(QuantizedRNNCellBase):
def __init__(self, other):
super().__init__(other)
warnings.warn(
"torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
)
@torch.jit.script_method
def forward(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tensor]:
self.check_forward_input(input)
if hx is None:
zeros = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
hx = (zeros, zeros)
self.check_forward_hidden(input, hx[0], "[0]")
self.check_forward_hidden(input, hx[1], "[1]")
return _VF.quantized_lstm_cell(
input,
hx,
self.weight_ih,
self.weight_hh,
self.bias_ih,
self.bias_hh,
self.packed_ih,
self.packed_hh,
self.col_offsets_ih,
self.col_offsets_hh,
self.scale_ih,
self.scale_hh,
self.zero_point_ih,
self.zero_point_hh,
raise RuntimeError(
"torch.jit.QuantizedLSTMCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
)
class QuantizedGRUCell(QuantizedRNNCellBase):
def __init__(self, other):
super().__init__(other)
warnings.warn(
"torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
raise RuntimeError(
"torch.jit.QuantizedGRUCell is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
)
@torch.jit.script_method
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
self.check_forward_input(input)
if hx is None:
hx = torch.zeros(
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
)
self.check_forward_hidden(input, hx, "")
return _VF.quantized_gru_cell(
input,
hx,
self.weight_ih,
self.weight_hh,
self.bias_ih,
self.bias_hh,
self.packed_ih,
self.packed_hh,
self.col_offsets_ih,
self.col_offsets_hh,
self.scale_ih,
self.scale_hh,
self.zero_point_ih,
self.zero_point_hh,
)
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
return tensor.index_select(dim, permutation)
class QuantizedRNNBase(torch.jit.ScriptModule):
__constants__ = [
"mode",
"input_size",
"hidden_size",
"num_layers",
"bias",
"batch_first",
"dropout",
"bidirectional",
"dtype",
]
def __init__(self, other, dtype=torch.int8):
super().__init__()
warnings.warn(
"torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead."
raise RuntimeError(
"torch.jit.QuantizedRNNBase is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic instead."
)
self.mode = other.mode
self.input_size = other.input_size
self.hidden_size = other.hidden_size
self.num_layers = other.num_layers
self.bias = other.bias
self.batch_first = other.batch_first
if self.mode != "GRU":
assert not self.batch_first
self.dropout = other.dropout
self.bidirectional = other.bidirectional
num_directions = 2 if self.bidirectional else 1
self.dtype = dtype
assert self.bias
# TODO: support more than just LSTM
if self.mode != "LSTM" and self.mode != "GRU":
raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN")
if dtype != torch.int8 and dtype != torch.float16:
raise RuntimeError(f"Unsupported dtype: {dtype}")
self.all_weights = []
for layer in range(self.num_layers):
for direction in range(num_directions):
layer_input_size = (
self.input_size if layer == 0 else self.hidden_size * num_directions
)
suffix = "_reverse" if direction == 1 else ""
def get_weight_bias(ihhh):
weight_name = f"weight_{ihhh}_l{layer}{suffix}"
bias_name = f"bias_{ihhh}_l{layer}{suffix}"
weight = getattr(other, weight_name)
bias = getattr(other, bias_name)
return weight, bias
weight_ih, bias_ih = get_weight_bias("ih")
weight_hh, bias_hh = get_weight_bias("hh")
if dtype == torch.int8:
cell_params = torch.ops.quantized.make_quantized_cell_params(
weight_ih, weight_hh, bias_ih, bias_hh
)
else:
packed_ih = torch.ops.quantized.linear_prepack_fp16(
weight_ih.float(), bias_ih
)
packed_hh = torch.ops.quantized.linear_prepack_fp16(
weight_hh.float(), bias_hh
)
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
packed_ih, packed_hh
)
setattr(self, f"cell_params_{layer}_{suffix}", cell_params)
self.all_weights.append(cell_params)
@torch.jit.script_method
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
expected_input_dim = 2 if batch_sizes is not None else 3
if input.dim() != expected_input_dim:
raise RuntimeError(
f"input must have {expected_input_dim} dimensions, got {input.dim()}"
)
if self.input_size != input.size(-1):
raise RuntimeError(
f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
)
@torch.jit.script_method
def get_expected_hidden_size(
self, input: Tensor, batch_sizes: Optional[Tensor]
) -> Tuple[int, int, int]:
if batch_sizes is not None:
mini_batch = int(batch_sizes[0])
else:
mini_batch = input.size(0) if self.batch_first else input.size(1)
num_directions = 2 if self.bidirectional else 1
expected_hidden_size = (
self.num_layers * num_directions,
mini_batch,
self.hidden_size,
)
return expected_hidden_size
@torch.jit.script_method
def check_hidden_size(
self,
hx: Tensor,
expected_hidden_size: Tuple[int, int, int],
msg: str = "Expected hidden size {}, got {}",
) -> None:
if hx.size() != expected_hidden_size:
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
@torch.jit.script_method
def check_forward_args(
self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(
hidden, expected_hidden_size, msg="Expected hidden size {}, got {}"
)
@torch.jit.script_method
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
if permutation is None:
return hx
return apply_permutation(hx, permutation)
class QuantizedLSTM(QuantizedRNNBase):
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
def __init__(self, other, dtype):
super().__init__(other, dtype)
warnings.warn(
"torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
raise RuntimeError(
"torch.jit.QuantizedLSTM is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
)
@torch.jit.script_method
def forward_impl(
self,
input: Tensor,
hx: Optional[Tuple[Tensor, Tensor]],
batch_sizes: Optional[Tensor],
max_batch_size: int,
sorted_indices: Optional[Tensor],
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
if hx is None:
num_directions = 2 if self.bidirectional else 1
zeros = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
hx = (zeros, zeros)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
assert batch_sizes is None
result = torch.quantized_lstm(
input,
hx,
self.all_weights,
self.bias,
self.num_layers,
float(self.dropout),
self.training,
self.bidirectional,
self.batch_first,
dtype=self.dtype,
use_dynamic=False,
)
output = result[0]
hidden = result[1:]
return output, hidden
@torch.jit.script_method
def forward_tensor(
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
output, hidden = self.forward_impl(
input, hx, batch_sizes, max_batch_size, sorted_indices
)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def forward_packed(
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
input_, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = int(batch_sizes[0])
output, hidden = self.forward_impl(
input_, hx, batch_sizes, max_batch_size, sorted_indices
)
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def permute_hidden(
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
) -> Tuple[Tensor, Tensor]:
if permutation is None:
return hx
return apply_permutation(hx[0], permutation), apply_permutation(
hx[1], permutation
)
@torch.jit.script_method
def check_forward_args(
self,
input: Tensor,
hidden: Tuple[Tensor, Tensor],
batch_sizes: Optional[Tensor],
) -> None:
self.check_input(input, batch_sizes)
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
self.check_hidden_size(
hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}"
)
self.check_hidden_size(
hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}"
)
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
else:
return self.forward_tensor(input, hx)
class QuantizedGRU(QuantizedRNNBase):
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
warnings.warn(
"torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming "
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead."
raise RuntimeError(
"torch.jit.QuantizedGRU is no longer supported. "
"Please use the torch.ao.nn.quantized.dynamic.GRU instead."
)
@torch.jit.script_method
def forward_impl(
self,
input: Tensor,
hx: Optional[Tensor],
batch_sizes: Optional[Tensor],
max_batch_size: int,
sorted_indices: Optional[Tensor],
) -> Tuple[Tensor, Tensor]:
if hx is None:
num_directions = 2 if self.bidirectional else 1
hx = torch.zeros(
self.num_layers * num_directions,
max_batch_size,
self.hidden_size,
dtype=input.dtype,
device=input.device,
)
else:
# Each batch of the hidden state should match the input sequence that
# the user believes he/she is passing in.
hx = self.permute_hidden(hx, sorted_indices)
self.check_forward_args(input, hx, batch_sizes)
if batch_sizes is None:
result = torch.quantized_gru(
input,
hx,
self.all_weights,
self.bias,
self.num_layers,
float(self.dropout),
self.training,
self.bidirectional,
self.batch_first,
)
else:
result = torch.quantized_gru(
input,
batch_sizes,
hx,
self.all_weights,
self.bias,
self.num_layers,
float(self.dropout),
self.training,
self.bidirectional,
)
output = result[0]
hidden = result[1]
return output, hidden
@torch.jit.script_method
def forward_tensor(
self, input: Tensor, hx: Optional[Tensor] = None
) -> Tuple[Tensor, Tensor]:
batch_sizes = None
max_batch_size = input.size(0) if self.batch_first else input.size(1)
sorted_indices = None
unsorted_indices = None
output, hidden = self.forward_impl(
input, hx, batch_sizes, max_batch_size, sorted_indices
)
return output, self.permute_hidden(hidden, unsorted_indices)
@torch.jit.script_method
def forward_packed(
self, input: PackedSequence, hx: Optional[Tensor] = None
) -> Tuple[PackedSequence, Tensor]:
input_, batch_sizes, sorted_indices, unsorted_indices = input
max_batch_size = int(batch_sizes[0])
output, hidden = self.forward_impl(
input_, hx, batch_sizes, max_batch_size, sorted_indices
)
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
return output, self.permute_hidden(hidden, unsorted_indices)
def forward(self, input, hx=None):
if isinstance(input, PackedSequence):
return self.forward_packed(input, hx)
else:
return self.forward_tensor(input, hx)
def quantize_rnn_cell_modules(module):
warnings.warn(
"quantize_rnn_cell_modules function has been deprecated. "
raise RuntimeError(
"quantize_rnn_cell_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_rnn_cell_modules(mod)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.LSTMCell):
return QuantizedLSTMCell(module)
if isinstance(module, torch.nn.GRUCell):
return QuantizedGRUCell(module)
if isinstance(module, torch.nn.RNNCell):
return QuantizedRNNCell(module)
return module
def quantize_linear_modules(module, dtype=torch.int8):
warnings.warn(
"quantize_linear_modules function has been deprecated. "
raise RuntimeError(
"quantize_linear_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_linear_modules(mod, dtype)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.Linear):
if dtype == torch.int8:
return QuantizedLinear(module)
elif dtype == torch.float16:
return QuantizedLinearFP16(module)
else:
raise RuntimeError(f"Unsupported dtype: {dtype}")
return module
def quantize_rnn_modules(module, dtype=torch.int8):
warnings.warn(
"quantize_rnn_modules function has been deprecated. "
raise RuntimeError(
"quantize_rnn_modules function is no longer supported. "
"Please use torch.ao.quantization.quantize_dynamic API instead."
)
reassign = {}
for name, mod in module.named_modules():
if mod is module:
continue
new_mod = quantize_rnn_modules(mod, dtype)
if new_mod is not mod:
reassign[name] = new_mod
for name, mod in reassign.items():
setattr(module, name, mod)
if isinstance(module, torch.nn.LSTM):
if dtype != torch.int8 and dtype != torch.float16:
raise RuntimeError(f"Unsupported dtype: {dtype}")
return QuantizedLSTM(module, dtype)
if isinstance(module, torch.nn.GRU):
return QuantizedGRU(module)
return module

View File

@ -585,14 +585,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
per_row_fake_quant=False, symmetric_quant=False: -1),
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
weight_zero_point, bias: -1),
torch.fbgemm_linear_quantize_weight: lambda input: -1,
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
torch.feature_alpha_dropout: lambda input, p, train: -1,
torch.feature_dropout: lambda input, p, train: -1,
torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
@ -977,21 +969,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
dilation=(1,), ceil_mode=False: -1),
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
dilation=(1, 1), ceil_mode=False: -1),
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
dilation=(1, 1, 1), ceil_mode=False: -1),
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
torch.rad2deg: lambda input, out=None: -1,
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,