mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add wrapper for fbgemm quantization operations (#122763)
Summary: We add wrappers for fbgemm's packing so we can pass it through PT2 to lowering phase of AOTInductor. Test Plan: Included in commit. test_quantized_ops::test_wrapped_fbgemm_linear_fp16 Reviewers: Subscribers: Tasks: Tags: Differential Revision: [D55433204](https://our.internmc.facebook.com/intern/diff/D55433204) Pull Request resolved: https://github.com/pytorch/pytorch/pull/122763 Approved by: https://github.com/jerryzh168 ghstack dependencies: #122762
This commit is contained in:
parent
e296722e0e
commit
966ae943df
|
|
@ -16,6 +16,9 @@
|
||||||
#include <ATen/ops/_empty_affine_quantized.h>
|
#include <ATen/ops/_empty_affine_quantized.h>
|
||||||
#include <ATen/ops/aminmax.h>
|
#include <ATen/ops/aminmax.h>
|
||||||
#include <ATen/ops/empty.h>
|
#include <ATen/ops/empty.h>
|
||||||
|
#include <ATen/ops/fbgemm_linear_fp16_weight_fp32_activation_native.h>
|
||||||
|
#include <ATen/ops/fbgemm_linear_fp16_weight_native.h>
|
||||||
|
#include <ATen/ops/fbgemm_pack_gemm_matrix_fp16_native.h>
|
||||||
#include <ATen/ops/quantize_per_tensor.h>
|
#include <ATen/ops/quantize_per_tensor.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
@ -725,6 +728,57 @@ class QLinearUnpackedDynamicFp16 final {
|
||||||
#endif // USE_FBGEMM
|
#endif // USE_FBGEMM
|
||||||
};
|
};
|
||||||
|
|
||||||
|
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16(const at::Tensor weight) {
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
TORCH_CHECK(
|
||||||
|
weight.dim() == 2,
|
||||||
|
"fbgemm weight packing only packs matrices not vectors.");
|
||||||
|
return at::native::fbgemm_pack_gemm_matrix_fp16(weight);
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor wrapped_fbgemm_pack_gemm_matrix_fp16_meta(const at::Tensor weight) {
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
// Strictly speaking this is not correct. However we do not know the exact
|
||||||
|
// size of the packed matrix as it's being maintained by the object itself,
|
||||||
|
// therefore we return the view we have here.
|
||||||
|
return at::empty({8}, weight.options().dtype(at::kByte));
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor wrapped_fbgemm_linear_fp16_weight(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
return at::native::fbgemm_linear_fp16_weight(input, weight, bias);
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
}
|
||||||
|
|
||||||
|
at::Tensor wrapped_fbgemm_linear_fp16_weight_meta(at::Tensor input, const at::Tensor weight, const at::Tensor bias, int64_t out_channel) {
|
||||||
|
#ifdef USE_FBGEMM
|
||||||
|
// For the meta function, we need users to provide the dimension explicitly
|
||||||
|
// as we don't have access to the weight.
|
||||||
|
auto out_sizes = input.sym_sizes().vec();
|
||||||
|
if (out_channel == -1) {
|
||||||
|
out_sizes.pop_back();
|
||||||
|
} else {
|
||||||
|
out_sizes.back() = out_channel;
|
||||||
|
}
|
||||||
|
return at::empty_symint(out_sizes, input.options());
|
||||||
|
#else // USE_FBGEMM
|
||||||
|
TORCH_CHECK(
|
||||||
|
false, "This PyTorch installation was not built with FBGEMM operators");
|
||||||
|
#endif // USE_FBGEMM
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
TORCH_LIBRARY_IMPL(quantized, CPU, m) {
|
||||||
register_linear_params();
|
register_linear_params();
|
||||||
m.impl(
|
m.impl(
|
||||||
|
|
@ -755,6 +809,21 @@ TORCH_LIBRARY_IMPL(_quantized, CPU, m) {
|
||||||
m.impl(
|
m.impl(
|
||||||
TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
|
TORCH_SELECTIVE_NAME("_quantized::linear_dynamic"),
|
||||||
TORCH_FN(QLinearDynamicInt8<false>::run));
|
TORCH_FN(QLinearDynamicInt8<false>::run));
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
|
||||||
|
wrapped_fbgemm_pack_gemm_matrix_fp16);
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
|
||||||
|
wrapped_fbgemm_linear_fp16_weight);
|
||||||
|
}
|
||||||
|
|
||||||
|
TORCH_LIBRARY_IMPL(_quantized, Meta, m) {
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16"),
|
||||||
|
wrapped_fbgemm_pack_gemm_matrix_fp16_meta);
|
||||||
|
m.impl(
|
||||||
|
TORCH_SELECTIVE_NAME("_quantized::wrapped_fbgemm_linear_fp16_weight"),
|
||||||
|
wrapped_fbgemm_linear_fp16_weight_meta);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
|
||||||
|
|
@ -247,6 +247,8 @@ TORCH_LIBRARY(_quantized, m) {
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
|
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16(Tensor W, Tensor? B=None) -> __torch__.torch.classes.quantized.LinearPackedParamsBase W_prepack"));
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
||||||
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::linear_prepack_fp16_legacy(Tensor W, Tensor? B=None) -> Tensor W_prepack"));
|
||||||
|
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_pack_gemm_matrix_fp16(Tensor W) -> Tensor"));
|
||||||
|
m.def(TORCH_SELECTIVE_SCHEMA("_quantized::wrapped_fbgemm_linear_fp16_weight(Tensor X, Tensor W, Tensor B, int out_channel) -> Tensor"));
|
||||||
}
|
}
|
||||||
|
|
||||||
TORCH_LIBRARY(onednn, m) {
|
TORCH_LIBRARY(onednn, m) {
|
||||||
|
|
|
||||||
|
|
@ -4,9 +4,10 @@
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import unittest
|
|
||||||
import operator
|
import operator
|
||||||
import random
|
import random
|
||||||
|
import sys
|
||||||
|
import unittest
|
||||||
from typing import NamedTuple, List
|
from typing import NamedTuple, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -3377,6 +3378,55 @@ class TestDynamicQuantizedOps(TestCase):
|
||||||
|
|
||||||
opcheck(qlinear_dynamic, (x, w, bias))
|
opcheck(qlinear_dynamic, (x, w, bias))
|
||||||
|
|
||||||
|
@skipIfNoFBGEMM
|
||||||
|
def test_wrapped_fbgemm_linear_fp16(self):
|
||||||
|
options = itertools.product(
|
||||||
|
(2, 4), # batch_size
|
||||||
|
(4, 5), # input_channels
|
||||||
|
(4, 7), # output_channels
|
||||||
|
)
|
||||||
|
for batch_size, input_channels, output_channels in options:
|
||||||
|
pack_op = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16
|
||||||
|
linear_op = torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight
|
||||||
|
|
||||||
|
x = torch.randn(batch_size, input_channels)
|
||||||
|
w = torch.randn(output_channels, input_channels)
|
||||||
|
bias = torch.randn(output_channels)
|
||||||
|
|
||||||
|
w_packed = pack_op(w)
|
||||||
|
out = linear_op(x, w_packed, bias, output_channels)
|
||||||
|
|
||||||
|
w_fp16 = w.to(torch.float16).to(torch.float32)
|
||||||
|
ref = F.linear(x, w_fp16, bias)
|
||||||
|
|
||||||
|
self.assertEqual(out, ref)
|
||||||
|
|
||||||
|
@unittest.skipIf(
|
||||||
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
||||||
|
)
|
||||||
|
@skipIfNoFBGEMM
|
||||||
|
def test_wrapped_fbgemm_pack_gemm_matrix_fp16_pt2_compliant(self):
|
||||||
|
# We are not using opcheck over here because the output for the op we're testing
|
||||||
|
# (_quantized.wrapped_fbgemm_pack_gemm_matrix_fp16) is not deterministic
|
||||||
|
# due to the C-struct it's procuding. This would fail the check when we're trying
|
||||||
|
# to match the result between compiled and eager version.
|
||||||
|
#
|
||||||
|
# This is only a temporary solution, long term, we should be able to support PT2
|
||||||
|
# with torchbind natively.
|
||||||
|
def func(X, W, B):
|
||||||
|
packed_W = torch.ops._quantized.wrapped_fbgemm_pack_gemm_matrix_fp16(W)
|
||||||
|
return torch.ops._quantized.wrapped_fbgemm_linear_fp16_weight(X, packed_W, B, W.size(0))
|
||||||
|
|
||||||
|
x = torch.randn(1, 4, device="cpu")
|
||||||
|
w = torch.randn(4, 4, device="cpu")
|
||||||
|
b = torch.zeros(4, device="cpu")
|
||||||
|
|
||||||
|
ref_out = func(x, w, b)
|
||||||
|
|
||||||
|
compiled = torch.compile(func)
|
||||||
|
compiled_out = compiled(x, w, b)
|
||||||
|
|
||||||
|
self.assertEqual(ref_out, compiled_out)
|
||||||
|
|
||||||
"""Tests the correctness of the dynamic quantized lstm/gru."""
|
"""Tests the correctness of the dynamic quantized lstm/gru."""
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user