[6/n] Quantization with min & max bounds support - using fbgemm changes in ATen (#162924)

Summary:
This diff uses the FBGEMM changes made in D78181177 & D81858256 to support using the provided per row min/max values while quantizaing float/half to 8-bit, 4-bit & 2-bit in ATen library.

Please find more context on this here: https://fburl.com/gdoc/yutf32a0

Test Plan:
```
buck test mode/opt caffe2/torch/fb/model_transform/splitting/tests:split_dispatcher_test
```
https://www.internalfb.com/intern/testinfra/testrun/7881299640979446

Please refer to D80905814's test plan for integration testing.

Rollback Plan:

Differential Revision: D81327342

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162924
Approved by: https://github.com/jerryzh168
This commit is contained in:
Sampath Victor 2025-09-25 02:52:04 +00:00 committed by PyTorch MergeBot
parent ad2f7315ca
commit 783a9dcb6d
4 changed files with 229 additions and 46 deletions

View File

@ -158,12 +158,46 @@ c10::intrusive_ptr<EmbeddingPackedParamsBase> PackedEmbeddingBagWeight::prepack(
return packed_ptr;
}
#ifdef USE_FBGEMM
namespace {
/// Number of columns in the rowwise min/max buffer passed to the quantization function(s)
constexpr int kRowwiseMinMaxNumCols = 2;
bool _validate_rowwise_min_max(
const at::Tensor& weight,
const std::optional<at::Tensor>& rowwise_min_max_opt) {
const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value();
if (is_valid_rowwise_min_max) {
TORCH_CHECK(
(rowwise_min_max_opt->dim() == 2 &&
rowwise_min_max_opt->size(0) == weight.size(0) &&
rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols),
"'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2].");
}
return is_valid_rowwise_min_max;
}
auto _get_rowwise_min_max_contig(
const std::optional<at::Tensor>& rowwise_min_max_opt) {
return rowwise_min_max_opt.has_value()
? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format())
: at::borrow_from_optional_tensor(rowwise_min_max_opt);
}
}
#endif // USE_FBGEMM
namespace at::native {
// Note - This is a temporary pack function for embedding bag which quantizes
// and packs the float weight tensor. In the next step it will be replaced by a
// quantize and pack function once we support FP scale and FP zero_point
//
// The optional rowwise_min_max argument is to support callers to pass in the min/max
// values of the weight tensor. If the rowwise_min_max is not provided, the min/max
// values will be computed from the weight tensor.
//
// Python example examining a packed 8bit zero_point and scale:
//
// >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]],
@ -221,7 +255,10 @@ namespace at::native {
//
// [[50. , 60.00000035],
// [70. , 80.00000035]]])
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt) {
// The "last" dimension of an N-Dimensioned batch of embedding bags is
// quantization channel. E.g. for a 2D embedding bag, this has
// [ row, col ] dimensions, for batched of embedding bags, dimensions might be
@ -256,9 +293,16 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (weight_contig->scalar_type() == at::ScalarType::Half) {
const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig->data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<
@ -266,17 +310,21 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) {
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
embedding_cols,
output_data + start_idx * output_columns);
output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
});
} else {
const auto weight_data = weight_contig->data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat<float>(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
embedding_cols,
output_data + start_idx * output_columns);
output_data + start_idx * output_columns,
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
});
}
@ -326,6 +374,22 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) {
return output;
}
static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max) {
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
Tensor output = at::detail::empty_cpu(
{0},
at::kByte,
weight_contig->layout(),
weight_contig->device(),
std::nullopt,
std::nullopt);
qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max);
return output;
}
Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
const auto weight_contig =
weight.expect_contiguous(weight.suggest_memory_format());
@ -335,7 +399,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) {
"'embedding_bag_byte_prepack' only support float32 or float16.");
const auto weight_sizes = weight.sym_sizes();
const auto cols_dim = weight.ndimension() - 1;
const auto embedding_cols = weight_sizes[cols_dim];
const auto& embedding_cols = weight_sizes[cols_dim];
// Add 8 bytes per column to store FP32 scale and zero_point per row.
const auto output_columns = embedding_cols + 2 * sizeof(float);
@ -359,7 +423,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
int bit_width,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
const double ratio,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt) {
TORCH_CHECK(
weight.scalar_type() == at::ScalarType::Float ||
weight.scalar_type() == at::ScalarType::Half,
@ -401,10 +466,17 @@ Tensor _qembeddingbag_nbit_prepack_helper(
auto* output_data = output.data_ptr<uint8_t>();
#ifdef USE_FBGEMM
// Move these outside of the ifdef when we support non-FBGEMM flow.
const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt);
const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt);
if (!optimized_qparams) {
if (weight_contig.scalar_type() == at::ScalarType::Half) {
const auto weight_data =
static_cast<fbgemm::float16*>(weight_contig.data_ptr());
const auto rowwise_min_max_data = is_valid_rowwise_min_max
? static_cast<fbgemm::float16*>(rowwise_min_max_contig->data_ptr())
: nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<
@ -413,10 +485,13 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1]);
output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
});
} else {
const auto weight_data = weight_contig.data_ptr<float>();
const auto rowwise_min_max_data =
is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr<float>() : nullptr;
at::parallel_for(
0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) {
fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf<float>(
@ -424,7 +499,8 @@ Tensor _qembeddingbag_nbit_prepack_helper(
weight_data + start_idx * embedding_cols,
end_idx - start_idx,
static_cast<int>(embedding_cols),
output_data + start_idx * output_shape[1]);
output_data + start_idx * output_shape[1],
(is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr));
});
}
} else {
@ -514,6 +590,16 @@ Tensor qembeddingbag_4bit_prepack(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio);
}
Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
// Applies 2-bit row-wise quantization by determining the range
// (maximum - minimum) and bias (minimum value) of each row in the input
// matrix, and then scaling each element to an 2-bit number between 0 and
@ -531,6 +617,16 @@ Tensor qembeddingbag_2bit_prepack(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio);
}
Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max(
const Tensor& weight,
const Tensor& rowwise_min_max,
const bool optimized_qparams,
const int64_t nbins,
const double ratio) {
return _qembeddingbag_nbit_prepack_helper(
weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max);
}
class QEmbeddingPackWeights final {
public:
static c10::intrusive_ptr<EmbeddingPackedParamsBase> run(const at::Tensor& weight) {
@ -542,12 +638,21 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"),
TORCH_FN(qembeddingbag_byte_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"),
TORCH_FN(qembeddingbag_4bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"),
TORCH_FN(qembeddingbag_2bit_prepack));
m.impl(
TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"),
TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max));
}
TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) {

View File

@ -3,7 +3,10 @@
namespace at::native {
Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight);
Tensor& qembeddingbag_byte_prepack_out(
Tensor& output,
const Tensor& weight,
const std::optional<Tensor>& rowwise_min_max_opt = std::nullopt);
Tensor qembeddingbag_byte_prepack(const Tensor& weight);

View File

@ -121,9 +121,12 @@ TORCH_LIBRARY(quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});
m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag});

