From 783a9dcb6d60616d24e0936df787afbc54bb9156 Mon Sep 17 00:00:00 2001 From: Sampath Victor Date: Thu, 25 Sep 2025 02:52:04 +0000 Subject: [PATCH] [6/n] Quantization with min & max bounds support - using fbgemm changes in ATen (#162924) Summary: This diff uses the FBGEMM changes made in D78181177 & D81858256 to support using the provided per row min/max values while quantizaing float/half to 8-bit, 4-bit & 2-bit in ATen library. Please find more context on this here: https://fburl.com/gdoc/yutf32a0 Test Plan: ``` buck test mode/opt caffe2/torch/fb/model_transform/splitting/tests:split_dispatcher_test ``` https://www.internalfb.com/intern/testinfra/testrun/7881299640979446 Please refer to D80905814's test plan for integration testing. Rollback Plan: Differential Revision: D81327342 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162924 Approved by: https://github.com/jerryzh168 --- .../quantized/cpu/qembeddingbag_prepack.cpp | 119 +++++++++++++- .../quantized/cpu/qembeddingbag_prepack.h | 5 +- aten/src/ATen/native/quantized/library.cpp | 3 + test/quantization/core/test_quantized_op.py | 148 +++++++++++++----- 4 files changed, 229 insertions(+), 46 deletions(-) diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp index 807a9b25d37..40fb1c6c0f5 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.cpp @@ -158,12 +158,46 @@ c10::intrusive_ptr PackedEmbeddingBagWeight::prepack( return packed_ptr; } +#ifdef USE_FBGEMM +namespace { +/// Number of columns in the rowwise min/max buffer passed to the quantization function(s) +constexpr int kRowwiseMinMaxNumCols = 2; + +bool _validate_rowwise_min_max( + const at::Tensor& weight, + const std::optional& rowwise_min_max_opt) { + const auto is_valid_rowwise_min_max = rowwise_min_max_opt.has_value(); + + if (is_valid_rowwise_min_max) { + TORCH_CHECK( + (rowwise_min_max_opt->dim() == 2 && + rowwise_min_max_opt->size(0) == weight.size(0) && + rowwise_min_max_opt->size(1) == kRowwiseMinMaxNumCols), + "'rowwise_min_max' must be a 2D tensor with shape [num_rows(weight), 2]."); + } + + return is_valid_rowwise_min_max; +} + +auto _get_rowwise_min_max_contig( + const std::optional& rowwise_min_max_opt) { + return rowwise_min_max_opt.has_value() + ? rowwise_min_max_opt->expect_contiguous(rowwise_min_max_opt->suggest_memory_format()) + : at::borrow_from_optional_tensor(rowwise_min_max_opt); +} +} +#endif // USE_FBGEMM + namespace at::native { // Note - This is a temporary pack function for embedding bag which quantizes // and packs the float weight tensor. In the next step it will be replaced by a // quantize and pack function once we support FP scale and FP zero_point // +// The optional rowwise_min_max argument is to support callers to pass in the min/max +// values of the weight tensor. If the rowwise_min_max is not provided, the min/max +// values will be computed from the weight tensor. +// // Python example examining a packed 8bit zero_point and scale: // // >> x = torch.from_numpy(np.array([[[10, 20], [30, 40]],[[50, 60], [70, 80]]], @@ -221,7 +255,10 @@ namespace at::native { // // [[50. , 60.00000035], // [70. , 80.00000035]]]) -Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { +Tensor& qembeddingbag_byte_prepack_out( + Tensor& output, + const Tensor& weight, + const std::optional& rowwise_min_max_opt) { // The "last" dimension of an N-Dimensioned batch of embedding bags is // quantization channel. E.g. for a 2D embedding bag, this has // [ row, col ] dimensions, for batched of embedding bags, dimensions might be @@ -256,9 +293,16 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { auto* output_data = output.data_ptr(); #ifdef USE_FBGEMM + // Move these outside of the ifdef when we support non-FBGEMM flow. + const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt); + const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt); + if (weight_contig->scalar_type() == at::ScalarType::Half) { const auto weight_data = static_cast(weight_contig->data_ptr()); + const auto rowwise_min_max_data = is_valid_rowwise_min_max + ? static_cast(rowwise_min_max_contig->data_ptr()) + : nullptr; at::parallel_for( 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat< @@ -266,17 +310,21 @@ Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight) { weight_data + start_idx * embedding_cols, end_idx - start_idx, embedding_cols, - output_data + start_idx * output_columns); + output_data + start_idx * output_columns, + (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); }); } else { const auto weight_data = weight_contig->data_ptr(); + const auto rowwise_min_max_data = + is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr() : nullptr; at::parallel_for( 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { fbgemm::FloatOrHalfToFused8BitRowwiseQuantizedSBFloat( weight_data + start_idx * embedding_cols, end_idx - start_idx, embedding_cols, - output_data + start_idx * output_columns); + output_data + start_idx * output_columns, + (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); }); } @@ -326,6 +374,22 @@ Tensor qembeddingbag_byte_prepack(const Tensor& weight) { return output; } +static Tensor qembeddingbag_byte_prepack_with_rowwise_min_max( + const Tensor& weight, + const Tensor& rowwise_min_max) { + const auto weight_contig = + weight.expect_contiguous(weight.suggest_memory_format()); + Tensor output = at::detail::empty_cpu( + {0}, + at::kByte, + weight_contig->layout(), + weight_contig->device(), + std::nullopt, + std::nullopt); + qembeddingbag_byte_prepack_out(output, weight, rowwise_min_max); + return output; +} + Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { const auto weight_contig = weight.expect_contiguous(weight.suggest_memory_format()); @@ -335,7 +399,7 @@ Tensor qembeddingbag_byte_prepack_meta(const Tensor& weight) { "'embedding_bag_byte_prepack' only support float32 or float16."); const auto weight_sizes = weight.sym_sizes(); const auto cols_dim = weight.ndimension() - 1; - const auto embedding_cols = weight_sizes[cols_dim]; + const auto& embedding_cols = weight_sizes[cols_dim]; // Add 8 bytes per column to store FP32 scale and zero_point per row. const auto output_columns = embedding_cols + 2 * sizeof(float); @@ -359,7 +423,8 @@ Tensor _qembeddingbag_nbit_prepack_helper( int bit_width, const bool optimized_qparams, const int64_t nbins, - const double ratio) { + const double ratio, + const std::optional& rowwise_min_max_opt = std::nullopt) { TORCH_CHECK( weight.scalar_type() == at::ScalarType::Float || weight.scalar_type() == at::ScalarType::Half, @@ -401,10 +466,17 @@ Tensor _qembeddingbag_nbit_prepack_helper( auto* output_data = output.data_ptr(); #ifdef USE_FBGEMM + // Move these outside of the ifdef when we support non-FBGEMM flow. + const auto is_valid_rowwise_min_max = _validate_rowwise_min_max(weight, rowwise_min_max_opt); + const auto rowwise_min_max_contig = _get_rowwise_min_max_contig(rowwise_min_max_opt); + if (!optimized_qparams) { if (weight_contig.scalar_type() == at::ScalarType::Half) { const auto weight_data = static_cast(weight_contig.data_ptr()); + const auto rowwise_min_max_data = is_valid_rowwise_min_max + ? static_cast(rowwise_min_max_contig->data_ptr()) + : nullptr; at::parallel_for( 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf< @@ -413,10 +485,13 @@ Tensor _qembeddingbag_nbit_prepack_helper( weight_data + start_idx * embedding_cols, end_idx - start_idx, static_cast(embedding_cols), - output_data + start_idx * output_shape[1]); + output_data + start_idx * output_shape[1], + (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); }); } else { const auto weight_data = weight_contig.data_ptr(); + const auto rowwise_min_max_data = + is_valid_rowwise_min_max ? rowwise_min_max_contig->data_ptr() : nullptr; at::parallel_for( 0, embedding_rows, 1, [&](int64_t start_idx, int64_t end_idx) { fbgemm::FloatOrHalfToFusedNBitRowwiseQuantizedSBHalf( @@ -424,7 +499,8 @@ Tensor _qembeddingbag_nbit_prepack_helper( weight_data + start_idx * embedding_cols, end_idx - start_idx, static_cast(embedding_cols), - output_data + start_idx * output_shape[1]); + output_data + start_idx * output_shape[1], + (is_valid_rowwise_min_max ? (rowwise_min_max_data + start_idx * kRowwiseMinMaxNumCols) : nullptr)); }); } } else { @@ -514,6 +590,16 @@ Tensor qembeddingbag_4bit_prepack( weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio); } +Tensor qembeddingbag_4bit_prepack_with_rowwise_min_max( + const Tensor& weight, + const Tensor& rowwise_min_max, + const bool optimized_qparams, + const int64_t nbins, + const double ratio) { + return _qembeddingbag_nbit_prepack_helper( + weight, 4 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max); +} + // Applies 2-bit row-wise quantization by determining the range // (maximum - minimum) and bias (minimum value) of each row in the input // matrix, and then scaling each element to an 2-bit number between 0 and @@ -531,6 +617,16 @@ Tensor qembeddingbag_2bit_prepack( weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio); } +Tensor qembeddingbag_2bit_prepack_with_rowwise_min_max( + const Tensor& weight, + const Tensor& rowwise_min_max, + const bool optimized_qparams, + const int64_t nbins, + const double ratio) { + return _qembeddingbag_nbit_prepack_helper( + weight, 2 /*bit_width*/, optimized_qparams, nbins, ratio, rowwise_min_max); +} + class QEmbeddingPackWeights final { public: static c10::intrusive_ptr run(const at::Tensor& weight) { @@ -542,12 +638,21 @@ TORCH_LIBRARY_IMPL(quantized, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack"), TORCH_FN(qembeddingbag_byte_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_byte_prepack_with_rowwise_min_max"), + TORCH_FN(qembeddingbag_byte_prepack_with_rowwise_min_max)); m.impl( TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack"), TORCH_FN(qembeddingbag_4bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max"), + TORCH_FN(qembeddingbag_4bit_prepack_with_rowwise_min_max)); m.impl( TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack"), TORCH_FN(qembeddingbag_2bit_prepack)); + m.impl( + TORCH_SELECTIVE_NAME("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max"), + TORCH_FN(qembeddingbag_2bit_prepack_with_rowwise_min_max)); } TORCH_LIBRARY_IMPL(quantized, QuantizedCPU, m) { diff --git a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h index e157405c107..c110e63b362 100644 --- a/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h +++ b/aten/src/ATen/native/quantized/cpu/qembeddingbag_prepack.h @@ -3,7 +3,10 @@ namespace at::native { -Tensor& qembeddingbag_byte_prepack_out(Tensor& output, const Tensor& weight); +Tensor& qembeddingbag_byte_prepack_out( + Tensor& output, + const Tensor& weight, + const std::optional& rowwise_min_max_opt = std::nullopt); Tensor qembeddingbag_byte_prepack(const Tensor& weight); diff --git a/aten/src/ATen/native/quantized/library.cpp b/aten/src/ATen/native/quantized/library.cpp index 550280dbf6d..9ce36192615 100644 --- a/aten/src/ATen/native/quantized/library.cpp +++ b/aten/src/ATen/native/quantized/library.cpp @@ -121,9 +121,12 @@ TORCH_LIBRARY(quantized, m) { m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_unpack(__torch__.torch.classes.quantized.EmbeddingPackedParamsBase W_prepack) -> Tensor W_origin"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack(Tensor weight, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); + m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_prepack_with_rowwise_min_max(Tensor weight, Tensor rowwise_min_max, bool optimized_qparams=False, int nbins=200, float ratio=0.16) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_2bit_unpack(Tensor weight) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_byte_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); m.def(TORCH_SELECTIVE_SCHEMA("quantized::embedding_bag_4bit_rowwise_offsets(Tensor weight, Tensor indices, Tensor? offsets=None, bool scale_grad_by_freq=False, int mode=0, bool pruned_weights=False, Tensor? per_sample_weights=None, Tensor? compressed_indices_mapping=None, bool include_last_offset=False) -> Tensor"), {at::Tag::pt2_compliant_tag}); diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index b6df2089e87..f2e12d2f64e 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4,56 +4,67 @@ import copy import itertools -import numpy as np import operator import random import unittest -from packaging.version import Version -from typing import NamedTuple +from typing import NamedTuple, TYPE_CHECKING + +import numpy as np import torch -from torch import _VF import torch.jit import torch.nn.functional as F -from torch.nn.modules.utils import _single, _pair - -from hypothesis import settings, HealthCheck -from hypothesis import assume, given, note -from hypothesis import strategies as st import torch.testing._internal.hypothesis_utils as hu + +from hypothesis import assume, given, HealthCheck, note, settings, strategies as st +from packaging.version import Version +from torch import _VF +if TYPE_CHECKING: + from torch._ops import OpOverloadPacket +from torch.nn.modules.utils import _pair, _single + hu.assert_deadline_disabled() -from torch.testing._internal.common_cuda import SM80OrLater -from torch.testing._internal.common_utils import ( - raise_on_run_directly, - TestCase, - IS_PPC, - IS_MACOS, - IS_SANDCASTLE, - IS_FBCODE, - IS_ARM64 -) -from torch.testing._internal.common_quantization import skipIfNoFBGEMM, skipIfNoQNNPACK, skipIfNoONEDNN -from torch.testing._internal.common_quantized import _quantize, _dequantize, _calculate_dynamic_qparams, \ - override_quantized_engine, supported_qengines, override_qengines, _snr -from torch.testing._internal.common_quantized import ( - qengine_is_qnnpack, - qengine_is_onednn, -) -from torch.ao.quantization import PerChannelMinMaxObserver -from torch.testing._internal.common_cuda import TEST_CUDNN, TEST_CUDNN_VERSION, TEST_CUDA -from torch.testing._internal.optests import opcheck -import torch.backends.xnnpack - -from torch.utils.cpp_extension import ROCM_HOME - from typing import Optional -np_dtype = { - torch.quint8 : np.uint8, - torch.qint8 : np.int8, - torch.qint32 : np.int32 -} +import torch.backends.xnnpack +from torch.ao.quantization import PerChannelMinMaxObserver +from torch.testing._internal.common_cuda import ( + SM80OrLater, + TEST_CUDA, + TEST_CUDNN, + TEST_CUDNN_VERSION, +) +from torch.testing._internal.common_quantization import ( + skipIfNoFBGEMM, + skipIfNoONEDNN, + skipIfNoQNNPACK, +) +from torch.testing._internal.common_quantized import ( + _calculate_dynamic_qparams, + _dequantize, + _quantize, + _snr, + override_qengines, + override_quantized_engine, + qengine_is_onednn, + qengine_is_qnnpack, + supported_qengines, +) +from torch.testing._internal.common_utils import ( + IS_ARM64, + IS_FBCODE, + IS_MACOS, + IS_PPC, + IS_SANDCASTLE, + raise_on_run_directly, + TestCase, +) +from torch.testing._internal.optests import opcheck + +from torch.utils.cpp_extension import ROCM_HOME + +np_dtype = {torch.quint8: np.uint8, torch.qint8: np.int8, torch.qint32: np.int32} TEST_ROCM = TEST_CUDA and torch.version.hip is not None and ROCM_HOME is not None @@ -8795,5 +8806,66 @@ class TestComparatorOps(TestCase): self.assertEqual(result_ref, result, msg=f"'tensor.{op}(scalar)'' failed") +"""Tests the correctness of the quantized::embedding_bag_(byte|4bit|2bit)_prepack_with_rowwise_min_max ops.""" +class TestQuantizedWithMinMax(TestCase): + """Validates that the *rowwsie_min_max* quantization functions are equivalent to the ones without it.""" + def test_quantize_tensor_with_min_max(self): + num_rows_list = [1, 2, 10, 100] + num_cols_list = [4, 8, 16, 32, 64, 128] + # Map of quantization bit rate to tuple of quantize function (with rowwise_min_max) and + # quantize function (without rowwise_min_max) + bit_rate_to_quant_fn: dict[ + int, + tuple[ + OpOverloadPacket, + OpOverloadPacket, + ], + ] = { + 8: ( + torch.ops.quantized.embedding_bag_byte_prepack_with_rowwise_min_max, + torch.ops.quantized.embedding_bag_byte_prepack, + ), + 4: ( + torch.ops.quantized.embedding_bag_4bit_prepack_with_rowwise_min_max, + torch.ops.quantized.embedding_bag_4bit_prepack, + ), + 2: ( + torch.ops.quantized.embedding_bag_2bit_prepack_with_rowwise_min_max, + torch.ops.quantized.embedding_bag_2bit_prepack, + ), + } + + for quant_fn_with_rowwise_min_max, quant_fn in bit_rate_to_quant_fn.values(): + for torch_dtype in [torch.float16, torch.float32]: + for num_rows, num_cols in itertools.product(num_rows_list, num_cols_list): + weight = torch.rand(num_rows, num_cols, dtype=torch_dtype) + rowwise_min_max = torch.stack( + [weight.min(dim=1).values, weight.max(dim=1).values], dim=1 + ) + + # Perform the quantization with rowwise_min_max + weight_quantized = quant_fn_with_rowwise_min_max( + weight, rowwise_min_max + ) + assert weight_quantized.dtype == torch.uint8 + + # Confirm that the quantization is matching the one without rowwise_min_max + weight_quantized_no_rowwise_min_max = quant_fn(weight) + assert torch.equal( + weight_quantized, weight_quantized_no_rowwise_min_max + ) + + # Confirtm that incorrect rowwise_min_max will result in different quantization output + incorrect_rowwise_min_max = torch.stack( + [weight.max(dim=1).values, weight.max(dim=1).values], dim=1 + ) + weight_incorrectly_quantized = quant_fn_with_rowwise_min_max( + weight, incorrect_rowwise_min_max + ) + assert weight_incorrectly_quantized.dtype == torch.uint8 + assert not torch.equal( + weight_incorrectly_quantized, weight_quantized_no_rowwise_min_max + ) + if __name__ == "__main__": raise_on_run_directly("test/test_quantization.py")