mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add wrapper for fbgemm quantization operations (#122763)
Summary: We add wrappers for fbgemm's packing so we can pass it through PT2 to lowering phase of AOTInductor. Test Plan: Included in commit. test_quantized_ops::test_wrapped_fbgemm_linear_fp16 Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763 Approved by: https://github.com/jerryzh168 ghstack dependencies: #122762
This commit is contained in:
parent
e296722e0e
commit
966ae943df
|
|
@ -16,6 +16,9 @@
|
|||
#include <ATen/ops/_empty_affine_quantized.h>
|
||||
#include <ATen/ops/aminmax.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
|
||||
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
|
||||
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
|
||||
#include <ATen/ops/quantize_per_tensor.h>
|
||||
#endif
|
||||
|
||||
|
|
@ -725,6 +728,57 @@ class QLinearUnpackedDynamicFp16 final {
|
|||
#endif // USE_FBGEMM
|
||||
};
|
||||
|
||||
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor weight) {
|
||||
#ifdef USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
weight.dim() == 2,
|
||||
"fbgemm weight packing only packs matrices not vectors.");
|
||||
return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor weight) {
|
||||
#ifdef USE_FBGEMM
|
||||
// Strictly speaking this is not correct. However we do not know the exact
|
||||
// size of the packed matrix as it's being maintained by the object itself,
|
||||
// therefore we return the view we have here.
|
||||
return at::empty({8}, weight.options().dtype(at::kByte));
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_fbgemm_linear_fp16_weight(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
|
||||
#ifdef USE_FBGEMM
|
||||
return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
|
||||
#ifdef USE_FBGEMM
|
||||
// For the meta function, we need users to provide the dimension explicitly
|
||||
// as we don't have access to the weight.
|
||||
auto out_sizes = input.sym_sizes().vec();
|
||||
if (out_channel == -1) {
|
||||
out_sizes.pop_back();
|
||||
} else {
|
||||
out_sizes.back() = out_channel;
|
||||
}
|
||||
return at::empty_symint(out_sizes, input.options());
|
||||
#else // USE_FBGEMM
|
||||
TORCH_CHECK(
|
||||
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||
#endif // USE_FBGEMM
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
||||
register_linear_params();
|
||||
m.impl(
|
||||
|
|
@ -755,6 +809,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
|
|||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
|
||||
TORCH_FN(QLinearDynamicInt8<false>::run));
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
|
||||
wrapped_fbgemm_pack_gemm_matrix_fp16);
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
|
||||
wrapped_fbgemm_linear_fp16_weight);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
|
||||
wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
|
||||
m.impl(
|
||||
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
|
||||
wrapped_fbgemm_linear_fp16_weight_meta);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
|
|
|||
|
|
@ -247,6 +247,8 @@ TORCH_LIBRARY(_quantized, m) {
|
|||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
|
||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
|
||||
}
|
||||
|
||||
TORCH_LIBRARY(onednn, m) {
|
||||
|
|
|
|||
|
|
@ -4,9 +4,10 @@
|
|||
import copy
|
||||
import itertools
|
||||
import numpy as np
|
||||
import unittest
|
||||
import operator
|
||||
import random
|
||||
import sys
|
||||
import unittest
|
||||
from typing import NamedTuple, List
|
||||
|
||||
import torch
|
||||
|
|
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
|
|||
|
||||
opcheck(qlinear_dynamic, (x, w, bias))
|
||||
|
||||
@skipIfNoFBGEMM
|
||||
def test_wrapped_fbgemm_linear_fp16(self):
|
||||
options = itertools.product(
|
||||
(2, 4), # batch_size
|
||||
(4, 5), # input_channels
|
||||
(4, 7), # output_channels
|
||||
)
|
||||
for batch_size, input_channels, output_channels in options:
|
||||
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
|
||||
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
|
||||
|
||||
x = torch.randn(batch_size, input_channels)
|
||||
w = torch.randn(output_channels, input_channels)
|
||||
bias = torch.randn(output_channels)
|
||||
|
||||
w_packed = pack_op(w)
|
||||
out = linear_op(x, w_packed, bias, output_channels)
|
||||
|
||||
w_fp16 = w.to(torch.float16).to(torch.float32)
|
||||
ref = F.linear(x, w_fp16, bias)
|
||||
|
||||
self.assertEqual(out, ref)
|
||||
|
||||
@unittest.skipIf(
|
||||
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
||||
)
|
||||
@skipIfNoFBGEMM
|
||||
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
|
||||
# We are not using opcheck over here because the output for the op we're testing
|
||||
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
|
||||
# due to the C-struct it's procuding. This would fail the check when we're trying
|
||||
# to match the result between compiled and eager version.
|
||||
#
|
||||
# This is only a temporary solution, long term, we should be able to support PT2
|
||||
# with torchbind natively.
|
||||
def func(X, W, B):
|
||||
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
|
||||
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
|
||||
|
||||
x = torch.randn(1, 4, device="cpu")
|
||||
w = torch.randn(4, 4, device="cpu")
|
||||
b = torch.zeros(4, device="cpu")
|
||||
|
||||
ref_out = func(x, w, b)
|
||||
|
||||
compiled = torch.compile(func)
|
||||
compiled_out = compiled(x, w, b)
|
||||
|
||||
self.assertEqual(ref_out, compiled_out)
|
||||
|
||||
"""Tests the correctness of the dynamic quantized lstm/gru."""
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user