[6/n] Quantization with min & max bounds support - using fbgemm changes in ATen (#162924)

Summary:
This diff uses the FBGEMM changes made in D78181177 & D81858256 to support using the provided per row min/max values while quantizaing float/half to 8-bit, 4-bit & 2-bit in ATen library.

Please find more context on this here: https://fburl.com/gdoc/yutf32a0

Test Plan:
```
buck test mode/opt caffe2/torch/fb/model_transform/splitting/tests:split_dispatcher_test
```
https://www.internalfb.com/intern/testinfra/testrun/7881299640979446

Please refer to D80905814's test plan for integration testing.

Rollback Plan:

Differential Revision: D81327342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162924
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sampath Victor 2025-09-25 02:52:04 +00:00 committed by PyTorch MergeBot
parent ad2f7315ca
commit 783a9dcb6d
4 changed files with 229 additions and 46 deletions

View File

@ -158,12 +158,46 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
return packed_ptr; return packed_ptr;
} }
#ifdef USE_FBGEMM
namespace {
/// Number of columns in the rowwise min/max buffer passed to the quantization function(s)
constexpr int kRowwiseMinMaxNumCols = 2;
bool _validate_rowwise_min_max(
const at::Tensor& weight,
const std::optional<at::Tensor>& rowwise_min_max_opt) {
const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value();
if (is_valid_rowwise_min_max) {
TORCH_CHECK(
(rowwise_min_max_opt->dim() == 2 &&
rowwise_min_max_opt->size(0) == weight.size(0) &&
rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols),
"'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2].");
}
return is_valid_rowwise_min_max;
}
auto _get_rowwise_min_max_contig(
const std::optional<at::Tensor>& rowwise_min_max_opt) {
return rowwise_min_max_opt.has_value()
? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format())
: at::borrow_from_optional_tensor(rowwise_min_max_opt);
}
}
#endif // USE_FBGEMM
namespace at::native { namespace at::native {
// Note - This is a temporary pack function for embedding bag which quantizes // Note - This is a temporary pack function for embedding bag which quantizes
// and packs the float weight tensor. In the next step it will be replaced by a // and packs the float weight tensor. In the next step it will be replaced by a
// quantize and pack function once we support FP scale and FP zero_point // quantize and pack function once we support FP scale and FP zero_point
// //
// The optional rowwise_min_max argument is to support callers to pass in the min/max
// values of the weight tensor. If the rowwise_min_max is not provided, the min/max
// values will be computed from the weight tensor.
//
// Python example examining a packed 8bit zero_point and scale: // Python example examining a packed 8bit zero_point and scale:
// //
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], // >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
@ -221,7 +255,10 @@ namespace at::native {
// //
// [[50. , 60.00000035], // [[50. , 60.00000035],
// [70. , 80.00000035]]]) // [70. , 80.00000035]]])
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt) {
// The "last" dimension of an N-Dimensioned batch of embedding bags is // The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has // quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be // [ row, col ] dimensions, for batched of embedding bags, dimensions might be
@ -256,9 +293,16 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
auto* output_data = output.data_ptr<uint8_t>(); auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM #ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (weight_contig->scalar_type() == at::ScalarType::Half) { if (weight_contig->scalar_type() == at::ScalarType::Half) {
const auto weight_data = const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig->data_ptr()); static_cast<fbgemm::float16*>(weight_contig->data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for( at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat< fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
@ -266,17 +310,21 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
weight_data + start_idx * embedding_cols, weight_data + start_idx * embedding_cols,
end_idx - start_idx, end_idx - start_idx,
embedding_cols, embedding_cols,
output_data + start_idx * output_columns); output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
}); });
} else { } else {
const auto weight_data = weight_contig->data_ptr<float>(); const auto weight_data = weight_contig->data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for( at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>( fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
weight_data + start_idx * embedding_cols, weight_data + start_idx * embedding_cols,
end_idx - start_idx, end_idx - start_idx,
embedding_cols, embedding_cols,
output_data + start_idx * output_columns); output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
}); });
} }
@ -326,6 +374,22 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
return output; return output;
} }
static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max) {
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
Tensor output = at::detail::empty_cpu(
{0},
at::kByte,
weight_contig->layout(),
weight_contig->device(),
std::nullopt,
std::nullopt);
qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max);
return output;
}
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
const auto weight_contig = const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format()); weight.expect_contiguous(weight.suggest_memory_format());
@ -335,7 +399,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
"'embedding_bag_byte_prepack' only support float32 or float16."); "'embedding_bag_byte_prepack' only support float32 or float16.");
const auto weight_sizes = weight.sym_sizes(); const auto weight_sizes = weight.sym_sizes();
const auto cols_dim = weight.ndimension() - 1; const auto cols_dim = weight.ndimension() - 1;
const auto embedding_cols = weight_sizes[cols_dim]; const auto& embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row. // Add 8 bytes per column to store FP32 scale and zero_point per row.
const auto output_columns = embedding_cols + 2 * sizeof(float); const auto output_columns = embedding_cols + 2 * sizeof(float);
@ -359,7 +423,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
int bit_width, int bit_width,
const bool optimized_qparams, const bool optimized_qparams,
const int64_t nbins, const int64_t nbins,
const double ratio) { const double ratio,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt) {
TORCH_CHECK( TORCH_CHECK(
weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Float ||
weight.scalar_type() == at::ScalarType::Half, weight.scalar_type() == at::ScalarType::Half,
@ -401,10 +466,17 @@ Tensor _qembeddingbag_nbit_prepack_helper(
auto* output_data = output.data_ptr<uint8_t>(); auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM #ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (!optimized_qparams) { if (!optimized_qparams) {
if (weight_contig.scalar_type() == at::ScalarType::Half) { if (weight_contig.scalar_type() == at::ScalarType::Half) {
const auto weight_data = const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig.data_ptr()); static_cast<fbgemm::float16*>(weight_contig.data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for( at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf< fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
@ -413,10 +485,13 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols, weight_data + start_idx * embedding_cols,
end_idx - start_idx, end_idx - start_idx,
static_cast<int>(embedding_cols), static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1]); output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
}); });
} else { } else {
const auto weight_data = weight_contig.data_ptr<float>(); const auto weight_data = weight_contig.data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for( at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>( fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
@ -424,7 +499,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols, weight_data + start_idx * embedding_cols,
end_idx - start_idx, end_idx - start_idx,
static_cast<int>(embedding_cols), static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1]); output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
}); });
} }
} else { } else {
@ -514,6 +590,16 @@ Tensor qembeddingbag_4bit_prepack(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio); weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
} }
Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
// Applies 2-bit row-wise quantization by determining the range // Applies 2-bit row-wise quantization by determining the range
// (maximum - minimum) and bias (minimum value) of each row in the input // (maximum - minimum) and bias (minimum value) of each row in the input
// matrix, and then scaling each element to an 2-bit number between 0 and // matrix, and then scaling each element to an 2-bit number between 0 and
@ -531,6 +617,16 @@ Tensor qembeddingbag_2bit_prepack(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio); weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
} }
Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
class QEmbeddingPackWeights final { class QEmbeddingPackWeights final {
public: public:
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) { static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) {
@ -542,12 +638,21 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl( m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
TORCH_FN(qembeddingbag_byte_prepack)); TORCH_FN(qembeddingbag_byte_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max));
m.impl( m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
TORCH_FN(qembeddingbag_4bit_prepack)); TORCH_FN(qembeddingbag_4bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max));
m.impl( m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
TORCH_FN(qembeddingbag_2bit_prepack)); TORCH_FN(qembeddingbag_2bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max));
} }
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {

View File

@ -3,7 +3,10 @@
namespace at::native { namespace at::native {
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt);
Tensor qembeddingbag_byte_prepack(const Tensor& weight); Tensor qembeddingbag_byte_prepack(const Tensor& weight);

View File

@ -121,9 +121,12 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});

