Add wrapper for fbgemm quantization operations (#122763)

Summary:
We add wrappers for fbgemm's packing so we can pass it through PT2 to
lowering phase of AOTInductor.

Test Plan:
Included in commit.
test_quantized_ops::test_wrapped_fbgemm_linear_fp16

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763
Approved by: https://github.com/jerryzh168
ghstack dependencies: #122762
This commit is contained in:
Mu-Chu Lee 2024-03-27 21:19:46 -07:00 committed by PyTorch MergeBot
parent e296722e0e
commit 966ae943df
3 changed files with 122 additions and 1 deletions

View File

@ -16,6 +16,9 @@
#include <ATen/ops/_empty_affine_quantized.h> #include <ATen/ops/_empty_affine_quantized.h>
#include <ATen/ops/aminmax.h> #include <ATen/ops/aminmax.h>
#include <ATen/ops/empty.h> #include <ATen/ops/empty.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
#include <ATen/ops/quantize_per_tensor.h> #include <ATen/ops/quantize_per_tensor.h>
#endif #endif
@ -725,6 +728,57 @@ class QLinearUnpackedDynamicFp16 final {
#endif // USE_FBGEMM #endif // USE_FBGEMM
}; };
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor weight) {
#ifdef USE_FBGEMM
TORCH_CHECK(
weight.dim() == 2,
"fbgemm weight packing only packs matrices not vectors.");
return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor weight) {
#ifdef USE_FBGEMM
// Strictly speaking this is not correct. However we do not know the exact
// size of the packed matrix as it's being maintained by the object itself,
// therefore we return the view we have here.
return at::empty({8}, weight.options().dtype(at::kByte));
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_linear_fp16_weight(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
#ifdef USE_FBGEMM
return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
#ifdef USE_FBGEMM
// For the meta function, we need users to provide the dimension explicitly
// as we don't have access to the weight.
auto out_sizes = input.sym_sizes().vec();
if (out_channel == -1) {
out_sizes.pop_back();
} else {
out_sizes.back() = out_channel;
}
return at::empty_symint(out_sizes, input.options());
#else // USE_FBGEMM
TORCH_CHECK(
false, "This PyTorch installation was not built with FBGEMM operators");
#endif // USE_FBGEMM
}
TORCH_LIBRARY_IMPL(quantized, CPU, m) { TORCH_LIBRARY_IMPL(quantized, CPU, m) {
register_linear_params(); register_linear_params();
m.impl( m.impl(
@ -755,6 +809,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
m.impl( m.impl(
TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"), TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
TORCH_FN(QLinearDynamicInt8<false>::run)); TORCH_FN(QLinearDynamicInt8<false>::run));
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
wrapped_fbgemm_pack_gemm_matrix_fp16);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
wrapped_fbgemm_linear_fp16_weight);
}
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
m.impl(
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
wrapped_fbgemm_linear_fp16_weight_meta);
} }
} // namespace } // namespace

View File

@ -247,6 +247,8 @@ TORCH_LIBRARY(_quantized, m) {
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack")); m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
} }
TORCH_LIBRARY(onednn, m) { TORCH_LIBRARY(onednn, m) {

View File

@ -4,9 +4,10 @@
import copy import copy
import itertools import itertools
import numpy as np import numpy as np
import unittest
import operator import operator
import random import random
import sys
import unittest
from typing import NamedTuple, List from typing import NamedTuple, List
import torch import torch
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
opcheck(qlinear_dynamic, (x, w, bias)) opcheck(qlinear_dynamic, (x, w, bias))
@skipIfNoFBGEMM
def test_wrapped_fbgemm_linear_fp16(self):
options = itertools.product(
(2, 4), # batch_size
(4, 5), # input_channels
(4, 7), # output_channels
)
for batch_size, input_channels, output_channels in options:
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
x = torch.randn(batch_size, input_channels)
w = torch.randn(output_channels, input_channels)
bias = torch.randn(output_channels)
w_packed = pack_op(w)
out = linear_op(x, w_packed, bias, output_channels)
w_fp16 = w.to(torch.float16).to(torch.float32)
ref = F.linear(x, w_fp16, bias)
self.assertEqual(out, ref)
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@skipIfNoFBGEMM
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
# We are not using opcheck over here because the output for the op we're testing
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
# due to the C-struct it's procuding. This would fail the check when we're trying
# to match the result between compiled and eager version.
#
# This is only a temporary solution, long term, we should be able to support PT2
# with torchbind natively.
def func(X, W, B):
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
x = torch.randn(1, 4, device="cpu")
w = torch.randn(4, 4, device="cpu")
b = torch.zeros(4, device="cpu")
ref_out = func(x, w, b)
compiled = torch.compile(func)
compiled_out = compiled(x, w, b)
self.assertEqual(ref_out, compiled_out)
"""Tests the correctness of the dynamic quantized lstm/gru.""" """Tests the correctness of the dynamic quantized lstm/gru."""