View File

@ -4,56 +4,67 @@
import copy
import itertools
import numpy as np
import operator
import random
import unittest
from packaging.version import Version
from typing import NamedTuple
from typing import NamedTuple, TYPE_CHECKING
import numpy as np
import torch
from torch import _VF
import torch.jit
import torch.nn.functional as F
from torch.nn.modules.utils import _single, _pair
from hypothesis import settings, HealthCheck
from hypothesis import assume, given, note
from hypothesis import strategies as st
import torch.testing._internal.hypothesis_utils as hu
from hypothesis import assume, given, HealthCheck, note, settings, strategies as st
from packaging.version import Version
from torch import _VF
if TYPE_CHECKING:
from torch._ops import OpOverloadPacket
from torch.nn.modules.utils import _pair, _single
hu.assert_deadline_disabled()
from torch.testing._internal.common_cuda import SM80OrLater
from torch.testing._internal.common_utils import (
raise_on_run_directly,
TestCase,
IS_PPC,
IS_MACOS,
IS_SANDCASTLE,
IS_FBCODE,
IS_ARM64
)
from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN
from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \
override_quantized_engine, supported_qengines, override_qengines, _snr
from torch.testing._internal.common_quantized import (
qengine_is_qnnpack,
qengine_is_onednn,
)
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA
from torch.testing._internal.optests import opcheck
import torch.backends.xnnpack
from torch.utils.cpp_extension import ROCM_HOME
from typing import Optional
np_dtype = {
torch.quint8 : np.uint8,
torch.qint8 : np.int8,
torch.qint32 : np.int32
}
import torch.backends.xnnpack
from torch.ao.quantization import PerChannelMinMaxObserver
from torch.testing._internal.common_cuda import (
SM80OrLater,
TEST_CUDA,
TEST_CUDNN,
TEST_CUDNN_VERSION,
)
from torch.testing._internal.common_quantization import (
skipIfNoFBGEMM,
skipIfNoONEDNN,
skipIfNoQNNPACK,
)
from torch.testing._internal.common_quantized import (
_calculate_dynamic_qparams,
_dequantize,
_quantize,
_snr,
override_qengines,
override_quantized_engine,
qengine_is_onednn,
qengine_is_qnnpack,
supported_qengines,
)
from torch.testing._internal.common_utils import (
IS_ARM64,
IS_FBCODE,
IS_MACOS,
IS_PPC,
IS_SANDCASTLE,
raise_on_run_directly,
TestCase,
)
from torch.testing._internal.optests import opcheck
from torch.utils.cpp_extension import ROCM_HOME
np_dtype = {torch.quint8: np.uint8, torch.qint8: np.int8, torch.qint32: np.int32}
TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None
@ -8795,5 +8806,66 @@ class TestComparatorOps(TestCase):
self.assertEqual(result_ref, result,
msg=f"'tensor.{op}(scalar)'' failed")
"""Tests the correctness of the quantized::embedding_bag_(byte|4bit|2bit)_prepack_with_rowwise_min_max ops."""
class TestQuantizedWithMinMax(TestCase):
"""Validates that the *rowwsie_min_max* quantization functions are equivalent to the ones without it."""
def test_quantize_tensor_with_min_max(self):
num_rows_list = [1, 2, 10, 100]
num_cols_list = [4, 8, 16, 32, 64, 128]
# Map of quantization bit rate to tuple of quantize function (with rowwise_min_max) and
# quantize function (without rowwise_min_max)
bit_rate_to_quant_fn: dict[
int,
tuple[
OpOverloadPacket,
OpOverloadPacket,
],
] = {
8: (
torch.ops.quantized.embedding_bag_byte_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_byte_prepack,
),
4: (
torch.ops.quantized.embedding_bag_4bit_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_4bit_prepack,
),
2: (
torch.ops.quantized.embedding_bag_2bit_prepack_with_rowwise_min_max,
torch.ops.quantized.embedding_bag_2bit_prepack,
),
}
for quant_fn_with_rowwise_min_max, quant_fn in bit_rate_to_quant_fn.values():
for torch_dtype in [torch.float16, torch.float32]:
for num_rows, num_cols in itertools.product(num_rows_list, num_cols_list):
weight = torch.rand(num_rows, num_cols, dtype=torch_dtype)
rowwise_min_max = torch.stack(
[weight.min(dim=1).values, weight.max(dim=1).values], dim=1
)
# Perform the quantization with rowwise_min_max
weight_quantized = quant_fn_with_rowwise_min_max(
weight, rowwise_min_max
)
assert weight_quantized.dtype == torch.uint8
# Confirm that the quantization is matching the one without rowwise_min_max
weight_quantized_no_rowwise_min_max = quant_fn(weight)
assert torch.equal(
weight_quantized, weight_quantized_no_rowwise_min_max
)
# Confirtm that incorrect rowwise_min_max will result in different quantization output
incorrect_rowwise_min_max = torch.stack(
[weight.max(dim=1).values, weight.max(dim=1).values], dim=1
)
weight_incorrectly_quantized = quant_fn_with_rowwise_min_max(
weight, incorrect_rowwise_min_max
)
assert weight_incorrectly_quantized.dtype == torch.uint8
assert not torch.equal(
weight_incorrectly_quantized, weight_quantized_no_rowwise_min_max
)
if __name__ == "__main__":
raise_on_run_directly("test/test_quantization.py")