View File

@ -4,56 +4,67 @@
import copy import copy
import itertools import itertools
import numpy as np
import operator import operator
import random import random
import unittest import unittest
from packaging.version import Version from typing import NamedTuple, TYPE_CHECKING
from typing import NamedTuple
import numpy as np
import torch import torch
from torch import _VF
import torch.jit import torch.jit
import torch.nn.functional as F import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair
from hypothesis import settings, HealthCheck
from hypothesis import assume, given, note
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu import torch.testing._internal.hypothesis_utils as hu
from hypothesis import assume, given, HealthCheck, note, settings, strategies as st
from packaging.version import Version
from torch import _VF
if TYPE_CHECKING:
from torch._ops import OpOverloadPacket
from torch.nn.modules.utils import _pair, _single
hu.assert_deadline_disabled() hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import (
raise_on_run_directly,
TestCase,
IS_PPC,
IS_MACOS,
IS_SANDCASTLE,
IS_FBCODE,
IS_ARM64
)
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
qengine_is_onednn,
)
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA
from torch.testing._internal.optests import opcheck
import torch.backends.xnnpack
from torch.utils.cpp_extension import ROCM_HOME
from typing import Optional from typing import Optional
np_dtype = { import torch.backends.xnnpack
torch.quint8 : np.uint8, from torch.ao.quantization import PerChannelMinMaxObserver
torch.qint8 : np.int8, from torch.testing._internal.common_cuda import (
torch.qint32 : np.int32 SM80OrLater,
} TEST_CUDA,
TEST_CUDNN,
TEST_CUDNN_VERSION,
)
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM,
skipIfNoONEDNN,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import (
_calculate_dynamic_qparams,
_dequantize,
_quantize,
_snr,
override_qengines,
override_quantized_engine,
qengine_is_onednn,
qengine_is_qnnpack,
supported_qengines,
)
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_FBCODE,
IS_MACOS,
IS_PPC,
IS_SANDCASTLE,
raise_on_run_directly,
TestCase,
)
from torch.testing._internal.optests import opcheck
from torch.utils.cpp_extension import ROCM_HOME
np_dtype = {torch.quint8: np.uint8, torch.qint8: np.int8, torch.qint32: np.int32}
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
@ -8795,5 +8806,66 @@ class TestComparatorOps(TestCase):
self.assertEqual(result_ref, result, self.assertEqual(result_ref, result,
msg=f"'tensor.{op}(scalar)'' failed") msg=f"'tensor.{op}(scalar)'' failed")
"""Tests the correctness of the quantized::embedding_bag_(byte|4bit|2bit)_prepack_with_rowwise_min_max ops."""
class TestQuantizedWithMinMax(TestCase):
"""Validates that the *rowwsie_min_max* quantization functions are equivalent to the ones without it."""
def test_quantize_tensor_with_min_max(self):
num_rows_list = [1, 2, 10, 100]
num_cols_list = [4, 8, 16, 32, 64, 128]
# Map of quantization bit rate to tuple of quantize function (with rowwise_min_max) and
# quantize function (without rowwise_min_max)
bit_rate_to_quant_fn: dict[
int,
tuple[
OpOverloadPacket,
OpOverloadPacket,
],
] = {
8: (
torch.ops.quantized.embedding_bag_byte_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_byte_prepack,
),
4: (
torch.ops.quantized.embedding_bag_4bit_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_4bit_prepack,
),
2: (
torch.ops.quantized.embedding_bag_2bit_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_2bit_prepack,
),
}
for quant_fn_with_rowwise_min_max, quant_fn in bit_rate_to_quant_fn.values():
for torch_dtype in [torch.float16, torch.float32]:
for num_rows, num_cols in itertools.product(num_rows_list, num_cols_list):
weight = torch.rand(num_rows, num_cols, dtype=torch_dtype)
rowwise_min_max = torch.stack(
[weight.min(dim=1).values, weight.max(dim=1).values], dim=1
)
# Perform the quantization with rowwise_min_max
weight_quantized = quant_fn_with_rowwise_min_max(
weight, rowwise_min_max
)
assert weight_quantized.dtype == torch.uint8
# Confirm that the quantization is matching the one without rowwise_min_max
weight_quantized_no_rowwise_min_max = quant_fn(weight)
assert torch.equal(
weight_quantized, weight_quantized_no_rowwise_min_max
)
# Confirtm that incorrect rowwise_min_max will result in different quantization output
incorrect_rowwise_min_max = torch.stack(
[weight.max(dim=1).values, weight.max(dim=1).values], dim=1
)
weight_incorrectly_quantized = quant_fn_with_rowwise_min_max(
weight, incorrect_rowwise_min_max
)
assert weight_incorrectly_quantized.dtype == torch.uint8
assert not torch.equal(
weight_incorrectly_quantized, weight_quantized_no_rowwise_min_max
)
if __name__ == "__main__": if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py") raise_on_run_directly("test/test_quantization.py")