Add wrapper for fbgemm quantization operations (#122763)

Summary:
We add wrappers for fbgemm's packing so we can pass it through PT2 to
lowering phase of AOTInductor.

Test Plan:
Included in commit.
test_quantized_ops::test_wrapped_fbgemm_linear_fp16

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763
Approved by: https://github.com/jerryzh168
ghstack dependencies: #122762
This commit is contained in:
Mu-Chu Lee 2024-03-27 21:19:46 -07:00 committed by PyTorch MergeBot
parent e296722e0e
commit 966ae943df
3 changed files with 122 additions and 1 deletions

View File

@ -16,6 +16,9 @@
#include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/aminmax.h>
#include <ATen/ops/empty.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
#include <ATen/ops/quantize_per_tensor.h>
#endif
@ -725,6 +728,57 @@ class QLinearUnpackedDynamicFp16 final {
#endif // USE_FBGEMM
};
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor weight) {
#ifdef USE_FBGEMM
TORCH_CHECK(
weight.dim() == 2,
"fbgemm weight packing only packs matrices not vectors.");
return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor weight) {
#ifdef USE_FBGEMM
// Strictly speaking this is not correct. However we do not know the exact
// size of the packed matrix as it's being maintained by the object itself,
// therefore we return the view we have here.
return at::empty({8}, weight.options().dtype(at::kByte));
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_linear_fp16_weight(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
#ifdef USE_FBGEMM
return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
#ifdef USE_FBGEMM
// For the meta function, we need users to provide the dimension explicitly
// as we don't have access to the weight.
auto out_sizes = input.sym_sizes().vec();
if (out_channel == -1) {
out_sizes.pop_back();
} else {
out_sizes.back() = out_channel;
}
return at::empty_symint(out_sizes, input.options());
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
register_linear_params();
m.impl(
@ -755,6 +809,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
m.impl(
TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
TORCH_FN(QLinearDynamicInt8<false>::run));
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
wrapped_fbgemm_pack_gemm_matrix_fp16);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
wrapped_fbgemm_linear_fp16_weight);
}
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
wrapped_fbgemm_linear_fp16_weight_meta);
}
} // namespace

View File

@ -247,6 +247,8 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
}
TORCH_LIBRARY(onednn, m) {

View File

@ -4,9 +4,10 @@
import copy
import itertools
import numpy as np
import unittest
import operator
import random
import sys
import unittest
from typing import NamedTuple, List
import torch
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
opcheck(qlinear_dynamic, (x, w, bias))
@skipIfNoFBGEMM
def test_wrapped_fbgemm_linear_fp16(self):
options = itertools.product(
(2, 4), # batch_size
(4, 5), # input_channels
(4, 7), # output_channels
)
for batch_size, input_channels, output_channels in options:
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
x = torch.randn(batch_size, input_channels)
w = torch.randn(output_channels, input_channels)
bias = torch.randn(output_channels)
w_packed = pack_op(w)
out = linear_op(x, w_packed, bias, output_channels)
w_fp16 = w.to(torch.float16).to(torch.float32)
ref = F.linear(x, w_fp16, bias)
self.assertEqual(out, ref)
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@skipIfNoFBGEMM
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
# We are not using opcheck over here because the output for the op we're testing
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
# due to the C-struct it's procuding. This would fail the check when we're trying
# to match the result between compiled and eager version.
#
# This is only a temporary solution, long term, we should be able to support PT2
# with torchbind natively.
def func(X, W, B):
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
x = torch.randn(1, 4, device="cpu")
w = torch.randn(4, 4, device="cpu")
b = torch.zeros(4, device="cpu")
ref_out = func(x, w, b)
compiled = torch.compile(func)
compiled_out = compiled(x, w, b)
self.assertEqual(ref_out, compiled_out)
"""Tests the correctness of the dynamic quantized lstm/gru."""