mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Remove deprecated fbgemm operators (#104535)
These operators are not used and have been deprecated since #72690 (Feb 2022). Additionally, the `torch.jit.quantized` interface has been deprecated since #40102 (June 2020). Pull Request resolved: https://github.com/pytorch/pytorch/pull/104535 Approved by: https://github.com/ezyang
This commit is contained in:
parent
bf01a7b023
commit
57c7aa12db
|
|
@ -1,585 +0,0 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <vector>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/WrapDimUtilsMulti.h>
|
||||
#include <ATen/cpp_custom_type_hack.h>
|
||||
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
|
||||
#include <ATen/native/quantized/PackedParams.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/empty_like_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_int8_weight_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
|
||||
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
|
||||
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
#include <fbgemm/Fbgemm.h>
|
||||
#include <fbgemm/FbgemmFP16.h>
|
||||
#include <fbgemm/QuantUtils.h>
|
||||
#endif // USE_FBGEMM
|
||||
|
||||
namespace caffe2 {
|
||||
CAFFE_KNOWN_TYPE(c10::intrusive_ptr<LinearPackedParamsBase>);
|
||||
} // namespace caffe2
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
namespace caffe2 {
|
||||
// Required for cpp_custom_type_hack to work
|
||||
CAFFE_KNOWN_TYPE(fbgemm::PackBMatrix<int8_t>);
|
||||
CAFFE_KNOWN_TYPE(c10::intrusive_ptr<PackedLinearWeightFp16>);
|
||||
} // namespace caffe2
|
||||
#endif // USE_FBGEMM
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
#ifdef USE_FBGEMM
|
||||
|
||||
Tensor fbgemm_linear_int8_weight_fp32_activation(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& packed,
|
||||
const Tensor& col_offsets,
|
||||
const Scalar& weight_scale,
|
||||
const Scalar& weight_zero_point,
|
||||
const Tensor& bias) {
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
|
||||
|
||||
TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
const Tensor input_contig = input.contiguous();
|
||||
const float* input_ptr = input_contig.data_ptr<float>();
|
||||
|
||||
TORCH_CHECK(input.dim() >= 2);
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
|
||||
const int64_t K = input.size(input.dim() - 1);
|
||||
TORCH_CHECK(weight.dim() == 2);
|
||||
TORCH_CHECK(K == weight.size(1));
|
||||
const int64_t N = weight.size(0);
|
||||
TORCH_CHECK(bias.dim() == 1);
|
||||
TORCH_CHECK(bias.size(0) == N);
|
||||
TORCH_CHECK(weight_scale.isFloatingPoint());
|
||||
TORCH_CHECK(weight_zero_point.isIntegral(false));
|
||||
|
||||
// Calculate statistics for quantization of the input Tensor
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float x_min;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float x_max;
|
||||
fbgemm::FindMinMax(
|
||||
/*m=*/input_ptr,
|
||||
/*min=*/&x_min,
|
||||
/*max=*/&x_max,
|
||||
/*len=*/input.numel());
|
||||
|
||||
// Input tensor is quantized as 8-bit unsigned values
|
||||
constexpr int kPrecision = 8;
|
||||
constexpr bool kIsSigned = false;
|
||||
constexpr int kBound = (1 << (kPrecision - 1));
|
||||
|
||||
// Calculate scale and zero point for quantization of input tensor
|
||||
auto q_params = fbgemm::ChooseQuantizationParams(
|
||||
/*min=*/x_min,
|
||||
/*max=*/x_max,
|
||||
/*qmin=*/kIsSigned ? -kBound : 0,
|
||||
/*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
|
||||
/*preserve_sparsity=*/false);
|
||||
q_params.precision = kPrecision;
|
||||
|
||||
// ReQuantizeForFloat requires pointers to the scale and zero point values,
|
||||
// since in the case of rowwise quantization these will be arrays rather than
|
||||
// scalars. But in this case, we're doing whole-tensor quantization so we just
|
||||
// pass a pointer to the scale values (and internally ReQuantizeFor Float
|
||||
// won't index past 0
|
||||
const float weight_scale_float =
|
||||
static_cast<float>(weight_scale.to<double>());
|
||||
const int32_t weight_zero_point_int32 =
|
||||
static_cast<int32_t>(weight_zero_point.to<int64_t>());
|
||||
|
||||
const Tensor bias_contig = bias.contiguous();
|
||||
|
||||
// Allocate output Tensor and a buffer for fbgemmPacked to use
|
||||
std::vector<int64_t> output_size = input.sizes().vec();
|
||||
output_size.back() = N;
|
||||
Tensor output = at::empty(output_size, input.options().dtype(at::kFloat), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
Tensor buffer = at::empty(output_size, input.options().dtype(at::kInt), LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
|
||||
// Pull out the PackBMatrix instance from the owning tensor
|
||||
auto& pack_b =
|
||||
cpp_custom_type_hack::cast<fbgemm::PackBMatrix<int8_t>>(packed);
|
||||
|
||||
const int num_tasks = at::get_num_threads();
|
||||
at::parallel_for(0, num_tasks, 1, [&](int64_t begin, int64_t end) {
|
||||
// This operation does the following:
|
||||
// 1) Quantizes the input matrix given the statistics we've calculated
|
||||
// above.
|
||||
// 2) Creates a "row buffer" vector with offset values that must be added
|
||||
// to the integer matrix multiplication operation to ensure correctness.
|
||||
// 3) Packs the resulting quantized matrix into vector-register and cache
|
||||
// friendly tiles.
|
||||
//
|
||||
// Note this is not executed eagerly, but rather within the fbgemmPacked
|
||||
// call below.
|
||||
fbgemm::PackAWithQuantRowOffset<uint8_t> pack_a(
|
||||
/*trans=*/fbgemm::matrix_op_t::NoTranspose,
|
||||
/*nRow=*/M,
|
||||
/*nCol=*/K,
|
||||
/*smat=*/input_ptr,
|
||||
/*ld=*/K,
|
||||
/*pmat=*/nullptr, // pack_a manages ownership of `pmat`
|
||||
/*scale=*/q_params.scale,
|
||||
/*zero_pt=*/q_params.zero_point);
|
||||
|
||||
// This is the end of the pipeline, pass the resulting matrix through
|
||||
fbgemm::DoNothing<float, float> kDoNothingObj{};
|
||||
for (const auto task_id : c10::irange(begin, end)) {
|
||||
// After the uint8 * int8 matrix multiplication is performed, this
|
||||
// operation does:
|
||||
// 1) Add in row and column offsets to the rows and columns, respectively
|
||||
// 2) Dequantize the results into floating point
|
||||
// 3) Add in the bias term
|
||||
fbgemm::ReQuantizeForFloat</* FUSE_RELU */ false> output_proc_obj(
|
||||
/*nextop=*/kDoNothingObj,
|
||||
/*Aq_scale=*/q_params.scale,
|
||||
/*Bq_scale=*/&weight_scale_float,
|
||||
/*Aq_zero_point=*/q_params.zero_point,
|
||||
/*Bq_zero_point=*/&weight_zero_point_int32,
|
||||
/*row_offsets=*/pack_a.getRowOffsetBuffer(),
|
||||
/*col_offsets=*/col_offsets.data_ptr<int32_t>(),
|
||||
/*bias=*/bias_contig.data_ptr<float>(),
|
||||
/*nCol=*/N);
|
||||
// Do the GEMM
|
||||
fbgemm::fbgemmPacked(
|
||||
/*packA=*/pack_a,
|
||||
/*packB=*/pack_b,
|
||||
/*C=*/output.data_ptr<float>(),
|
||||
/*C_buffer=*/buffer.data_ptr<int32_t>(),
|
||||
/*ldc=*/N,
|
||||
/*outProcess=*/output_proc_obj,
|
||||
/*thread_id=*/task_id,
|
||||
/*num_threads=*/num_tasks);
|
||||
}
|
||||
});
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_int8_weight(
|
||||
const Tensor& input,
|
||||
const Tensor& weight,
|
||||
const Tensor& packed,
|
||||
const Tensor& col_offsets,
|
||||
const Scalar& weight_scale,
|
||||
const Scalar& weight_zero_point,
|
||||
const Tensor& bias) {
|
||||
return at::native::fbgemm_linear_int8_weight_fp32_activation(
|
||||
input,
|
||||
weight,
|
||||
packed,
|
||||
col_offsets,
|
||||
weight_scale,
|
||||
weight_zero_point,
|
||||
bias);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Calculate the column offsets
|
||||
// Note this includes the sum of the columns as well as the scalar term
|
||||
// B_zero_point * K, whereas the row_offsets created by
|
||||
// PackAWithQuantRowOffset is only the sum of the A rows.
|
||||
void CalcColOffsetsTranspose(
|
||||
int K,
|
||||
int N,
|
||||
const int8_t* Bint8,
|
||||
int32_t B_zero_point,
|
||||
int32_t* col_offsets) {
|
||||
for (const auto i : c10::irange(N)) {
|
||||
int32_t sum = 0;
|
||||
for (const auto j : c10::irange(K)) {
|
||||
sum += Bint8[i * K + j];
|
||||
}
|
||||
col_offsets[i] = sum - B_zero_point * K;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
|
||||
const Tensor& weight) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
|
||||
const Tensor weight_contig = weight.contiguous();
|
||||
|
||||
// Calculate weight statistics
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float w_min;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
float w_max;
|
||||
fbgemm::FindMinMax(
|
||||
/*m=*/weight_contig.data_ptr<float>(),
|
||||
/*min=*/&w_min,
|
||||
/*max=*/&w_max,
|
||||
/*len=*/weight_contig.numel());
|
||||
|
||||
// Choose parameters for quantizing the weight as 8-bit signed integer
|
||||
constexpr bool kIsSigned = true;
|
||||
constexpr int kPrecision = 8;
|
||||
constexpr int kBound = (1 << (kPrecision - 1));
|
||||
auto q_params = fbgemm::ChooseQuantizationParams(
|
||||
/*min=*/w_min,
|
||||
/*max=*/w_max,
|
||||
/*qmin=*/kIsSigned ? -kBound : 0,
|
||||
/*qmax=*/kIsSigned ? (kBound - 1) : (1 << kPrecision) - 1,
|
||||
/*preserve_sparsity=*/false);
|
||||
q_params.precision = kPrecision;
|
||||
|
||||
Tensor quantized = at::native::empty_like(
|
||||
weight_contig,
|
||||
at::kChar,
|
||||
weight_contig.options().layout_opt(),
|
||||
weight_contig.options().device_opt(),
|
||||
weight_contig.options().pinned_memory_opt(),
|
||||
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
// Tensor quantized = at::native::empty_cpu(
|
||||
// weight_contig.sizes(), weight_contig.options().dtype(at::kChar));
|
||||
fbgemm::Quantize<int8_t, false /*LEGACY*/>(
|
||||
/*src=*/weight_contig.data_ptr<float>(),
|
||||
/*dst=*/quantized.data_ptr<int8_t>(),
|
||||
/*len=*/weight_contig.numel(),
|
||||
/*qparams=*/q_params);
|
||||
|
||||
// Calculate column offsets of the weight and store them away in a tensor.
|
||||
// Similarly to quantization, this can be done once and cached.
|
||||
Tensor col_offsets = at::empty(
|
||||
{weight_contig.size(0)},
|
||||
at::kInt,
|
||||
weight_contig.options().layout_opt(),
|
||||
weight_contig.options().device_opt(),
|
||||
weight_contig.options().pinned_memory_opt(),
|
||||
LEGACY_CONTIGUOUS_MEMORY_FORMAT);
|
||||
CalcColOffsetsTranspose(
|
||||
/*K=*/quantized.size(1),
|
||||
/*N=*/quantized.size(0),
|
||||
/*Bint8=*/quantized.data_ptr<int8_t>(),
|
||||
/*B_zero_point=*/q_params.zero_point,
|
||||
/*col_offsets=*/col_offsets.data_ptr<int32_t>());
|
||||
|
||||
return std::make_tuple(
|
||||
quantized, col_offsets, q_params.scale, q_params.zero_point);
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_quantized_matrix(const Tensor& weight) {
|
||||
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
|
||||
const int64_t K = weight.size(1);
|
||||
const int64_t N = weight.size(0);
|
||||
const Tensor weight_contig = weight.contiguous();
|
||||
const int8_t* weight_ptr = weight_contig.data_ptr<int8_t>();
|
||||
auto ptr = std::make_unique<fbgemm::PackBMatrix<int8_t>>(
|
||||
/*trans=*/fbgemm::matrix_op_t::Transpose,
|
||||
/*nRow=*/K,
|
||||
/*nCol=*/N,
|
||||
/*smat=*/weight_ptr,
|
||||
/*ld=*/K,
|
||||
/*pmat=*/nullptr, // PackBMatrix manages ownership of pmat
|
||||
/*groups=*/1);
|
||||
return cpp_custom_type_hack::create(std::move(ptr), weight.options());
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_quantized_matrix(
|
||||
const Tensor& weight,
|
||||
int64_t K,
|
||||
int64_t N) {
|
||||
// Replace after https://github.com/pytorch/pytorch/issues/24354 is fixed
|
||||
// TORCH_WARN(
|
||||
// "fbgemm_pack_quantized_matrix(weight, K, N) will be deprecated soon."
|
||||
// "Please use fbgemm_pack_quantized_matrix(weight) instead.");
|
||||
return at::native::fbgemm_pack_quantized_matrix(weight);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
float RawUint16ToFp16(unsigned short value) {
|
||||
// Convert raw 16 bits half precision floating point number
|
||||
// to single precision floating point number.
|
||||
const unsigned short sign_bits = value >> 15;
|
||||
const unsigned short exponent_bits = value >> 10 & 0x1f;
|
||||
const unsigned short significand_bits = value & 0x3ff;
|
||||
|
||||
const float sign = sign_bits ? -1 : 1;
|
||||
const float significand =
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
1 + significand_bits * 0.0009765625f; // 0.0009765625f = 0x1p-10 = 2^-10
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const float exponent = exponent_bits - 0xf;
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
return sign * std::ldexp(significand, exponent);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool CheckAndSaturate(T max_val, T* element) {
|
||||
if (*element > max_val) {
|
||||
*element = max_val;
|
||||
return true;
|
||||
}
|
||||
if (*element < -max_val) {
|
||||
*element = -max_val;
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// The range for using FP16 quantization of weights requires that the elements
|
||||
// should be in the range of [5.96e-8, 65504]. If it is out of range, then the
|
||||
// number will be saturated to max or min representable values by FP16.
|
||||
void HandleWeightsSaturation(int64_t N, float* weight) {
|
||||
const float kFp16Max = RawUint16ToFp16(0x7BFF);
|
||||
bool found_out_of_range = false;
|
||||
for (const auto i : c10::irange(N)) {
|
||||
if (CheckAndSaturate<float>(kFp16Max, weight + i)) {
|
||||
found_out_of_range = true;
|
||||
}
|
||||
}
|
||||
if (found_out_of_range) {
|
||||
TORCH_WARN("FOUND weight out of range ");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
|
||||
TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
|
||||
|
||||
const int64_t K = weight.size(1);
|
||||
const int64_t N = weight.size(0);
|
||||
Tensor weight_contig = weight.contiguous();
|
||||
float* weight_contig_ptr = weight_contig.data_ptr<float>();
|
||||
HandleWeightsSaturation(K * N, weight_contig_ptr);
|
||||
|
||||
// TODO(mingzhe09088):
|
||||
// Consider using a functor here in PackedGemmMatrixFP16
|
||||
// Comments from (XQ): Not entirely sure this make_unique is safe. make_unique
|
||||
// is created with regular "new", and freed through TypeMetaData::deleteFn in
|
||||
// this function. This is perfectly fine if the tensors are created and freed
|
||||
// within this translation unit. It might be very problematic if that tensor
|
||||
// flows across dll boundaries.
|
||||
auto ptr = std::make_unique<fbgemm::PackedGemmMatrixFP16>(
|
||||
fbgemm::matrix_op_t::Transpose, K, N, 1, weight_contig_ptr);
|
||||
c10::intrusive_ptr<LinearPackedParamsBase> packed_weight =
|
||||
c10::make_intrusive<PackedLinearWeightFp16>(std::move(ptr), c10::nullopt);
|
||||
auto unique_ptr_wrapper =
|
||||
std::make_unique<decltype(packed_weight)>(std::move(packed_weight));
|
||||
return cpp_custom_type_hack::create(
|
||||
std::move(unique_ptr_wrapper), weight.options());
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(fbgemm::fbgemmSupportedCPU(), "Your CPU doesn't support FBGEMM.");
|
||||
|
||||
const Tensor input_contig = input.contiguous();
|
||||
const float* input_ptr = input_contig.data_ptr<float>();
|
||||
|
||||
// Pull out the PackedGemmMatrixFP16 instance from the owning tensor
|
||||
const fbgemm::PackedGemmMatrixFP16& packed_weight_fp16 =
|
||||
*c10::dynamic_intrusive_pointer_cast<PackedLinearWeightFp16>(
|
||||
cpp_custom_type_hack::cast<
|
||||
c10::intrusive_ptr<LinearPackedParamsBase>>(packed_weight))
|
||||
->w;
|
||||
|
||||
TORCH_CHECK(input.size(input.dim() - 1) == packed_weight_fp16.numRows())
|
||||
TORCH_CHECK(input.dim() >= 2);
|
||||
TORCH_CHECK(bias.dim() == 1);
|
||||
|
||||
// NOLINTNEXTLINE(bugprone-narrowing-conversions,cppcoreguidelines-narrowing-conversions)
|
||||
const int64_t M = size_to_dim_(input.dim() - 1, input.sizes());
|
||||
const int64_t N = packed_weight_fp16.numCols();
|
||||
std::vector<int64_t> output_size = input.sizes().vec();
|
||||
output_size.back() = N;
|
||||
Tensor output = at::empty(output_size, input.options().dtype(at::kFloat));
|
||||
|
||||
// Call the fp16 gemm interface
|
||||
fbgemm::cblas_gemm_compute(
|
||||
fbgemm::matrix_op_t::NoTranspose,
|
||||
M,
|
||||
input_ptr,
|
||||
packed_weight_fp16,
|
||||
0.0f,
|
||||
output.data_ptr<float>());
|
||||
|
||||
// Add bias term
|
||||
output.add_(bias);
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_fp16_weight(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
return at::native::fbgemm_linear_fp16_weight_fp32_activation(
|
||||
input, packed_weight, bias);
|
||||
}
|
||||
|
||||
#else // USE_FBGEMM
|
||||
|
||||
Tensor fbgemm_linear_int8_weight_fp32_activation(
|
||||
const Tensor& /*input*/,
|
||||
const Tensor& /*weight*/,
|
||||
const Tensor& /*packed*/,
|
||||
const Tensor& /*col_offsets*/,
|
||||
const Scalar& /*weight_scale*/,
|
||||
const Scalar& /*weight_zero_point*/,
|
||||
const Tensor& /*bias*/) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_int8_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_int8_weight(
|
||||
const Tensor& /*input*/,
|
||||
const Tensor& /*weight*/,
|
||||
const Tensor& /*packed*/,
|
||||
const Tensor& /*col_offsets*/,
|
||||
const Scalar& /*weight_scale*/,
|
||||
const Scalar& /*weight_zero_point*/,
|
||||
const Tensor& /*bias*/) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_int8_weight is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, double, int64_t> fbgemm_linear_quantize_weight(
|
||||
const Tensor& /*weight*/) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_quantize_weight is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_quantized_matrix(const Tensor& /*input*/) {
|
||||
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_quantized_matrix(
|
||||
const Tensor& /*input*/,
|
||||
int64_t /*K*/,
|
||||
int64_t /*N*/) {
|
||||
TORCH_WARN_ONCE("fbgemm_pack_quantized_matrix is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_pack_gemm_matrix_fp16(const Tensor& weight) {
|
||||
TORCH_WARN_ONCE("fbgemm_pack_gemm_matrix_fp16 is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_fp16_weight_fp32_activation(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight_fp32_activation is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
Tensor fbgemm_linear_fp16_weight(
|
||||
const Tensor& input,
|
||||
const Tensor& packed_weight,
|
||||
const Tensor& bias) {
|
||||
TORCH_WARN_ONCE("fbgemm_linear_fp16_weight is deprecated "
|
||||
"and will be removed in a future PyTorch release.")
|
||||
|
||||
// We make a strong guarantee that models using these operators will have the
|
||||
// same numerics across different machines. Therefore, we do not provide a
|
||||
// fallback path and rather fail loudly if we cannot run FBGEMM.
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
}
|
||||
|
||||
#endif // USE_FBGEMM
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
@ -32,19 +32,12 @@
|
|||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/cudnn_is_acceptable.h>
|
||||
#include <ATen/ops/dropout.h>
|
||||
#include <ATen/ops/fbgemm_linear_int8_weight_fp32_activation.h>
|
||||
#include <ATen/ops/fbgemm_linear_quantize_weight_native.h>
|
||||
#include <ATen/ops/fbgemm_pack_quantized_matrix_native.h>
|
||||
#include <ATen/ops/gru_cell_native.h>
|
||||
#include <ATen/ops/gru_native.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/lstm_cell_native.h>
|
||||
#include <ATen/ops/lstm_native.h>
|
||||
#include <ATen/ops/matmul.h>
|
||||
#include <ATen/ops/quantized_gru_cell_native.h>
|
||||
#include <ATen/ops/quantized_lstm_cell_native.h>
|
||||
#include <ATen/ops/quantized_rnn_relu_cell_native.h>
|
||||
#include <ATen/ops/quantized_rnn_tanh_cell_native.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/rnn_relu_cell_native.h>
|
||||
#include <ATen/ops/rnn_relu_native.h>
|
||||
|
|
@ -208,158 +201,6 @@ struct CellParams : public CellParamsBase {
|
|||
}
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
|
||||
const at::Tensor& w_ih,
|
||||
const at::Tensor& w_hh,
|
||||
at::Tensor bias_ih,
|
||||
at::Tensor bias_hh);
|
||||
|
||||
struct QuantizedCellParams : public CellParamsBase {
|
||||
QuantizedCellParams(
|
||||
Tensor _w_ih,
|
||||
Tensor _w_hh,
|
||||
Tensor _b_ih,
|
||||
Tensor _b_hh,
|
||||
Tensor _packed_ih,
|
||||
Tensor _packed_hh,
|
||||
Tensor _col_offsets_ih,
|
||||
Tensor _col_offsets_hh,
|
||||
Scalar _scale_ih,
|
||||
Scalar _scale_hh,
|
||||
Scalar _zero_point_ih,
|
||||
Scalar _zero_point_hh)
|
||||
: w_ih(std::move(_w_ih)),
|
||||
w_hh(std::move(_w_hh)),
|
||||
b_ih_(std::move(_b_ih)),
|
||||
b_hh_(std::move(_b_hh)),
|
||||
packed_ih(std::move(_packed_ih)),
|
||||
packed_hh(std::move(_packed_hh)),
|
||||
col_offsets_ih(std::move(_col_offsets_ih)),
|
||||
col_offsets_hh(std::move(_col_offsets_hh)),
|
||||
scale_ih(std::move(_scale_ih)),
|
||||
scale_hh(std::move(_scale_hh)),
|
||||
zero_point_ih(std::move(_zero_point_ih)),
|
||||
zero_point_hh(std::move(_zero_point_hh)) {}
|
||||
|
||||
const Tensor w_ih;
|
||||
const Tensor w_hh;
|
||||
const Tensor b_ih_;
|
||||
const Tensor b_hh_;
|
||||
const Tensor packed_ih;
|
||||
const Tensor packed_hh;
|
||||
const Tensor col_offsets_ih;
|
||||
const Tensor col_offsets_hh;
|
||||
const Scalar scale_ih;
|
||||
const Scalar scale_hh;
|
||||
const Scalar zero_point_ih;
|
||||
const Scalar zero_point_hh;
|
||||
|
||||
Tensor matmul_ih(const Tensor& input) const override {
|
||||
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
|
||||
}
|
||||
Tensor matmul_hh(const Tensor& h) const override {
|
||||
TORCH_CHECK(false, "matmul is not supported with quantized cell params");
|
||||
}
|
||||
Tensor linear_ih(const Tensor& input) const override {
|
||||
return at::fbgemm_linear_int8_weight_fp32_activation(
|
||||
input, w_ih, packed_ih, col_offsets_ih, scale_ih, zero_point_ih, b_ih_);
|
||||
}
|
||||
Tensor linear_hh(const Tensor& h) const override {
|
||||
return at::fbgemm_linear_int8_weight_fp32_activation(
|
||||
h, w_hh, packed_hh, col_offsets_hh, scale_hh, zero_point_hh, b_hh_);
|
||||
}
|
||||
const Tensor& b_ih() const override {
|
||||
return b_ih_;
|
||||
}
|
||||
const Tensor& b_hh() const override {
|
||||
return b_hh_;
|
||||
}
|
||||
CellParamsSerializationType __getstate__() const override {
|
||||
std::vector<at::Tensor> tensors_to_serialize = {
|
||||
w_ih, w_hh, b_ih_, b_hh_, col_offsets_ih, col_offsets_hh};
|
||||
std::vector<double> doubles_to_serialize = {scale_ih.toDouble(),
|
||||
scale_hh.toDouble()};
|
||||
std::vector<int64_t> longs_to_serialize = {zero_point_ih.toLong(),
|
||||
zero_point_hh.toLong()};
|
||||
return CellParamsSerializationType(
|
||||
"quantized",
|
||||
std::move(tensors_to_serialize),
|
||||
std::move(doubles_to_serialize),
|
||||
std::move(longs_to_serialize),
|
||||
{});
|
||||
}
|
||||
static c10::intrusive_ptr<CellParamsBase> __setstate__(
|
||||
CellParamsSerializationType state) {
|
||||
std::vector<at::Tensor> tensors;
|
||||
std::vector<double> doubles;
|
||||
std::vector<int64_t> longs;
|
||||
std::tie(std::ignore, tensors, doubles, longs, std::ignore) =
|
||||
std::move(state);
|
||||
TORCH_INTERNAL_ASSERT(tensors.size() == 6);
|
||||
TORCH_INTERNAL_ASSERT(doubles.size() == 2);
|
||||
TORCH_INTERNAL_ASSERT(longs.size() == 2);
|
||||
|
||||
at::Tensor qw_ih = std::move(tensors[0]), qw_hh = std::move(tensors[1]),
|
||||
b_ih = std::move(tensors[2]), b_hh = std::move(tensors[3]),
|
||||
col_offsets_ih = std::move(tensors[4]),
|
||||
col_offsets_hh = std::move(tensors[5]);
|
||||
double scale_ih = doubles[0], scale_hh = doubles[1];
|
||||
int64_t zero_point_ih = longs[0], zero_point_hh = longs[1];
|
||||
|
||||
at::Tensor packed_ih = at::native::fbgemm_pack_quantized_matrix(qw_ih);
|
||||
at::Tensor packed_hh = at::native::fbgemm_pack_quantized_matrix(qw_hh);
|
||||
|
||||
return c10::make_intrusive<QuantizedCellParams>(
|
||||
/*w_ih=*/std::move(qw_ih),
|
||||
/*w_hh=*/std::move(qw_hh),
|
||||
/*b_ih_=*/std::move(b_ih),
|
||||
/*b_hh_=*/std::move(b_hh),
|
||||
/*packed_ih=*/std::move(packed_ih),
|
||||
/*packed_hh=*/std::move(packed_hh),
|
||||
/*col_offsets_ih=*/std::move(col_offsets_ih),
|
||||
/*col_offsets_hh=*/std::move(col_offsets_hh),
|
||||
/*scale_ih=*/scale_ih,
|
||||
/*scale_hh=*/scale_hh,
|
||||
/*zero_point_ih=*/zero_point_ih,
|
||||
/*zero_point_hh=*/zero_point_hh);
|
||||
}
|
||||
};
|
||||
|
||||
c10::intrusive_ptr<CellParamsBase> make_quantized_cell_params(
|
||||
const at::Tensor& w_ih,
|
||||
const at::Tensor& w_hh,
|
||||
at::Tensor b_ih,
|
||||
at::Tensor b_hh) {
|
||||
auto make_vals = [&](const at::Tensor& W) {
|
||||
auto params = at::native::fbgemm_linear_quantize_weight(W);
|
||||
at::Tensor packed_weight =
|
||||
at::native::fbgemm_pack_quantized_matrix(std::get<0>(params));
|
||||
return std::tuple_cat(
|
||||
std::make_tuple(std::move(packed_weight)), std::move(params));
|
||||
};
|
||||
|
||||
at::Tensor qw_ih, qw_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh;
|
||||
at::Scalar scale_ih, scale_hh, zero_point_ih, zero_point_hh;
|
||||
|
||||
std::tie(packed_ih, qw_ih, col_offsets_ih, scale_ih, zero_point_ih) =
|
||||
make_vals(w_ih);
|
||||
std::tie(packed_hh, qw_hh, col_offsets_hh, scale_hh, zero_point_hh) =
|
||||
make_vals(w_hh);
|
||||
|
||||
return c10::make_intrusive<QuantizedCellParams>(
|
||||
/*qw_ih=*/std::move(qw_ih),
|
||||
/*qw_hh=*/std::move(qw_hh),
|
||||
/*b_ih=*/std::move(b_ih),
|
||||
/*b_hh=*/std::move(b_hh),
|
||||
/*packed_ih=*/std::move(packed_ih),
|
||||
/*packed_hh=*/std::move(packed_hh),
|
||||
/*col_offsets_ih=*/std::move(col_offsets_ih),
|
||||
/*col_offsets_hh=*/std::move(col_offsets_hh),
|
||||
/*scale_ih=*/std::move(scale_ih),
|
||||
/*scale_hh=*/std::move(scale_hh),
|
||||
/*zero_point_ih=*/std::move(zero_point_ih),
|
||||
/*zero_point_hh=*/std::move(zero_point_hh));
|
||||
}
|
||||
|
||||
// QuantizedCellParams vs. QuantizedCellParamsDynamic
|
||||
//
|
||||
|
|
@ -536,7 +377,6 @@ static std::unordered_map<
|
|||
std::string,
|
||||
c10::intrusive_ptr<CellParamsBase> (*)(CellParamsSerializationType)>
|
||||
cell_params_deserializers = {
|
||||
{"quantized", &QuantizedCellParams::__setstate__},
|
||||
{"quantized_dynamic", &QuantizedCellParamsDynamic::__setstate__},
|
||||
{"quantized_fp16", &QuantizedCellParamsFP16::__setstate__}};
|
||||
|
||||
|
|
@ -1841,38 +1681,6 @@ static std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
|
|||
"using the newer definitions in torch.jit.quantized");
|
||||
}
|
||||
|
||||
#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \
|
||||
return_type name( \
|
||||
const Tensor& input, \
|
||||
hx_type hx, \
|
||||
const Tensor& w_ih, \
|
||||
const Tensor& w_hh, \
|
||||
const Tensor& b_ih, \
|
||||
const Tensor& b_hh, \
|
||||
const Tensor& packed_ih, \
|
||||
const Tensor& packed_hh, \
|
||||
const Tensor& col_offsets_ih, \
|
||||
const Tensor& col_offsets_hh, \
|
||||
const Scalar& scale_ih, \
|
||||
const Scalar& scale_hh, \
|
||||
const Scalar& zero_point_ih, \
|
||||
const Scalar& zero_point_hh) { \
|
||||
QuantizedCellParams params( \
|
||||
w_ih, \
|
||||
w_hh, \
|
||||
b_ih, \
|
||||
b_hh, \
|
||||
packed_ih, \
|
||||
packed_hh, \
|
||||
col_offsets_ih, \
|
||||
col_offsets_hh, \
|
||||
scale_ih, \
|
||||
scale_hh, \
|
||||
zero_point_ih, \
|
||||
zero_point_hh); \
|
||||
return cell_type{}( \
|
||||
input, prepare_hx_fn(hx), params); \
|
||||
}
|
||||
// Set reduced range to be True for all RNN Cells by default. This flag is used only for FBGEMM kernels
|
||||
// QNNPACK does not reduce range for activations
|
||||
#define DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(name, hx_type, cell_type, return_type, prepare_hx_fn) \
|
||||
|
|
@ -1895,7 +1703,6 @@ return_type name( \
|
|||
}
|
||||
|
||||
// Quantized LSTM cell
|
||||
using quantized_lstm_cell_type = LSTMCell<QuantizedCellParams>;
|
||||
using quantized_lstm_return_type = std::tuple<Tensor, Tensor>;
|
||||
static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
|
||||
return std::make_tuple(hx[0], hx[1]);
|
||||
|
|
@ -1904,7 +1711,6 @@ static std::tuple<Tensor, Tensor> prepare_quantized_lstm_hx(TensorList hx) {
|
|||
// Quantized LSTM cell
|
||||
using quantized_lstm_cell_dynamic_type = LSTMCell<QuantizedCellParamsDynamic>;
|
||||
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_lstm_cell, TensorList, quantized_lstm_cell_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
|
||||
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_lstm_cell_dynamic, TensorList, quantized_lstm_cell_dynamic_type, quantized_lstm_return_type, prepare_quantized_lstm_hx);
|
||||
|
||||
|
|
@ -1915,22 +1721,15 @@ static simple_hx_type prepare_quantized_hx(simple_hx_type hx) {
|
|||
}
|
||||
|
||||
// Quantized GRU cell
|
||||
using quantized_gru_cell_type = GRUCell<QuantizedCellParams>;
|
||||
using quantized_gru_cell_dynamic_type = GRUCell<QuantizedCellParamsDynamic>;
|
||||
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_gru_cell, simple_hx_type, quantized_gru_cell_type, Tensor, prepare_quantized_hx);
|
||||
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_gru_cell_dynamic, simple_hx_type, quantized_gru_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
|
||||
// Quantized RNN w/ ReLU cell
|
||||
using quantized_rnn_relu_cell_type = SimpleCell<relu_f, QuantizedCellParams>;
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_relu_cell, simple_hx_type, quantized_rnn_relu_cell_type, Tensor, prepare_quantized_hx);
|
||||
using quantized_rnn_relu_cell_dynamic_type = SimpleCell<relu_f, QuantizedCellParamsDynamic>;
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_relu_cell_dynamic, simple_hx_type, quantized_rnn_relu_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
|
||||
// Quantized RNN w/ tanh cell
|
||||
using quantized_rnn_tanh_cell_type = SimpleCell<tanh_f, QuantizedCellParams>;
|
||||
DEFINE_QUANTIZED_RNN_CELL(quantized_rnn_tanh_cell, simple_hx_type, quantized_rnn_tanh_cell_type, Tensor, prepare_quantized_hx);
|
||||
using quantized_rnn_tanh_cell_dynamic_type = SimpleCell<tanh_f, QuantizedCellParamsDynamic>;
|
||||
static DEFINE_QUANTIZED_RNN_CELL_DYNAMIC(quantized_rnn_tanh_cell_dynamic, simple_hx_type, quantized_rnn_tanh_cell_dynamic_type, Tensor, prepare_quantized_hx);
|
||||
|
||||
|
|
@ -1972,7 +1771,6 @@ TORCH_LIBRARY_FRAGMENT(aten, m) {
|
|||
TORCH_LIBRARY_FRAGMENT(quantized, m) {
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_dynamic(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh, bool reduce_range=False) -> __torch__.torch.classes.rnn.CellParamsBase"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params_fp16(__torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::make_quantized_cell_params(Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh) -> __torch__.torch.classes.rnn.CellParamsBase"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_lstm_cell_dynamic(Tensor input, Tensor[] hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor bias_ih, Tensor bias_hh) -> (Tensor, Tensor)"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_gru_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::quantized_rnn_relu_cell_dynamic(Tensor input, Tensor hx, __torch__.torch.classes.quantized.LinearPackedParamsBase w_ih, __torch__.torch.classes.quantized.LinearPackedParamsBase w_hh, Tensor b_ih, Tensor b_hh) -> Tensor"));
|
||||
|
|
@ -1992,7 +1790,6 @@ TORCH_LIBRARY_IMPL(aten, CPU, m) {
|
|||
|
||||
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params_dynamic"), TORCH_FN(make_quantized_cell_params_dynamic));
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::make_quantized_cell_params"), TORCH_FN(make_quantized_cell_params));
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_lstm_cell_dynamic"), TORCH_FN(quantized_lstm_cell_dynamic));
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_gru_cell_dynamic"), TORCH_FN(quantized_gru_cell_dynamic));
|
||||
m.impl(TORCH_SELECTIVE_NAME("quantized::quantized_rnn_relu_cell_dynamic"), TORCH_FN(quantized_rnn_relu_cell_dynamic));
|
||||
|
|
|
|||
|
|
@ -3285,22 +3285,6 @@
|
|||
dispatch:
|
||||
CUDA: _mixed_dtypes_linear
|
||||
|
||||
- func: fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)
|
||||
|
||||
- func: fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
|
||||
- func: fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor
|
||||
|
||||
- func: fbgemm_pack_quantized_matrix(Tensor input) -> Tensor
|
||||
|
||||
- func: fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor
|
||||
|
||||
- func: ldexp.Tensor(Tensor self, Tensor other) -> Tensor
|
||||
variants: function, method
|
||||
|
||||
|
|
@ -7611,15 +7595,6 @@
|
|||
# - func: quantized_gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)
|
||||
#
|
||||
|
||||
# Quantized RNN cells
|
||||
- func: quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)
|
||||
|
||||
- func: quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
|
||||
|
||||
- func: quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
|
||||
|
||||
- func: quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor
|
||||
|
||||
# PackedSequence utilities
|
||||
- func: _pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -1324,7 +1324,6 @@ aten_native_source_non_codegen_list = [
|
|||
"aten/src/ATen/native/PointwiseOps.cpp",
|
||||
"aten/src/ATen/native/Pooling.cpp",
|
||||
"aten/src/ATen/native/Pow.cpp",
|
||||
"aten/src/ATen/native/QuantizedLinear.cpp",
|
||||
"aten/src/ATen/native/RNN.cpp",
|
||||
"aten/src/ATen/native/RangeFactories.cpp",
|
||||
"aten/src/ATen/native/ReduceAllOps.cpp",
|
||||
|
|
|
|||
|
|
@ -1131,7 +1131,6 @@ set(ATen_CPU_INCLUDE
|
|||
${CMAKE_BINARY_DIR}/aten/src)
|
||||
|
||||
if(CMAKE_CXX_COMPILER_ID MATCHES "Clang" OR CMAKE_CXX_COMPILER_ID STREQUAL "GNU")
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/QuantizedLinear.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/RNN.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
set_source_files_properties(${CMAKE_CURRENT_SOURCE_DIR}/../aten/src/ATen/native/quantized/qlinear_unpack.cpp PROPERTIES COMPILE_FLAGS -Wno-deprecated-declarations)
|
||||
|
|
|
|||
|
|
@ -312,6 +312,18 @@ ALLOW_LIST = [
|
|||
("aten::batch_norm_backward_elemt.out", datetime.date(2023, 12, 31)),
|
||||
("aten::batch_norm_backward_elemt", datetime.date(2023, 12, 31)),
|
||||
("aten::sym_constrain_range", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_linear_int8_weight_fp32_activation", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_linear_int8_weight", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_linear_quantize_weight", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_pack_gemm_matrix_fp16", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_linear_fp16_weight_fp32_activation", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_linear_fp16_weight", datetime.date(2023, 12, 31)),
|
||||
("aten::fbgemm_pack_quantized_matrix", datetime.date(2023, 12, 31)),
|
||||
("aten::quantized_lstm_cell", datetime.date(2023, 12, 31)),
|
||||
("aten::quantized_gru_cell", datetime.date(2023, 12, 31)),
|
||||
("aten::quantized_rnn_relu_cell", datetime.date(2023, 12, 31)),
|
||||
("aten::quantized_rnn_tanh_cell", datetime.date(2023, 12, 31)),
|
||||
("quantized::make_quantized_cell_params", datetime.date(2023, 12, 31)),
|
||||
]
|
||||
|
||||
ALLOW_LIST_COMPILED = [
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|||
sys.path.append(pytorch_test_dir)
|
||||
from torch.testing._internal.jit_utils import JitTestCase, RUN_CUDA
|
||||
from torch.testing._internal.common_utils import slowTest, suppress_warnings
|
||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM
|
||||
|
||||
if __name__ == '__main__':
|
||||
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
||||
|
|
@ -407,12 +406,6 @@ class TestModels(JitTestCase):
|
|||
def test_snli(self):
|
||||
self._test_snli(self, device='cpu')
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
# Suppression: this exercises a deprecated API
|
||||
@suppress_warnings
|
||||
def test_snli_quantized(self):
|
||||
self._test_snli(self, device='cpu', quantized=True)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_snli_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
|
|
@ -553,12 +546,6 @@ class TestModels(JitTestCase):
|
|||
def test_vae(self):
|
||||
self._test_vae(self, device='cpu')
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
# Suppression: this exercises a deprecated API
|
||||
@suppress_warnings
|
||||
def test_vae_quantized(self):
|
||||
self._test_vae(self, device='cpu', quantized=True)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||
def test_vae_cuda(self):
|
||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||
|
|
|
|||
|
|
@ -3200,81 +3200,6 @@ class TestDynamicQuantizedOps(TestCase):
|
|||
self.assertEqual(Y_fp32, Y_fp32_ref,
|
||||
msg="torch.ops.quantized.linear_dynamic results are off")
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
@given(
|
||||
batch_size=st.integers(1, 4),
|
||||
input_channels=st.integers(16, 32),
|
||||
output_channels=st.integers(4, 8),
|
||||
)
|
||||
def test_qlinear_legacy(self, batch_size, input_channels, output_channels):
|
||||
X_scale = 1.0
|
||||
X_zp = 0
|
||||
X_value_min = 0
|
||||
X_value_max = 255
|
||||
X_q0 = np.round(np.random.rand(batch_size, input_channels) * (
|
||||
X_value_max - X_value_min) + X_value_min
|
||||
).astype(np.uint8)
|
||||
X_q0[0, 0] = X_value_min
|
||||
X_q0[0, 1] = X_value_max
|
||||
|
||||
W_scale = 1.0
|
||||
W_zp = 0
|
||||
W_value_min = -128
|
||||
W_value_max = 127
|
||||
W_q0 = np.round(
|
||||
np.random.rand(output_channels, input_channels)
|
||||
* (W_value_max - W_value_min)
|
||||
+ W_value_min
|
||||
).astype(np.int8)
|
||||
W_q0[0, 0] = W_value_min
|
||||
W_q0[1, 0] = W_value_max
|
||||
|
||||
b_value_min = -10
|
||||
b_value_max = 10
|
||||
b_q0 = np.round(
|
||||
np.random.rand(output_channels) * (b_value_max - b_value_min) +
|
||||
b_value_min
|
||||
).astype(np.int32)
|
||||
|
||||
avoid_vpmaddubsw_overflow_linear(
|
||||
batch_size,
|
||||
input_channels,
|
||||
output_channels,
|
||||
X_q0,
|
||||
X_value_min,
|
||||
X_value_max,
|
||||
W_q0,
|
||||
W_value_min,
|
||||
W_value_max,
|
||||
)
|
||||
|
||||
X_fp32 = torch.from_numpy(_dequantize(X_q0, X_scale, X_zp)).to(dtype=torch.float)
|
||||
W_fp32 = torch.from_numpy(_dequantize(W_q0, W_scale, W_zp)).to(dtype=torch.float)
|
||||
b_fp32 = torch.from_numpy(
|
||||
_dequantize(b_q0, X_scale * W_scale, 0)
|
||||
).to(dtype=torch.float)
|
||||
|
||||
W_scale, W_zp = _calculate_dynamic_qparams(W_fp32, torch.qint8)
|
||||
W_q = torch.quantize_per_tensor(W_fp32, scale=W_scale, zero_point=W_zp, dtype=torch.qint8)
|
||||
|
||||
# Observe X_fp32 and determine X_scale and X_zero_point, this should match
|
||||
# internals of dynamic linear.
|
||||
X_scale, X_zp = _calculate_dynamic_qparams(X_fp32, torch.quint8)
|
||||
X_q = torch.quantize_per_tensor(X_fp32, scale=X_scale, zero_point=X_zp, dtype=torch.quint8)
|
||||
|
||||
W_int8, col_offsets, W_scale, W_zp = torch.fbgemm_linear_quantize_weight(W_q.dequantize())
|
||||
W_prepack = torch.fbgemm_pack_quantized_matrix(W_int8.clone(), W_int8.size(1), W_int8.size(0))
|
||||
# Quantized Linear operator with prepacked weight
|
||||
Y_fp32 = torch.fbgemm_linear_int8_weight(
|
||||
X_q.dequantize(), W_q.dequantize(), W_prepack, col_offsets,
|
||||
W_scale, W_zp, b_fp32)
|
||||
|
||||
Y_fp32_ref = F.linear(X_q.dequantize(), W_q.dequantize(), b_fp32)
|
||||
# Y_fp32_ref = F.linear(X_fp32, W_fp32, b_fp32)
|
||||
|
||||
self.assertEqual(Y_fp32, Y_fp32_ref,
|
||||
msg="torch.ops.quantized.fbgemm_linear_dynamic results are off")
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
@given(
|
||||
input_channels=st.integers(16, 32),
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@ import torch
|
|||
from torch.testing._internal.common_quantization import (
|
||||
skipIfNoFBGEMM
|
||||
)
|
||||
from torch.testing._internal.common_utils import suppress_warnings
|
||||
from torch.testing._internal.jit_utils import JitTestCase
|
||||
|
||||
from typing import Tuple
|
||||
import copy
|
||||
|
||||
class TestDeprecatedJitQuantized(JitTestCase):
|
||||
@skipIfNoFBGEMM
|
||||
|
|
@ -54,54 +51,8 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
|||
torch.tensor(vals, dtype=torch.float),
|
||||
requires_grad=False)
|
||||
|
||||
ref = copy.deepcopy(cell)
|
||||
|
||||
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
|
||||
x = torch.tensor([[100, -155],
|
||||
[-155, 100],
|
||||
[100, -155]], dtype=torch.float)
|
||||
h0_vals = [[-155, 100],
|
||||
[-155, 155],
|
||||
[100, -155]]
|
||||
hx = torch.tensor(h0_vals, dtype=torch.float)
|
||||
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
||||
cx = torch.tensor(h0_vals, dtype=torch.float)
|
||||
hiddens = (hx, cx)
|
||||
else:
|
||||
hiddens = hx
|
||||
|
||||
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
||||
class ScriptWrapper(torch.jit.ScriptModule):
|
||||
def __init__(self, cell):
|
||||
super().__init__()
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x: torch.Tensor,
|
||||
hiddens: Tuple[torch.Tensor, torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.cell(x, hiddens)
|
||||
else:
|
||||
|
||||
class ScriptWrapper(torch.jit.ScriptModule):
|
||||
def __init__(self, cell):
|
||||
super().__init__()
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
|
||||
return self.cell(x, hiddens)
|
||||
|
||||
cell = ScriptWrapper(cell)
|
||||
outs = cell(x, hiddens)
|
||||
cell = self.getExportImportCopyWithPacking(cell)
|
||||
|
||||
outs = cell(x, hiddens)
|
||||
ref_outs = ref(x, hiddens)
|
||||
|
||||
self.assertEqual(len(outs), len(ref_outs))
|
||||
for out, ref_out in zip(outs, ref_outs):
|
||||
torch.testing.assert_close(out, ref_out)
|
||||
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_cell_modules function is no longer supported"):
|
||||
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_rnn_quantized(self):
|
||||
|
|
@ -143,85 +94,14 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
|||
torch.tensor(vals, dtype=torch.float),
|
||||
requires_grad=False)
|
||||
|
||||
ref = copy.deepcopy(cell)
|
||||
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
|
||||
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
|
||||
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
|
||||
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
|
||||
|
||||
niter = 10
|
||||
x = torch.tensor([[100, -155],
|
||||
[-155, 100],
|
||||
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
|
||||
h0_vals = [[-155, 100],
|
||||
[-155, 155],
|
||||
[100, -155]]
|
||||
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
||||
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
||||
with self.assertRaisesRegex(RuntimeError, "quantize_rnn_modules function is no longer supported"):
|
||||
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
|
||||
|
||||
if isinstance(ref, torch.nn.LSTM):
|
||||
hiddens = (hx, cx)
|
||||
elif isinstance(ref, torch.nn.GRU):
|
||||
hiddens = hx
|
||||
|
||||
ref_out, ref_hid = ref(x, hiddens)
|
||||
|
||||
# Compare int8 quantized to unquantized
|
||||
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
|
||||
|
||||
torch.testing.assert_close(output_int8, ref_out)
|
||||
for out, ref in zip(final_hiddens_int8, ref_hid):
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
# Compare fp16 quantized to unquantized
|
||||
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
|
||||
|
||||
torch.testing.assert_close(output_fp16, ref_out)
|
||||
for out, ref in zip(final_hiddens_fp16, ref_hid):
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
def compare_quantized_unquantized(ScriptWrapper, cell):
|
||||
wrapper = ScriptWrapper(cell)
|
||||
|
||||
# Compare quantize scripted module to unquantized
|
||||
script_out, script_hid = wrapper(x, hiddens)
|
||||
torch.testing.assert_close(script_out, ref_out)
|
||||
for out, ref in zip(script_hid, ref_hid):
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
# Compare export/import to unquantized
|
||||
export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
|
||||
ei_out, ei_hid = export_import_wrapper(x, hiddens)
|
||||
torch.testing.assert_close(ei_out, ref_out)
|
||||
for out, ref in zip(ei_hid, ref_hid):
|
||||
torch.testing.assert_close(out, ref)
|
||||
|
||||
if isinstance(cell, torch.jit.quantized.QuantizedGRU):
|
||||
class ScriptWrapper(torch.jit.ScriptModule):
|
||||
def __init__(self, cell):
|
||||
super().__init__()
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
return self.cell(x, hiddens)
|
||||
|
||||
compare_quantized_unquantized(ScriptWrapper, cell)
|
||||
elif isinstance(cell, torch.jit.quantized.QuantizedLSTM):
|
||||
for cell in [cell_int8, cell_fp16]:
|
||||
class ScriptWrapper(torch.jit.ScriptModule):
|
||||
def __init__(self, cell):
|
||||
super().__init__()
|
||||
self.cell = cell
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, x, hiddens):
|
||||
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
|
||||
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
||||
return self.cell(x, hiddens)
|
||||
compare_quantized_unquantized(ScriptWrapper, cell)
|
||||
|
||||
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
||||
# Suppression: using deprecated quant api
|
||||
@suppress_warnings
|
||||
def test_quantization_modules(self):
|
||||
K1, N1 = 2, 2
|
||||
|
||||
|
|
@ -244,18 +124,11 @@ class TestDeprecatedJitQuantized(JitTestCase):
|
|||
|
||||
y_ref = fb(value)
|
||||
|
||||
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
|
||||
traced_int8 = torch.jit.trace(fb_int8, (x,))
|
||||
fb_int8 = self.getExportImportCopyWithPacking(traced_int8)
|
||||
y_int8 = fb_int8(value)
|
||||
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
|
||||
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
|
||||
|
||||
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
|
||||
traced_fp16 = torch.jit.trace(fb_fp16, (x,))
|
||||
fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
|
||||
y_fp16 = fb_fp16(value)
|
||||
|
||||
torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3)
|
||||
torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
|
||||
with self.assertRaisesRegex(RuntimeError, "quantize_linear_modules function is no longer supported"):
|
||||
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_erase_class_tensor_shapes(self):
|
||||
|
|
|
|||
|
|
@ -2112,25 +2112,6 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
res_bf16 = F.threshold(x.to(dtype=dtype), threshold, 0).float()
|
||||
self.assertEqual(res_bf16, expected)
|
||||
|
||||
@unittest.skipUnless('fbgemm' in torch.backends.quantized.supported_engines,
|
||||
'Linear_FP16_weight requires FBGEMM. FBGEMM is only optimized for CPUs'
|
||||
' with instruction set support avx2 or newer.')
|
||||
def test_fb_fc_packed(self):
|
||||
X = np.random.rand(16, 16).astype(np.float32) - 0.5
|
||||
W = np.random.rand(16, 16).astype(np.float32) - 0.5
|
||||
b = np.random.rand(16).astype(np.float32) - 0.5
|
||||
|
||||
def fc_op(X, W, b):
|
||||
return np.dot(X, W.T) + b
|
||||
|
||||
x_tensor = torch.tensor(X)
|
||||
w_tensor = torch.tensor(W)
|
||||
b_tensor = torch.tensor(b)
|
||||
packed_w_tensor = torch.fbgemm_pack_gemm_matrix_fp16(w_tensor)
|
||||
actual_output = torch.fbgemm_linear_fp16_weight(x_tensor, packed_w_tensor, b_tensor)
|
||||
expected_output = fc_op(X, W, b)
|
||||
torch.testing.assert_close(torch.from_numpy(expected_output), actual_output.cpu(), atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_pad_scalar_error(self):
|
||||
inputs = torch.tensor(0., requires_grad=True)
|
||||
self.assertRaises(RuntimeError, lambda: F.pad(inputs, (1, 1)))
|
||||
|
|
|
|||
|
|
@ -1,798 +1,99 @@
|
|||
import warnings
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import _VF, Tensor # noqa: F401
|
||||
from torch.nn.utils.rnn import PackedSequence
|
||||
|
||||
|
||||
class QuantizedLinear(torch.jit.ScriptModule):
|
||||
__constants__ = ["scale", "zero_point"]
|
||||
|
||||
def __init__(self, other):
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedLinear is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedLinear is no longer supported. Please use "
|
||||
"torch.ao.nn.quantized.dynamic.Linear instead."
|
||||
)
|
||||
|
||||
self.in_features = other.in_features
|
||||
self.out_features = other.out_features
|
||||
# Quantize weight and discard the original
|
||||
(
|
||||
self.weight,
|
||||
self.col_offsets,
|
||||
self.scale,
|
||||
self.zero_point,
|
||||
) = torch.fbgemm_linear_quantize_weight(
|
||||
other.weight.clone(memory_format=torch.contiguous_format).float()
|
||||
)
|
||||
self.weight = torch.nn.Parameter(self.weight, requires_grad=False)
|
||||
self.col_offsets = torch.nn.Parameter(self.col_offsets, requires_grad=False)
|
||||
assert other.bias is not None, "QuantizedLinear requires a bias"
|
||||
self.bias = torch.nn.Parameter(
|
||||
other.bias.clone(memory_format=torch.contiguous_format).float(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
self.register_buffer(
|
||||
"packed_tensor_ptr",
|
||||
torch.fbgemm_pack_quantized_matrix(
|
||||
self.weight.clone(memory_format=torch.contiguous_format)
|
||||
),
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def _unpack(self):
|
||||
self.packed_tensor_ptr.set_(torch.fbgemm_pack_quantized_matrix(self.weight))
|
||||
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
self.packed_tensor_ptr.set_(
|
||||
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
out = torch.fbgemm_linear_int8_weight_fp32_activation(
|
||||
input.float(),
|
||||
self.weight,
|
||||
self.packed_tensor_ptr,
|
||||
self.col_offsets,
|
||||
self.scale,
|
||||
self.zero_point,
|
||||
self.bias,
|
||||
)
|
||||
return out.to(input.dtype)
|
||||
|
||||
def extra_repr(self):
|
||||
repr = (
|
||||
"in_features={in_features}, out_features={out_features}, "
|
||||
"scale={scale}, zero_point={zero_point}".format(**self.__dict__)
|
||||
)
|
||||
return repr
|
||||
|
||||
|
||||
# FP16 weights
|
||||
class QuantizedLinearFP16(torch.jit.ScriptModule):
|
||||
def __init__(self, other):
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedLinearFP16 is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.Linear instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedLinearFP16 is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.Linear instead."
|
||||
)
|
||||
self.in_features = other.in_features
|
||||
self.out_features = other.out_features
|
||||
self.original_weight = other.weight
|
||||
self.weight = torch.fbgemm_pack_gemm_matrix_fp16(
|
||||
other.weight.clone(memory_format=torch.contiguous_format).float()
|
||||
)
|
||||
assert other.bias is not None, "QuantizedLinearFP16 requires a bias"
|
||||
self.bias = torch.nn.Parameter(
|
||||
other.bias.clone(memory_format=torch.contiguous_format).float(),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.register_buffer("packed_weight", self.weight)
|
||||
|
||||
@torch.jit.script_method
|
||||
def _unpack(self):
|
||||
self.packed_weight.set_(
|
||||
torch.fbgemm_pack_gemm_matrix_fp16(self.original_weight)
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
self.packed_weight.set_(
|
||||
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input):
|
||||
out = torch.fbgemm_linear_fp16_weight_fp32_activation(
|
||||
input.float(), self.packed_weight, self.bias
|
||||
)
|
||||
return out
|
||||
|
||||
def extra_repr(self):
|
||||
repr = "in_features={in_features}, out_features={out_features}, ".format(
|
||||
**self.__dict__
|
||||
)
|
||||
return repr
|
||||
|
||||
|
||||
# Quantized RNN cell implementations
|
||||
class QuantizedRNNCellBase(torch.jit.ScriptModule):
|
||||
__constants__ = [
|
||||
"input_size",
|
||||
"hidden_size",
|
||||
"bias",
|
||||
"scale_hh",
|
||||
"scale_ih",
|
||||
"zero_point_ih",
|
||||
"zero_point_hh",
|
||||
]
|
||||
|
||||
def __init__(self, other):
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedRNNCellBase is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
|
||||
)
|
||||
|
||||
self.input_size = other.input_size
|
||||
self.hidden_size = other.hidden_size
|
||||
self.bias = other.bias
|
||||
if not self.bias:
|
||||
raise ValueError("Quantized RNN cells require bias terms")
|
||||
|
||||
(
|
||||
weight_ih,
|
||||
col_offsets_ih,
|
||||
self.scale_ih,
|
||||
self.zero_point_ih,
|
||||
) = torch.fbgemm_linear_quantize_weight(
|
||||
other.weight_ih.clone(memory_format=torch.contiguous_format).float()
|
||||
)
|
||||
self.register_buffer("weight_ih", weight_ih)
|
||||
self.register_buffer("col_offsets_ih", col_offsets_ih)
|
||||
(
|
||||
weight_hh,
|
||||
col_offsets_hh,
|
||||
self.scale_hh,
|
||||
self.zero_point_hh,
|
||||
) = torch.fbgemm_linear_quantize_weight(
|
||||
other.weight_hh.clone(memory_format=torch.contiguous_format).float()
|
||||
)
|
||||
self.register_buffer("weight_hh", weight_hh)
|
||||
self.register_buffer("col_offsets_hh", col_offsets_hh)
|
||||
|
||||
packed_ih = torch.fbgemm_pack_quantized_matrix(self.weight_ih)
|
||||
self.register_buffer("packed_ih", packed_ih)
|
||||
packed_hh = torch.fbgemm_pack_quantized_matrix(self.weight_hh)
|
||||
self.register_buffer("packed_hh", packed_hh)
|
||||
|
||||
self.bias_ih = torch.nn.Parameter(
|
||||
other.bias_ih.clone(memory_format=torch.contiguous_format).float(),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.bias_hh = torch.nn.Parameter(
|
||||
other.bias_hh.clone(memory_format=torch.contiguous_format).float(),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def extra_repr(self):
|
||||
s = "{input_size}, {hidden_size}"
|
||||
if "bias" in self.__dict__ and self.bias is not True:
|
||||
s += ", bias={bias}"
|
||||
if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh":
|
||||
s += ", nonlinearity={nonlinearity}"
|
||||
return s.format(**self.__dict__)
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_forward_input(self, input):
|
||||
if input.size(1) != self.input_size:
|
||||
raise RuntimeError(
|
||||
f"input has inconsistent input_size: got {input.size(1)}, expected {self.input_size}"
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_forward_hidden(
|
||||
self, input: Tensor, hx: Tensor, hidden_label: str = ""
|
||||
) -> None:
|
||||
if input.size(0) != hx.size(0):
|
||||
raise RuntimeError(
|
||||
f"Input batch size {input.size(0)} doesn't match hidden{hidden_label} batch size {hx.size(0)}"
|
||||
)
|
||||
|
||||
if hx.size(1) != self.hidden_size:
|
||||
raise RuntimeError(
|
||||
f"hidden{hidden_label} has inconsistent hidden_size: got {hx.size(1)}, expected {self.hidden_size}"
|
||||
)
|
||||
|
||||
# TODO: for some reason weak_script_method causes a destruction of the
|
||||
# module to occur, which in turn frees the packed_ih object via its DataPtr
|
||||
# deleter. This is bizarre and should probably get fixed.
|
||||
# @torch._jit_internal.weak_script_method
|
||||
@torch.jit.script_method
|
||||
def _unpack(self):
|
||||
self.packed_ih.set_(torch.fbgemm_pack_quantized_matrix(self.weight_ih))
|
||||
self.packed_hh.set_(torch.fbgemm_pack_quantized_matrix(self.weight_hh))
|
||||
|
||||
# @torch._jit_internal.weak_script_method
|
||||
@torch.jit.script_method
|
||||
def _pack(self):
|
||||
self.packed_ih.set_(
|
||||
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
|
||||
)
|
||||
self.packed_hh.set_(
|
||||
torch.zeros(torch.jit.annotate(List[int], []), dtype=torch.uint8).detach()
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedRNNCellBase is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
|
||||
)
|
||||
|
||||
|
||||
class QuantizedRNNCell(QuantizedRNNCellBase):
|
||||
__constants__ = [
|
||||
"input_size",
|
||||
"hidden_size",
|
||||
"bias",
|
||||
"scale_hh",
|
||||
"scale_ih",
|
||||
"zero_point_ih",
|
||||
"zero_point_hh",
|
||||
"nonlinearity",
|
||||
]
|
||||
|
||||
def __init__(self, other):
|
||||
super().__init__(other)
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedRNNCell is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedRNNCell is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.RNNCell instead."
|
||||
)
|
||||
self.nonlinearity = other.nonlinearity
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
hx = torch.zeros(
|
||||
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.check_forward_hidden(input, hx, "")
|
||||
if self.nonlinearity == "tanh":
|
||||
ret = _VF.quantized_rnn_tanh_cell(
|
||||
input,
|
||||
hx,
|
||||
self.weight_ih,
|
||||
self.weight_hh,
|
||||
self.bias_ih,
|
||||
self.bias_hh,
|
||||
self.packed_ih,
|
||||
self.packed_hh,
|
||||
self.col_offsets_ih,
|
||||
self.col_offsets_hh,
|
||||
self.scale_ih,
|
||||
self.scale_hh,
|
||||
self.zero_point_ih,
|
||||
self.zero_point_hh,
|
||||
)
|
||||
elif self.nonlinearity == "relu":
|
||||
ret = _VF.quantized_rnn_relu_cell(
|
||||
input,
|
||||
hx,
|
||||
self.weight_ih,
|
||||
self.weight_hh,
|
||||
self.bias_ih,
|
||||
self.bias_hh,
|
||||
self.packed_ih,
|
||||
self.packed_hh,
|
||||
self.col_offsets_ih,
|
||||
self.col_offsets_hh,
|
||||
self.scale_ih,
|
||||
self.scale_hh,
|
||||
self.zero_point_ih,
|
||||
self.zero_point_hh,
|
||||
)
|
||||
else:
|
||||
ret = input # TODO: remove when jit supports exception flow
|
||||
raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}")
|
||||
return ret
|
||||
|
||||
|
||||
class QuantizedLSTMCell(QuantizedRNNCellBase):
|
||||
def __init__(self, other):
|
||||
super().__init__(other)
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedLSTMCell is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
zeros = torch.zeros(
|
||||
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
||||
)
|
||||
hx = (zeros, zeros)
|
||||
self.check_forward_hidden(input, hx[0], "[0]")
|
||||
self.check_forward_hidden(input, hx[1], "[1]")
|
||||
return _VF.quantized_lstm_cell(
|
||||
input,
|
||||
hx,
|
||||
self.weight_ih,
|
||||
self.weight_hh,
|
||||
self.bias_ih,
|
||||
self.bias_hh,
|
||||
self.packed_ih,
|
||||
self.packed_hh,
|
||||
self.col_offsets_ih,
|
||||
self.col_offsets_hh,
|
||||
self.scale_ih,
|
||||
self.scale_hh,
|
||||
self.zero_point_ih,
|
||||
self.zero_point_hh,
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedLSTMCell is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.LSTMCell instead."
|
||||
)
|
||||
|
||||
|
||||
class QuantizedGRUCell(QuantizedRNNCellBase):
|
||||
def __init__(self, other):
|
||||
super().__init__(other)
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedGRUCell is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedGRUCell is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.GRUCell instead."
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor:
|
||||
self.check_forward_input(input)
|
||||
if hx is None:
|
||||
hx = torch.zeros(
|
||||
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device
|
||||
)
|
||||
self.check_forward_hidden(input, hx, "")
|
||||
return _VF.quantized_gru_cell(
|
||||
input,
|
||||
hx,
|
||||
self.weight_ih,
|
||||
self.weight_hh,
|
||||
self.bias_ih,
|
||||
self.bias_hh,
|
||||
self.packed_ih,
|
||||
self.packed_hh,
|
||||
self.col_offsets_ih,
|
||||
self.col_offsets_hh,
|
||||
self.scale_ih,
|
||||
self.scale_hh,
|
||||
self.zero_point_ih,
|
||||
self.zero_point_hh,
|
||||
)
|
||||
|
||||
|
||||
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor:
|
||||
return tensor.index_select(dim, permutation)
|
||||
|
||||
|
||||
class QuantizedRNNBase(torch.jit.ScriptModule):
|
||||
__constants__ = [
|
||||
"mode",
|
||||
"input_size",
|
||||
"hidden_size",
|
||||
"num_layers",
|
||||
"bias",
|
||||
"batch_first",
|
||||
"dropout",
|
||||
"bidirectional",
|
||||
"dtype",
|
||||
]
|
||||
|
||||
def __init__(self, other, dtype=torch.int8):
|
||||
super().__init__()
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedRNNBase is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedRNNBase is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic instead."
|
||||
)
|
||||
self.mode = other.mode
|
||||
self.input_size = other.input_size
|
||||
self.hidden_size = other.hidden_size
|
||||
self.num_layers = other.num_layers
|
||||
self.bias = other.bias
|
||||
self.batch_first = other.batch_first
|
||||
if self.mode != "GRU":
|
||||
assert not self.batch_first
|
||||
self.dropout = other.dropout
|
||||
self.bidirectional = other.bidirectional
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
self.dtype = dtype
|
||||
|
||||
assert self.bias
|
||||
|
||||
# TODO: support more than just LSTM
|
||||
if self.mode != "LSTM" and self.mode != "GRU":
|
||||
raise RuntimeError("Only LSTM or GRU is supported for QuantizedRNN")
|
||||
|
||||
if dtype != torch.int8 and dtype != torch.float16:
|
||||
raise RuntimeError(f"Unsupported dtype: {dtype}")
|
||||
|
||||
self.all_weights = []
|
||||
for layer in range(self.num_layers):
|
||||
for direction in range(num_directions):
|
||||
layer_input_size = (
|
||||
self.input_size if layer == 0 else self.hidden_size * num_directions
|
||||
)
|
||||
|
||||
suffix = "_reverse" if direction == 1 else ""
|
||||
|
||||
def get_weight_bias(ihhh):
|
||||
weight_name = f"weight_{ihhh}_l{layer}{suffix}"
|
||||
bias_name = f"bias_{ihhh}_l{layer}{suffix}"
|
||||
|
||||
weight = getattr(other, weight_name)
|
||||
bias = getattr(other, bias_name)
|
||||
return weight, bias
|
||||
|
||||
weight_ih, bias_ih = get_weight_bias("ih")
|
||||
weight_hh, bias_hh = get_weight_bias("hh")
|
||||
|
||||
if dtype == torch.int8:
|
||||
cell_params = torch.ops.quantized.make_quantized_cell_params(
|
||||
weight_ih, weight_hh, bias_ih, bias_hh
|
||||
)
|
||||
else:
|
||||
packed_ih = torch.ops.quantized.linear_prepack_fp16(
|
||||
weight_ih.float(), bias_ih
|
||||
)
|
||||
packed_hh = torch.ops.quantized.linear_prepack_fp16(
|
||||
weight_hh.float(), bias_hh
|
||||
)
|
||||
|
||||
cell_params = torch.ops.quantized.make_quantized_cell_params_fp16(
|
||||
packed_ih, packed_hh
|
||||
)
|
||||
|
||||
setattr(self, f"cell_params_{layer}_{suffix}", cell_params)
|
||||
self.all_weights.append(cell_params)
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None:
|
||||
expected_input_dim = 2 if batch_sizes is not None else 3
|
||||
if input.dim() != expected_input_dim:
|
||||
raise RuntimeError(
|
||||
f"input must have {expected_input_dim} dimensions, got {input.dim()}"
|
||||
)
|
||||
if self.input_size != input.size(-1):
|
||||
raise RuntimeError(
|
||||
f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}"
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def get_expected_hidden_size(
|
||||
self, input: Tensor, batch_sizes: Optional[Tensor]
|
||||
) -> Tuple[int, int, int]:
|
||||
if batch_sizes is not None:
|
||||
mini_batch = int(batch_sizes[0])
|
||||
else:
|
||||
mini_batch = input.size(0) if self.batch_first else input.size(1)
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
expected_hidden_size = (
|
||||
self.num_layers * num_directions,
|
||||
mini_batch,
|
||||
self.hidden_size,
|
||||
)
|
||||
return expected_hidden_size
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_hidden_size(
|
||||
self,
|
||||
hx: Tensor,
|
||||
expected_hidden_size: Tuple[int, int, int],
|
||||
msg: str = "Expected hidden size {}, got {}",
|
||||
) -> None:
|
||||
if hx.size() != expected_hidden_size:
|
||||
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size())))
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_forward_args(
|
||||
self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor]
|
||||
) -> None:
|
||||
self.check_input(input, batch_sizes)
|
||||
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
||||
self.check_hidden_size(
|
||||
hidden, expected_hidden_size, msg="Expected hidden size {}, got {}"
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]) -> Tensor:
|
||||
if permutation is None:
|
||||
return hx
|
||||
return apply_permutation(hx, permutation)
|
||||
|
||||
|
||||
class QuantizedLSTM(QuantizedRNNBase):
|
||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||
|
||||
def __init__(self, other, dtype):
|
||||
super().__init__(other, dtype)
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedLSTM is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedLSTM is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.LSTM instead."
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_impl(
|
||||
self,
|
||||
input: Tensor,
|
||||
hx: Optional[Tuple[Tensor, Tensor]],
|
||||
batch_sizes: Optional[Tensor],
|
||||
max_batch_size: int,
|
||||
sorted_indices: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
zeros = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
hx = (zeros, zeros)
|
||||
else:
|
||||
# Each batch of the hidden state should match the input sequence that
|
||||
# the user believes he/she is passing in.
|
||||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
||||
self.check_forward_args(input, hx, batch_sizes)
|
||||
assert batch_sizes is None
|
||||
result = torch.quantized_lstm(
|
||||
input,
|
||||
hx,
|
||||
self.all_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
float(self.dropout),
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
self.batch_first,
|
||||
dtype=self.dtype,
|
||||
use_dynamic=False,
|
||||
)
|
||||
output = result[0]
|
||||
hidden = result[1:]
|
||||
|
||||
return output, hidden
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_tensor(
|
||||
self, input: Tensor, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[Tensor, Tuple[Tensor, Tensor]]:
|
||||
batch_sizes = None
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
||||
output, hidden = self.forward_impl(
|
||||
input, hx, batch_sizes, max_batch_size, sorted_indices
|
||||
)
|
||||
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_packed(
|
||||
self, input: PackedSequence, hx: Optional[Tuple[Tensor, Tensor]] = None
|
||||
) -> Tuple[PackedSequence, Tuple[Tensor, Tensor]]:
|
||||
input_, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = int(batch_sizes[0])
|
||||
|
||||
output, hidden = self.forward_impl(
|
||||
input_, hx, batch_sizes, max_batch_size, sorted_indices
|
||||
)
|
||||
|
||||
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@torch.jit.script_method
|
||||
def permute_hidden(
|
||||
self, hx: Tuple[Tensor, Tensor], permutation: Optional[Tensor]
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
if permutation is None:
|
||||
return hx
|
||||
return apply_permutation(hx[0], permutation), apply_permutation(
|
||||
hx[1], permutation
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def check_forward_args(
|
||||
self,
|
||||
input: Tensor,
|
||||
hidden: Tuple[Tensor, Tensor],
|
||||
batch_sizes: Optional[Tensor],
|
||||
) -> None:
|
||||
self.check_input(input, batch_sizes)
|
||||
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes)
|
||||
|
||||
self.check_hidden_size(
|
||||
hidden[0], expected_hidden_size, "Expected hidden[0] size {}, got {}"
|
||||
)
|
||||
self.check_hidden_size(
|
||||
hidden[1], expected_hidden_size, "Expected hidden[1] size {}, got {}"
|
||||
)
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
if isinstance(input, PackedSequence):
|
||||
return self.forward_packed(input, hx)
|
||||
else:
|
||||
return self.forward_tensor(input, hx)
|
||||
|
||||
|
||||
class QuantizedGRU(QuantizedRNNBase):
|
||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
warnings.warn(
|
||||
"torch.jit.QuantizedGRU is deprecated and will be removed in an upcoming "
|
||||
"PyTorch release. Please use the torch.ao.nn.quantized.dynamic.GRU instead."
|
||||
raise RuntimeError(
|
||||
"torch.jit.QuantizedGRU is no longer supported. "
|
||||
"Please use the torch.ao.nn.quantized.dynamic.GRU instead."
|
||||
)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_impl(
|
||||
self,
|
||||
input: Tensor,
|
||||
hx: Optional[Tensor],
|
||||
batch_sizes: Optional[Tensor],
|
||||
max_batch_size: int,
|
||||
sorted_indices: Optional[Tensor],
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
if hx is None:
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
hx = torch.zeros(
|
||||
self.num_layers * num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size,
|
||||
dtype=input.dtype,
|
||||
device=input.device,
|
||||
)
|
||||
else:
|
||||
# Each batch of the hidden state should match the input sequence that
|
||||
# the user believes he/she is passing in.
|
||||
hx = self.permute_hidden(hx, sorted_indices)
|
||||
|
||||
self.check_forward_args(input, hx, batch_sizes)
|
||||
if batch_sizes is None:
|
||||
result = torch.quantized_gru(
|
||||
input,
|
||||
hx,
|
||||
self.all_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
float(self.dropout),
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
self.batch_first,
|
||||
)
|
||||
else:
|
||||
result = torch.quantized_gru(
|
||||
input,
|
||||
batch_sizes,
|
||||
hx,
|
||||
self.all_weights,
|
||||
self.bias,
|
||||
self.num_layers,
|
||||
float(self.dropout),
|
||||
self.training,
|
||||
self.bidirectional,
|
||||
)
|
||||
|
||||
output = result[0]
|
||||
hidden = result[1]
|
||||
|
||||
return output, hidden
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_tensor(
|
||||
self, input: Tensor, hx: Optional[Tensor] = None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
batch_sizes = None
|
||||
max_batch_size = input.size(0) if self.batch_first else input.size(1)
|
||||
sorted_indices = None
|
||||
unsorted_indices = None
|
||||
|
||||
output, hidden = self.forward_impl(
|
||||
input, hx, batch_sizes, max_batch_size, sorted_indices
|
||||
)
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
@torch.jit.script_method
|
||||
def forward_packed(
|
||||
self, input: PackedSequence, hx: Optional[Tensor] = None
|
||||
) -> Tuple[PackedSequence, Tensor]:
|
||||
input_, batch_sizes, sorted_indices, unsorted_indices = input
|
||||
max_batch_size = int(batch_sizes[0])
|
||||
|
||||
output, hidden = self.forward_impl(
|
||||
input_, hx, batch_sizes, max_batch_size, sorted_indices
|
||||
)
|
||||
|
||||
output = PackedSequence(output, batch_sizes, sorted_indices, unsorted_indices)
|
||||
return output, self.permute_hidden(hidden, unsorted_indices)
|
||||
|
||||
def forward(self, input, hx=None):
|
||||
if isinstance(input, PackedSequence):
|
||||
return self.forward_packed(input, hx)
|
||||
else:
|
||||
return self.forward_tensor(input, hx)
|
||||
|
||||
|
||||
def quantize_rnn_cell_modules(module):
|
||||
warnings.warn(
|
||||
"quantize_rnn_cell_modules function has been deprecated. "
|
||||
raise RuntimeError(
|
||||
"quantize_rnn_cell_modules function is no longer supported. "
|
||||
"Please use torch.ao.quantization.quantize_dynamic API instead."
|
||||
)
|
||||
reassign = {}
|
||||
for name, mod in module.named_modules():
|
||||
if mod is module:
|
||||
continue
|
||||
new_mod = quantize_rnn_cell_modules(mod)
|
||||
if new_mod is not mod:
|
||||
reassign[name] = new_mod
|
||||
for name, mod in reassign.items():
|
||||
setattr(module, name, mod)
|
||||
if isinstance(module, torch.nn.LSTMCell):
|
||||
return QuantizedLSTMCell(module)
|
||||
if isinstance(module, torch.nn.GRUCell):
|
||||
return QuantizedGRUCell(module)
|
||||
if isinstance(module, torch.nn.RNNCell):
|
||||
return QuantizedRNNCell(module)
|
||||
return module
|
||||
|
||||
|
||||
def quantize_linear_modules(module, dtype=torch.int8):
|
||||
warnings.warn(
|
||||
"quantize_linear_modules function has been deprecated. "
|
||||
raise RuntimeError(
|
||||
"quantize_linear_modules function is no longer supported. "
|
||||
"Please use torch.ao.quantization.quantize_dynamic API instead."
|
||||
)
|
||||
|
||||
reassign = {}
|
||||
for name, mod in module.named_modules():
|
||||
if mod is module:
|
||||
continue
|
||||
new_mod = quantize_linear_modules(mod, dtype)
|
||||
if new_mod is not mod:
|
||||
reassign[name] = new_mod
|
||||
|
||||
for name, mod in reassign.items():
|
||||
setattr(module, name, mod)
|
||||
if isinstance(module, torch.nn.Linear):
|
||||
if dtype == torch.int8:
|
||||
return QuantizedLinear(module)
|
||||
elif dtype == torch.float16:
|
||||
return QuantizedLinearFP16(module)
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported dtype: {dtype}")
|
||||
return module
|
||||
|
||||
|
||||
def quantize_rnn_modules(module, dtype=torch.int8):
|
||||
warnings.warn(
|
||||
"quantize_rnn_modules function has been deprecated. "
|
||||
raise RuntimeError(
|
||||
"quantize_rnn_modules function is no longer supported. "
|
||||
"Please use torch.ao.quantization.quantize_dynamic API instead."
|
||||
)
|
||||
reassign = {}
|
||||
for name, mod in module.named_modules():
|
||||
if mod is module:
|
||||
continue
|
||||
new_mod = quantize_rnn_modules(mod, dtype)
|
||||
if new_mod is not mod:
|
||||
reassign[name] = new_mod
|
||||
|
||||
for name, mod in reassign.items():
|
||||
setattr(module, name, mod)
|
||||
if isinstance(module, torch.nn.LSTM):
|
||||
if dtype != torch.int8 and dtype != torch.float16:
|
||||
raise RuntimeError(f"Unsupported dtype: {dtype}")
|
||||
return QuantizedLSTM(module, dtype)
|
||||
if isinstance(module, torch.nn.GRU):
|
||||
return QuantizedGRU(module)
|
||||
return module
|
||||
|
|
|
|||
|
|
@ -585,14 +585,6 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.fused_moving_avg_obs_fake_quant: (lambda x, observer_on, fake_quant_on, averaging_const, running_min,
|
||||
running_max, scale, zero_point, quant_min, quant_max, ch_axis,
|
||||
per_row_fake_quant=False, symmetric_quant=False: -1),
|
||||
torch.fbgemm_linear_fp16_weight: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_fp16_weight_fp32_activation: lambda input, packed_weight, bias: -1,
|
||||
torch.fbgemm_linear_int8_weight: lambda input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias: -1,
|
||||
torch.fbgemm_linear_int8_weight_fp32_activation: (lambda input, weight, packed, col_offsets, weight_scale,
|
||||
weight_zero_point, bias: -1),
|
||||
torch.fbgemm_linear_quantize_weight: lambda input: -1,
|
||||
torch.fbgemm_pack_gemm_matrix_fp16: lambda input: -1,
|
||||
torch.fbgemm_pack_quantized_matrix: lambda input, a, b: -1,
|
||||
torch.feature_alpha_dropout: lambda input, p, train: -1,
|
||||
torch.feature_dropout: lambda input, p, train: -1,
|
||||
torch.fft.ifft: lambda input, n=None, dim=-1, norm=None: -1,
|
||||
|
|
@ -977,21 +969,12 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.quantize_per_tensor: lambda input, scale, zero_point, dtype: -1,
|
||||
torch.quantize_per_tensor_dynamic: lambda input, dtype, reduce_range: -1,
|
||||
torch.quantized_batch_norm: lambda input, weight, bias, mean, var, eps, output_scale, output_zero_point: -1,
|
||||
torch.quantized_gru_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
|
||||
torch.quantized_lstm_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.quantized_max_pool1d: (lambda input, kernel_size, stride=tuple(), padding=(0,),
|
||||
dilation=(1,), ceil_mode=False: -1),
|
||||
torch.quantized_max_pool2d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0),
|
||||
dilation=(1, 1), ceil_mode=False: -1),
|
||||
torch.quantized_max_pool3d: (lambda input, kernel_size, stride=tuple(), padding=(0, 0, 0),
|
||||
dilation=(1, 1, 1), ceil_mode=False: -1),
|
||||
torch.quantized_rnn_relu_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.quantized_rnn_tanh_cell: (lambda input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih,
|
||||
col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh: -1),
|
||||
torch.rad2deg: lambda input, out=None: -1,
|
||||
torch.rand_like: lambda input, dtype=None, layout=None, device=None, requires_grad=False: -1,
|
||||
torch.randint_like: lambda input, high, dtype=None, layout=torch.strided, device=None, requires_grad=False: -1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user