mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[6/n] Quantization with min & max bounds support - using fbgemm changes in ATen (#162924)
Summary: This diff uses the FBGEMM changes made in D78181177 & D81858256 to support using the provided per row min/max values while quantizaing float/half to 8-bit, 4-bit & 2-bit in ATen library. Please find more context on this here: https://fburl.com/gdoc/yutf32a0 Test Plan: ``` buck test mode/opt caffe2/torch/fb/model_transform/splitting/tests:split_dispatcher_test ``` https://www.internalfb.com/intern/testinfra/testrun/7881299640979446 Please refer to D80905814's test plan for integration testing. Rollback Plan: Differential Revision: D81327342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162924 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
ad2f7315ca
commit
783a9dcb6d
|
|
@ -158,12 +158,46 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
|
||||||
return packed_ptr;
|
return packed_ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
namespace {
|
||||||
|
/// Number of columns in the rowwise min/max buffer passed to the quantization function(s)
|
||||||
|
constexpr int kRowwiseMinMaxNumCols = 2;
|
||||||
|
|
||||||
|
bool _validate_rowwise_min_max(
|
||||||
|
const at::Tensor& weight,
|
||||||
|
const std::optional<at::Tensor>& rowwise_min_max_opt) {
|
||||||
|
const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value();
|
||||||
|
|
||||||
|
if (is_valid_rowwise_min_max) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
(rowwise_min_max_opt->dim() == 2 &&
|
||||||
|
rowwise_min_max_opt->size(0) == weight.size(0) &&
|
||||||
|
rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols),
|
||||||
|
"'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2].");
|
||||||
|
}
|
||||||
|
|
||||||
|
return is_valid_rowwise_min_max;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto _get_rowwise_min_max_contig(
|
||||||
|
const std::optional<at::Tensor>& rowwise_min_max_opt) {
|
||||||
|
return rowwise_min_max_opt.has_value()
|
||||||
|
? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format())
|
||||||
|
: at::borrow_from_optional_tensor(rowwise_min_max_opt);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
// Note - This is a temporary pack function for embedding bag which quantizes
|
// Note - This is a temporary pack function for embedding bag which quantizes
|
||||||
// and packs the float weight tensor. In the next step it will be replaced by a
|
// and packs the float weight tensor. In the next step it will be replaced by a
|
||||||
// quantize and pack function once we support FP scale and FP zero_point
|
// quantize and pack function once we support FP scale and FP zero_point
|
||||||
//
|
//
|
||||||
|
// The optional rowwise_min_max argument is to support callers to pass in the min/max
|
||||||
|
// values of the weight tensor. If the rowwise_min_max is not provided, the min/max
|
||||||
|
// values will be computed from the weight tensor.
|
||||||
|
//
|
||||||
// Python example examining a packed 8bit zero_point and scale:
|
// Python example examining a packed 8bit zero_point and scale:
|
||||||
//
|
//
|
||||||
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
|
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
|
||||||
|
|
@ -221,7 +255,10 @@ namespace at::native {
|
||||||
//
|
//
|
||||||
// [[50. , 60.00000035],
|
// [[50. , 60.00000035],
|
||||||
// [70. , 80.00000035]]])
|
// [70. , 80.00000035]]])
|
||||||
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
|
Tensor& qembeddingbag_byte_prepack_out(
|
||||||
|
Tensor& output,
|
||||||
|
const Tensor& weight,
|
||||||
|
const std::optional<Tensor>& rowwise_min_max_opt) {
|
||||||
// The "last" dimension of an N-Dimensioned batch of embedding bags is
|
// The "last" dimension of an N-Dimensioned batch of embedding bags is
|
||||||
// quantization channel. E.g. for a 2D embedding bag, this has
|
// quantization channel. E.g. for a 2D embedding bag, this has
|
||||||
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
|
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
|
||||||
|
|
@ -256,9 +293,16 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
|
||||||
auto* output_data = output.data_ptr<uint8_t>();
|
auto* output_data = output.data_ptr<uint8_t>();
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
|
// Move these outside of the ifdef when we support non-FBGEMM flow.
|
||||||
|
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
|
||||||
|
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
|
||||||
|
|
||||||
if (weight_contig->scalar_type() == at::ScalarType::Half) {
|
if (weight_contig->scalar_type() == at::ScalarType::Half) {
|
||||||
const auto weight_data =
|
const auto weight_data =
|
||||||
static_cast<fbgemm::float16*>(weight_contig->data_ptr());
|
static_cast<fbgemm::float16*>(weight_contig->data_ptr());
|
||||||
|
const auto rowwise_min_max_data = is_valid_rowwise_min_max
|
||||||
|
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
|
||||||
|
: nullptr;
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
||||||
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
|
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
|
||||||
|
|
@ -266,17 +310,21 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
|
||||||
weight_data + start_idx * embedding_cols,
|
weight_data + start_idx * embedding_cols,
|
||||||
end_idx - start_idx,
|
end_idx - start_idx,
|
||||||
embedding_cols,
|
embedding_cols,
|
||||||
output_data + start_idx * output_columns);
|
output_data + start_idx * output_columns,
|
||||||
|
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const auto weight_data = weight_contig->data_ptr<float>();
|
const auto weight_data = weight_contig->data_ptr<float>();
|
||||||
|
const auto rowwise_min_max_data =
|
||||||
|
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
||||||
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
|
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
|
||||||
weight_data + start_idx * embedding_cols,
|
weight_data + start_idx * embedding_cols,
|
||||||
end_idx - start_idx,
|
end_idx - start_idx,
|
||||||
embedding_cols,
|
embedding_cols,
|
||||||
output_data + start_idx * output_columns);
|
output_data + start_idx * output_columns,
|
||||||
|
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -326,6 +374,22 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max(
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& rowwise_min_max) {
|
||||||
|
const auto weight_contig =
|
||||||
|
weight.expect_contiguous(weight.suggest_memory_format());
|
||||||
|
Tensor output = at::detail::empty_cpu(
|
||||||
|
{0},
|
||||||
|
at::kByte,
|
||||||
|
weight_contig->layout(),
|
||||||
|
weight_contig->device(),
|
||||||
|
std::nullopt,
|
||||||
|
std::nullopt);
|
||||||
|
qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max);
|
||||||
|
return output;
|
||||||
|
}
|
||||||
|
|
||||||
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
|
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
|
||||||
const auto weight_contig =
|
const auto weight_contig =
|
||||||
weight.expect_contiguous(weight.suggest_memory_format());
|
weight.expect_contiguous(weight.suggest_memory_format());
|
||||||
|
|
@ -335,7 +399,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
|
||||||
"'embedding_bag_byte_prepack' only support float32 or float16.");
|
"'embedding_bag_byte_prepack' only support float32 or float16.");
|
||||||
const auto weight_sizes = weight.sym_sizes();
|
const auto weight_sizes = weight.sym_sizes();
|
||||||
const auto cols_dim = weight.ndimension() - 1;
|
const auto cols_dim = weight.ndimension() - 1;
|
||||||
const auto embedding_cols = weight_sizes[cols_dim];
|
const auto& embedding_cols = weight_sizes[cols_dim];
|
||||||
// Add 8 bytes per column to store FP32 scale and zero_point per row.
|
// Add 8 bytes per column to store FP32 scale and zero_point per row.
|
||||||
const auto output_columns = embedding_cols + 2 * sizeof(float);
|
const auto output_columns = embedding_cols + 2 * sizeof(float);
|
||||||
|
|
||||||
|
|
@ -359,7 +423,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
|
||||||
int bit_width,
|
int bit_width,
|
||||||
const bool optimized_qparams,
|
const bool optimized_qparams,
|
||||||
const int64_t nbins,
|
const int64_t nbins,
|
||||||
const double ratio) {
|
const double ratio,
|
||||||
|
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
weight.scalar_type() == at::ScalarType::Float ||
|
weight.scalar_type() == at::ScalarType::Float ||
|
||||||
weight.scalar_type() == at::ScalarType::Half,
|
weight.scalar_type() == at::ScalarType::Half,
|
||||||
|
|
@ -401,10 +466,17 @@ Tensor _qembeddingbag_nbit_prepack_helper(
|
||||||
auto* output_data = output.data_ptr<uint8_t>();
|
auto* output_data = output.data_ptr<uint8_t>();
|
||||||
|
|
||||||
#ifdef USE_FBGEMM
|
#ifdef USE_FBGEMM
|
||||||
|
// Move these outside of the ifdef when we support non-FBGEMM flow.
|
||||||
|
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
|
||||||
|
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
|
||||||
|
|
||||||
if (!optimized_qparams) {
|
if (!optimized_qparams) {
|
||||||
if (weight_contig.scalar_type() == at::ScalarType::Half) {
|
if (weight_contig.scalar_type() == at::ScalarType::Half) {
|
||||||
const auto weight_data =
|
const auto weight_data =
|
||||||
static_cast<fbgemm::float16*>(weight_contig.data_ptr());
|
static_cast<fbgemm::float16*>(weight_contig.data_ptr());
|
||||||
|
const auto rowwise_min_max_data = is_valid_rowwise_min_max
|
||||||
|
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
|
||||||
|
: nullptr;
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
||||||
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
|
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
|
||||||
|
|
@ -413,10 +485,13 @@ Tensor _qembeddingbag_nbit_prepack_helper(
|
||||||
weight_data + start_idx * embedding_cols,
|
weight_data + start_idx * embedding_cols,
|
||||||
end_idx - start_idx,
|
end_idx - start_idx,
|
||||||
static_cast<int>(embedding_cols),
|
static_cast<int>(embedding_cols),
|
||||||
output_data + start_idx * output_shape[1]);
|
output_data + start_idx * output_shape[1],
|
||||||
|
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
|
||||||
});
|
});
|
||||||
} else {
|
} else {
|
||||||
const auto weight_data = weight_contig.data_ptr<float>();
|
const auto weight_data = weight_contig.data_ptr<float>();
|
||||||
|
const auto rowwise_min_max_data =
|
||||||
|
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
|
||||||
at::parallel_for(
|
at::parallel_for(
|
||||||
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
|
||||||
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
|
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
|
||||||
|
|
@ -424,7 +499,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
|
||||||
weight_data + start_idx * embedding_cols,
|
weight_data + start_idx * embedding_cols,
|
||||||
end_idx - start_idx,
|
end_idx - start_idx,
|
||||||
static_cast<int>(embedding_cols),
|
static_cast<int>(embedding_cols),
|
||||||
output_data + start_idx * output_shape[1]);
|
output_data + start_idx * output_shape[1],
|
||||||
|
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -514,6 +590,16 @@ Tensor qembeddingbag_4bit_prepack(
|
||||||
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
|
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max(
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& rowwise_min_max,
|
||||||
|
const bool optimized_qparams,
|
||||||
|
const int64_t nbins,
|
||||||
|
const double ratio) {
|
||||||
|
return _qembeddingbag_nbit_prepack_helper(
|
||||||
|
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
|
||||||
|
}
|
||||||
|
|
||||||
// Applies 2-bit row-wise quantization by determining the range
|
// Applies 2-bit row-wise quantization by determining the range
|
||||||
// (maximum - minimum) and bias (minimum value) of each row in the input
|
// (maximum - minimum) and bias (minimum value) of each row in the input
|
||||||
// matrix, and then scaling each element to an 2-bit number between 0 and
|
// matrix, and then scaling each element to an 2-bit number between 0 and
|
||||||
|
|
@ -531,6 +617,16 @@ Tensor qembeddingbag_2bit_prepack(
|
||||||
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
|
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max(
|
||||||
|
const Tensor& weight,
|
||||||
|
const Tensor& rowwise_min_max,
|
||||||
|
const bool optimized_qparams,
|
||||||
|
const int64_t nbins,
|
||||||
|
const double ratio) {
|
||||||
|
return _qembeddingbag_nbit_prepack_helper(
|
||||||
|
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
|
||||||
|
}
|
||||||
|
|
||||||
class QEmbeddingPackWeights final {
|
class QEmbeddingPackWeights final {
|
||||||
public:
|
public:
|
||||||
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) {
|
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) {
|
||||||
|
|
@ -542,12 +638,21 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
||||||
m.impl(
|
m.impl(
|
||||||
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
|
||||||
TORCH_FN(qembeddingbag_byte_prepack));
|
TORCH_FN(qembeddingbag_byte_prepack));
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"),
|
||||||
|
TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max));
|
||||||
m.impl(
|
m.impl(
|
||||||
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
|
||||||
TORCH_FN(qembeddingbag_4bit_prepack));
|
TORCH_FN(qembeddingbag_4bit_prepack));
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"),
|
||||||
|
TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max));
|
||||||
m.impl(
|
m.impl(
|
||||||
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
|
||||||
TORCH_FN(qembeddingbag_2bit_prepack));
|
TORCH_FN(qembeddingbag_2bit_prepack));
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"),
|
||||||
|
TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max));
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,10 @@
|
||||||
|
|
||||||
namespace at::native {
|
namespace at::native {
|
||||||
|
|
||||||
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
|
Tensor& qembeddingbag_byte_prepack_out(
|
||||||
|
Tensor& output,
|
||||||
|
const Tensor& weight,
|
||||||
|
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt);
|
||||||
|
|
||||||
Tensor qembeddingbag_byte_prepack(const Tensor& weight);
|
Tensor qembeddingbag_byte_prepack(const Tensor& weight);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,9 +121,12 @@ TORCH_LIBRARY(quantized, m) {
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
|
||||||
|
|
|
||||||
|
|
@ -4,56 +4,67 @@
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
|
||||||
import operator
|
import operator
|
||||||
import random
|
import random
|
||||||
import unittest
|
import unittest
|
||||||
from packaging.version import Version
|
from typing import NamedTuple, TYPE_CHECKING
|
||||||
from typing import NamedTuple
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import _VF
|
|
||||||
import torch.jit
|
import torch.jit
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from torch.nn.modules.utils import _single, _pair
|
|
||||||
|
|
||||||
from hypothesis import settings, HealthCheck
|
|
||||||
from hypothesis import assume, given, note
|
|
||||||
from hypothesis import strategies as st
|
|
||||||
import torch.testing._internal.hypothesis_utils as hu
|
import torch.testing._internal.hypothesis_utils as hu
|
||||||
|
|
||||||
|
from hypothesis import assume, given, HealthCheck, note, settings, strategies as st
|
||||||
|
from packaging.version import Version
|
||||||
|
from torch import _VF
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from torch._ops import OpOverloadPacket
|
||||||
|
from torch.nn.modules.utils import _pair, _single
|
||||||
|
|
||||||
hu.assert_deadline_disabled()
|
hu.assert_deadline_disabled()
|
||||||
|
|
||||||
from torch.testing._internal.common_cuda import SM80OrLater
|
|
||||||
from torch.testing._internal.common_utils import (
|
|
||||||
raise_on_run_directly,
|
|
||||||
TestCase,
|
|
||||||
IS_PPC,
|
|
||||||
IS_MACOS,
|
|
||||||
IS_SANDCASTLE,
|
|
||||||
IS_FBCODE,
|
|
||||||
IS_ARM64
|
|
||||||
)
|
|
||||||
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
|
|
||||||
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
|
|
||||||
override_quantized_engine, supported_qengines, override_qengines, _snr
|
|
||||||
from torch.testing._internal.common_quantized import (
|
|
||||||
qengine_is_qnnpack,
|
|
||||||
qengine_is_onednn,
|
|
||||||
)
|
|
||||||
from torch.ao.quantization import PerChannelMinMaxObserver
|
|
||||||
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA
|
|
||||||
from torch.testing._internal.optests import opcheck
|
|
||||||
import torch.backends.xnnpack
|
|
||||||
|
|
||||||
from torch.utils.cpp_extension import ROCM_HOME
|
|
||||||
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
np_dtype = {
|
import torch.backends.xnnpack
|
||||||
torch.quint8 : np.uint8,
|
from torch.ao.quantization import PerChannelMinMaxObserver
|
||||||
torch.qint8 : np.int8,
|
from torch.testing._internal.common_cuda import (
|
||||||
torch.qint32 : np.int32
|
SM80OrLater,
|
||||||
}
|
TEST_CUDA,
|
||||||
|
TEST_CUDNN,
|
||||||
|
TEST_CUDNN_VERSION,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_quantization import (
|
||||||
|
skipIfNoFBGEMM,
|
||||||
|
skipIfNoONEDNN,
|
||||||
|
skipIfNoQNNPACK,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_quantized import (
|
||||||
|
_calculate_dynamic_qparams,
|
||||||
|
_dequantize,
|
||||||
|
_quantize,
|
||||||
|
_snr,
|
||||||
|
override_qengines,
|
||||||
|
override_quantized_engine,
|
||||||
|
qengine_is_onednn,
|
||||||
|
qengine_is_qnnpack,
|
||||||
|
supported_qengines,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_utils import (
|
||||||
|
IS_ARM64,
|
||||||
|
IS_FBCODE,
|
||||||
|
IS_MACOS,
|
||||||
|
IS_PPC,
|
||||||
|
IS_SANDCASTLE,
|
||||||
|
raise_on_run_directly,
|
||||||
|
TestCase,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.optests import opcheck
|
||||||
|
|
||||||
|
from torch.utils.cpp_extension import ROCM_HOME
|
||||||
|
|
||||||
|
np_dtype = {torch.quint8: np.uint8, torch.qint8: np.int8, torch.qint32: np.int32}
|
||||||
|
|
||||||
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
|
||||||
|
|
||||||
|
|
@ -8795,5 +8806,66 @@ class TestComparatorOps(TestCase):
|
||||||
self.assertEqual(result_ref, result,
|
self.assertEqual(result_ref, result,
|
||||||
msg=f"'tensor.{op}(scalar)'' failed")
|
msg=f"'tensor.{op}(scalar)'' failed")
|
||||||
|
|
||||||
|
"""Tests the correctness of the quantized::embedding_bag_(byte|4bit|2bit)_prepack_with_rowwise_min_max ops."""
|
||||||
|
class TestQuantizedWithMinMax(TestCase):
|
||||||
|
"""Validates that the *rowwsie_min_max* quantization functions are equivalent to the ones without it."""
|
||||||
|
def test_quantize_tensor_with_min_max(self):
|
||||||
|
num_rows_list = [1, 2, 10, 100]
|
||||||
|
num_cols_list = [4, 8, 16, 32, 64, 128]
|
||||||
|
# Map of quantization bit rate to tuple of quantize function (with rowwise_min_max) and
|
||||||
|
# quantize function (without rowwise_min_max)
|
||||||
|
bit_rate_to_quant_fn: dict[
|
||||||
|
int,
|
||||||
|
tuple[
|
||||||
|
OpOverloadPacket,
|
||||||
|
OpOverloadPacket,
|
||||||
|
],
|
||||||
|
] = {
|
||||||
|
8: (
|
||||||
|
torch.ops.quantized.embedding_bag_byte_prepack_with_rowwise_min_max,
|
||||||
|
torch.ops.quantized.embedding_bag_byte_prepack,
|
||||||
|
),
|
||||||
|
4: (
|
||||||
|
torch.ops.quantized.embedding_bag_4bit_prepack_with_rowwise_min_max,
|
||||||
|
torch.ops.quantized.embedding_bag_4bit_prepack,
|
||||||
|
),
|
||||||
|
2: (
|
||||||
|
torch.ops.quantized.embedding_bag_2bit_prepack_with_rowwise_min_max,
|
||||||
|
torch.ops.quantized.embedding_bag_2bit_prepack,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
for quant_fn_with_rowwise_min_max, quant_fn in bit_rate_to_quant_fn.values():
|
||||||
|
for torch_dtype in [torch.float16, torch.float32]:
|
||||||
|
for num_rows, num_cols in itertools.product(num_rows_list, num_cols_list):
|
||||||
|
weight = torch.rand(num_rows, num_cols, dtype=torch_dtype)
|
||||||
|
rowwise_min_max = torch.stack(
|
||||||
|
[weight.min(dim=1).values, weight.max(dim=1).values], dim=1
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform the quantization with rowwise_min_max
|
||||||
|
weight_quantized = quant_fn_with_rowwise_min_max(
|
||||||
|
weight, rowwise_min_max
|
||||||
|
)
|
||||||
|
assert weight_quantized.dtype == torch.uint8
|
||||||
|
|
||||||
|
# Confirm that the quantization is matching the one without rowwise_min_max
|
||||||
|
weight_quantized_no_rowwise_min_max = quant_fn(weight)
|
||||||
|
assert torch.equal(
|
||||||
|
weight_quantized, weight_quantized_no_rowwise_min_max
|
||||||
|
)
|
||||||
|
|
||||||
|
# Confirtm that incorrect rowwise_min_max will result in different quantization output
|
||||||
|
incorrect_rowwise_min_max = torch.stack(
|
||||||
|
[weight.max(dim=1).values, weight.max(dim=1).values], dim=1
|
||||||
|
)
|
||||||
|
weight_incorrectly_quantized = quant_fn_with_rowwise_min_max(
|
||||||
|
weight, incorrect_rowwise_min_max
|
||||||
|
)
|
||||||
|
assert weight_incorrectly_quantized.dtype == torch.uint8
|
||||||
|
assert not torch.equal(
|
||||||
|
weight_incorrectly_quantized, weight_quantized_no_rowwise_min_max
|
||||||
|
)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
raise_on_run_directly("test/test_quantization.py")
|
raise_on_run_directly("test/test_quantization.py")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user