Quantized LSTM/GRU: Remove legacy API support (#72522)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72522

Ref #72263 for cpp_custom_type_hack removal

These overloads were deprecated in #35787 which was in the PyTorch 1.6
release, so the BC period is well expired.

cc jamesr66a

Test Plan: Imported from OSS

Reviewed By: bdhirsh

Differential Revision: D34111271

Pulled By: albanD

fbshipit-source-id: 0078564188133625ca67137975fd5dd2fa2b4827
This commit is contained in:
Peter Bell 2022-02-22 16:52:58 -08:00 committed by Facebook GitHub Bot
parent 7bbf29ed8e
commit 4f9c5a3ed7

View File

@ -3,7 +3,6 @@
#include <ATen/ATen.h>
#include <ATen/NativeFunctions.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/cpp_custom_type_hack.h>
#include <ATen/native/quantized/cpu/packed_params.h>
#include <ATen/native/quantized/cpu/fbgemm_utils.h>
#include <ATen/native/quantized/cpu/qnnpack_utils.h>
@ -579,88 +578,6 @@ static std::vector<CellParams> gather_params(TensorList params, bool has_biases,
return result;
}
// These gather_* functions are kept solely for the purposes of backward
// compatbility in the legacy quantized_{lstm,gru} APIs
static c10::List<c10::intrusive_ptr<CellParamsBase>> gather_quantized_params(
c10::List<at::Tensor> params) {
static at::Tensor undefined;
std::vector<c10::intrusive_ptr<CellParamsBase>> result;
TORCH_CHECK(params.size() % 12 == 0, "got an incorrect number of quantized RNN parameters");
for (size_t i = 0; i < params.size(); i += 12) {
result.emplace_back(c10::make_intrusive<QuantizedCellParams>(
static_cast<at::Tensor>(params[i]),
static_cast<at::Tensor>(params[i + 1]),
static_cast<at::Tensor>(params[i + 2]),
static_cast<at::Tensor>(params[i + 3]),
static_cast<at::Tensor>(params[i + 4]),
static_cast<at::Tensor>(params[i + 5]),
static_cast<at::Tensor>(params[i + 6]),
static_cast<at::Tensor>(params[i + 7]),
static_cast<at::Tensor>(params[i + 8]).item(),
static_cast<at::Tensor>(params[i + 9]).item(),
static_cast<at::Tensor>(params[i + 10]).item(),
static_cast<at::Tensor>(params[i + 11]).item()));
}
return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
}
static c10::List<c10::intrusive_ptr<CellParamsBase>>
gather_quantized_params_dynamic(c10::List<at::Tensor> params) {
static at::Tensor undefined;
std::vector<c10::intrusive_ptr<CellParamsBase>> result;
for (size_t i = 0; i < params.size(); i += 2) {
auto packed_struct_ih =
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
static_cast<at::Tensor>(params[i]));
auto packed_struct_hh =
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
static_cast<at::Tensor>(params[i + 1]));
auto bias_ih = packed_struct_ih->bias().value_or(undefined);
auto bias_hh = packed_struct_hh->bias().value_or(undefined);
result.emplace_back(c10::make_intrusive<QuantizedCellParamsDynamic>(
std::move(packed_struct_ih),
std::move(packed_struct_hh),
std::move(bias_ih),
std::move(bias_hh)));
}
return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
}
static c10::List<c10::intrusive_ptr<CellParamsBase>>
gather_quantized_params_fp16(c10::List<at::Tensor> params) {
static at::Tensor undefined;
std::vector<c10::intrusive_ptr<CellParamsBase>> result;
TORCH_CHECK(params.size() % 4 == 0,
"incorrect number of quantized RNN parameters FP16");
for (size_t i = 0; i < params.size(); i += 4) {
c10::intrusive_ptr<LinearPackedParamsBase> packed_struct_ih =
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
static_cast<at::Tensor>(params[i]));
c10::intrusive_ptr<LinearPackedParamsBase> packed_struct_hh =
cpp_custom_type_hack::cast<c10::intrusive_ptr<LinearPackedParamsBase>>(
static_cast<at::Tensor>(params[i + 1]));
// NB: we install the bias from the gathered parameters here because
// in the "new world", the fp16 linear apply() method always expects
// the bias to be present in the packed struct. In the "old world",
// we called `fbgemm_linear_fp16_weight_fp32_activation`, which took
// the bias explicitly and ignored the bias in the packed struct. To
// reconcile serialized models that behavied in the old style, we
// put the bias into the appropriate packed structures here.
//
// Hopefully we can remove this in the future when we eliminate
// the old style altogether
packed_struct_ih->set_bias(params[i + 2]);
packed_struct_hh->set_bias(params[i + 3]);
result.emplace_back(c10::make_intrusive<QuantizedCellParamsFP16>(
std::move(packed_struct_ih), std::move(packed_struct_hh)));
}
return c10::List<c10::intrusive_ptr<CellParamsBase>>(result);
}
////////////////////////////////////////////////////////////////////////////////
// HIDDEN STATE FUNCTIONS
//
@ -1411,21 +1328,11 @@ std::tuple<Tensor, Tensor> quantized_gru_input_legacy(
bool train,
bool bidirectional,
bool batch_first) {
TORCH_WARN_ONCE(
TORCH_CHECK(
false,
"torch.quantized_gru with List[Tensor] for parameters is "
"deprecated and may be removed! Please re-export your model "
"no longer supported. Please re-export your model "
"using the newer definitions in torch.jit.quantized");
auto params = gather_quantized_params(std::move(_params));
return quantized_gru_input(
_input,
hx,
std::move(params),
has_biases,
num_layers,
dropout_p,
train,
bidirectional,
batch_first);
}
std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
@ -1438,21 +1345,11 @@ std::tuple<Tensor, Tensor> quantized_gru_data_legacy(
double dropout_p,
bool train,
bool bidirectional) {
TORCH_WARN_ONCE(
TORCH_CHECK(
false,
"torch.quantized_gru with List[Tensor] for parameters is "
"deprecated and may be removed! Please re-export your model "
"no longer supported. Please re-export your model "
"using the newer definitions in torch.jit.quantized");
auto params = gather_quantized_params(std::move(_params));
return quantized_gru_data(
data,
batch_sizes,
hx,
std::move(params),
has_biases,
num_layers,
dropout_p,
train,
bidirectional);
}
using tanf_cell_type = SimpleCell<tanh_f, CellParams>;
@ -1768,34 +1665,11 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_input_legacy(
bool batch_first,
c10::optional<ScalarType> dtype,
bool use_dynamic) {
TORCH_WARN_ONCE(
TORCH_CHECK(
false,
"torch.quantized_lstm with List[Tensor] for parameters is "
"deprecated and may be removed! Please re-export your model "
"no longer supported. Please re-export your model "
"using the newer definitions in torch.jit.quantized");
c10::List<c10::intrusive_ptr<CellParamsBase>> params;
auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;
if (result_dtype == at::kChar || result_dtype == at::kQInt8) {
if (use_dynamic) {
params = gather_quantized_params_dynamic(std::move(_params_));
} else {
params = gather_quantized_params(std::move(_params_));
}
} else {
params = gather_quantized_params_fp16(std::move(_params_));
}
return quantized_lstm_input(
_input,
std::move(hx_),
std::move(params),
has_biases,
num_layers,
dropout_p,
train,
bidirectional,
batch_first,
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(dtype),
use_dynamic);
}
std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data(
@ -1857,34 +1731,11 @@ std::tuple<Tensor, Tensor, Tensor> quantized_lstm_data_legacy(
bool bidirectional,
c10::optional<ScalarType> dtype,
bool use_dynamic) {
TORCH_WARN_ONCE(
TORCH_CHECK(
false,
"torch.quantized_lstm with List[Tensor] for parameters is "
"deprecated and may be removed! Please re-export your model "
"no longer supported. Please re-export your model "
"using the newer definitions in torch.jit.quantized");
c10::List<c10::intrusive_ptr<CellParamsBase>> params;
auto result_dtype = dtype.has_value() ? dtype.value() : at::kChar;
if (result_dtype == at::kChar || result_dtype == at::kQInt8) {
if (use_dynamic) {
params = gather_quantized_params_dynamic(std::move(_params_));
} else {
params = gather_quantized_params(std::move(_params_));
}
} else {
params = gather_quantized_params_fp16(std::move(_params_));
}
return quantized_lstm_data(
data,
batch_sizes,
std::move(hx_),
std::move(params),
has_biases,
num_layers,
dropout_p,
train,
bidirectional,
// NOLINTNEXTLINE(performance-move-const-arg)
std::move(dtype),
use_dynamic);
}
#define DEFINE_QUANTIZED_RNN_CELL(name, hx_type, cell_type, return_type, prepare_hx_fn) \