mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Names of analogous files in quantized directory (previously snake case) were inconsistent with their non-quantized filename counterparts (pascal case). This is the first of a series of PRs that changes all files in quantized (and sub-directories) dir to have pascal case. `aten/src/ATen/native/quantized/qconv_unpack.cpp` has not been renamed yet because (for reasons currently unknown) after making the name change, `import torch` produces the below error (`qlinear_unpack.cpp` renaming also seems to fail some phabricator CI tests for similar reasons). We suspect that these may be undefined errors and will revisit naming these files in a future PR. ``` terminate called after throwing an instance of 'c10::Error' what(): Type c10::intrusive_ptr<ConvPackedParamsBase<2> > could not be converted to any of the known types. Exception raised from operator() at ../aten/src/ATen/core/jit_type.h:1735 (most recent call first): frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x55 (0x7f26745c0c65 in /data/users/dzdang/pytorch/torch/lib/libc10.so) frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xb1 (0x7f26745bdcd1 in /data/users/dzdang/pytorch/torch/lib/libc10.so) frame #2: <unknown function> + 0x1494e24 (0x7f2663b14e24 in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #3: <unknown function> + 0xfed0bc (0x7f266366d0bc in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #4: c10::detail::infer_schema::make_function_schema(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >&&, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>) + 0x5a (0x7f266366d71a in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #5: c10::detail::infer_schema::make_function_schema(c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>, c10::ArrayRef<c10::detail::infer_schema::ArgumentDef>) + 0x7b (0x7f266366e06b in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #6: <unknown function> + 0x1493f32 (0x7f2663b13f32 in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #7: <unknown function> + 0xe227dd (0x7f26634a27dd in /data/users/dzdang/pytorch/torch/lib/libtorch_cpu.so) frame #8: <unknown function> + 0x14e0a (0x7f268c934e0a in /lib64/ld-linux-x86-64.so.2) ..........................truncated............. ``` Test Plan: ``` python test/test_quantization.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/77037 Approved by: https://github.com/jerryzh168
453 lines
17 KiB
C++
453 lines
17 KiB
C++
#include <gtest/gtest.h>
|
|
|
|
#include <ATen/native/quantized/PackedParams.h>
|
|
#include <test/cpp/tensorexpr/test_base.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/ir/irparser.h>
|
|
#include <torch/csrc/jit/tensorexpr/kernel.h>
|
|
#include <torch/csrc/jit/tensorexpr/loopnest.h>
|
|
#include <torch/csrc/jit/tensorexpr/tensor.h>
|
|
#include <torch/csrc/jit/testing/file_check.h>
|
|
#include <torch/torch.h>
|
|
#include <cmath>
|
|
#include <sstream>
|
|
#include "torch/csrc/jit/tensorexpr/eval.h"
|
|
#include "torch/csrc/jit/tensorexpr/ir.h"
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
using namespace torch::jit::tensorexpr;
|
|
using SimpleIRExprEval = ExprEval<SimpleIREvaluator>;
|
|
using namespace torch::indexing;
|
|
using namespace torch::jit::tensorexpr;
|
|
|
|
class Quantization : public ::testing::Test {
|
|
public:
|
|
// NOLINTNEXTLINE(modernize-use-override,cppcoreguidelines-explicit-virtual-functions)
|
|
void SetUp() {
|
|
getTEMustUseLLVMOnCPU() = false;
|
|
}
|
|
};
|
|
|
|
TEST_F(Quantization, QuantDequantInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=12]()
|
|
%3 : int = prim::Constant[value=13]()
|
|
%4 : float = prim::Constant[value=0.1]()
|
|
%q.1 : QInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
|
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQInt8);
|
|
auto y_expected = at::dequantize(q);
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x.1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%3 : int = prim::Constant[value=122]()
|
|
%4 : float = prim::Constant[value=0.1]()
|
|
%q.1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
|
%6 : Float(2, 2) = aten::dequantize(%q.1)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = 2 * at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
|
auto y_expected = at::dequantize(q);
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantDequantUInt8_NLC) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x.1 : Float(1, 2, 2, strides=[4, 1, 2], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%3 : int = prim::Constant[value=122]()
|
|
%4 : float = prim::Constant[value=0.1]()
|
|
%q.1 : QUInt8(1, 2, 2) = aten::quantize_per_tensor(%x.1, %4, %3, %2)
|
|
%6 : Float(1, 2, 2) = aten::dequantize(%q.1)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = 2 * at::rand({1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
x.unsafeGetTensorImpl()->set_sizes_and_strides({1, 2, 2}, {4, 1, 2});
|
|
auto q = at::quantize_per_tensor(x, 0.1f, 122, at::kQUInt8);
|
|
auto y_expected = at::dequantize(q);
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x:\n" << x << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
at::Tensor quantized_add(
|
|
at::Tensor x1,
|
|
at::Tensor x2,
|
|
double scale,
|
|
int64_t zero) {
|
|
const auto qadd_op =
|
|
c10::Dispatcher::singleton()
|
|
.findSchemaOrThrow("quantized::add", "")
|
|
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
|
|
return qadd_op.call(x1, x2, scale, zero);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantAddDequantInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=12]()
|
|
%qz1 : int = prim::Constant[value=13]()
|
|
%qs1 : float = prim::Constant[value=0.1]()
|
|
%qz2 : int = prim::Constant[value=13]()
|
|
%qs2 : float = prim::Constant[value=0.1]()
|
|
%qza : int = prim::Constant[value=13]()
|
|
%qsa : float = prim::Constant[value=0.1]()
|
|
%q1 : QInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
|
%q2 : QInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
|
%qa : QInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
|
%6 : Float(2, 2) = aten::dequantize(%qa)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQInt8);
|
|
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQInt8);
|
|
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
|
auto y_expected = at::dequantize(qa);
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x1, x2};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x1:\n" << x1 << std::endl;
|
|
std::cout << "q1:\n" << q1 << std::endl;
|
|
std::cout << "x2:\n" << x2 << std::endl;
|
|
std::cout << "q2:\n" << q2 << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantAddDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%qz1 : int = prim::Constant[value=13]()
|
|
%qs1 : float = prim::Constant[value=0.1]()
|
|
%qz2 : int = prim::Constant[value=13]()
|
|
%qs2 : float = prim::Constant[value=0.1]()
|
|
%qza : int = prim::Constant[value=13]()
|
|
%qsa : float = prim::Constant[value=0.1]()
|
|
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
|
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
|
%qa : QUInt8(2, 2) = quantized::add(%q1, %q2, %qsa, %qza)
|
|
%6 : Float(2, 2) = aten::dequantize(%qa)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
|
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
|
|
auto qa = quantized_add(q1, q2, 0.1f, 13);
|
|
auto y_expected = at::dequantize(qa);
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x1, x2};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x1:\n" << x1 << std::endl;
|
|
std::cout << "q1:\n" << q1 << std::endl;
|
|
std::cout << "x2:\n" << x2 << std::endl;
|
|
std::cout << "q2:\n" << q2 << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantSigmoidDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%qz1 : int = prim::Constant[value=13]()
|
|
%qs1 : float = prim::Constant[value=0.1]()
|
|
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
|
%qa : QUInt8(2, 2) = aten::sigmoid(%q1)
|
|
%6 : Float(2, 2) = aten::dequantize(%qa)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
|
auto qs = at::sigmoid(q1);
|
|
auto y_expected = at::dequantize(qs);
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x1};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x1:\n" << x1 << std::endl;
|
|
std::cout << "q1:\n" << q1 << std::endl;
|
|
std::cout << "qs:\n" << qs << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
at::Tensor quantized_mul(
|
|
at::Tensor x1,
|
|
at::Tensor x2,
|
|
double scale,
|
|
int64_t zero) {
|
|
const auto op =
|
|
c10::Dispatcher::singleton()
|
|
.findSchemaOrThrow("quantized::mul", "")
|
|
.typed<at::Tensor(at::Tensor, at::Tensor, double, int64_t)>();
|
|
return op.call(x1, x2, scale, zero);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantMulDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x1 : Float(2, 2, strides=[2, 1], device=cpu), %x2 : Float(2, 2, strides=[2, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%qz1 : int = prim::Constant[value=13]()
|
|
%qs1 : float = prim::Constant[value=0.1]()
|
|
%qz2 : int = prim::Constant[value=13]()
|
|
%qs2 : float = prim::Constant[value=0.1]()
|
|
%qza : int = prim::Constant[value=13]()
|
|
%qsa : float = prim::Constant[value=0.1]()
|
|
%q1 : QUInt8(2, 2) = aten::quantize_per_tensor(%x1, %qs1, %qz1, %2)
|
|
%q2 : QUInt8(2, 2) = aten::quantize_per_tensor(%x2, %qs2, %qz2, %2)
|
|
%qa : QUInt8(2, 2) = quantized::mul(%q1, %q2, %qsa, %qza)
|
|
%6 : Float(2, 2) = aten::dequantize(%qa)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x1 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto x2 = at::rand({2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q1 = at::quantize_per_tensor(x1, 0.1f, 13, at::kQUInt8);
|
|
auto q2 = at::quantize_per_tensor(x2, 0.1f, 13, at::kQUInt8);
|
|
auto qa = quantized_mul(q1, q2, 0.1f, 13);
|
|
auto y_expected = at::dequantize(qa);
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x1, x2};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x1:\n" << x1 << std::endl;
|
|
std::cout << "q1:\n" << q1 << std::endl;
|
|
std::cout << "x2:\n" << x2 << std::endl;
|
|
std::cout << "q2:\n" << q2 << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantUpsampleNearst2dDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(1, 1, 4, 4, strides=[16, 16, 4, 1], device=cpu)):
|
|
%2 : int = prim::Constant[value=13]()
|
|
%4 : NoneType = prim::Constant()
|
|
%3 : int[] = prim::Constant[value=[6, 6]]()
|
|
%qz : int = prim::Constant[value=13]()
|
|
%qs : float = prim::Constant[value=0.1]()
|
|
%q : QUInt8(1, 1, 4, 4) = aten::quantize_per_tensor(%x, %qs, %qz, %2)
|
|
%qu : QUInt8(1, 1, 6, 6) = aten::upsample_nearest2d(%q, %3, %4)
|
|
%6 : Float(1, 1, 6, 6) = aten::dequantize(%qu)
|
|
return (%6))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = at::rand({1, 1, 4, 4}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto q = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
|
|
auto qu = at::upsample_nearest2d(q, {6, 6});
|
|
auto y_expected = at::dequantize(qu);
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x:\n" << x << std::endl;
|
|
std::cout << "q:\n" << q << std::endl;
|
|
std::cout << "qu:\n" << qu << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
TEST_F(Quantization, UpsampleNearst2d) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
|
|
%4 : NoneType = prim::Constant()
|
|
%3 : int[] = prim::Constant[value=[4, 4]]()
|
|
%u : Float(1, 1, 4, 4) = aten::upsample_nearest2d(%x, %3, %4)
|
|
return (%u))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto y_expected = at::upsample_nearest2d(x, {4, 4});
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto y = stack[0].toTensor();
|
|
bool check = at::allclose(y_expected, y);
|
|
if (!check) {
|
|
std::cout << "x:\n" << x << std::endl;
|
|
std::cout << "y_expected:\n" << y_expected << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
at::Tensor quantized_cat(
|
|
c10::List<at::Tensor> const& xs,
|
|
int64_t dim,
|
|
double scale,
|
|
int64_t zero) {
|
|
const auto op = c10::Dispatcher::singleton()
|
|
.findSchemaOrThrow("quantized::cat", "")
|
|
.typed<at::Tensor(
|
|
c10::List<at::Tensor> const&,
|
|
int64_t,
|
|
c10::optional<double>,
|
|
c10::optional<int64_t>)>();
|
|
return op.redispatch(
|
|
DispatchKeySet({DispatchKey::QuantizedCPU}), xs, dim, scale, zero);
|
|
}
|
|
|
|
TEST_F(Quantization, QuantCatDequantUInt8) {
|
|
const auto graph_string = R"IR(
|
|
graph(%x : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %y : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu), %z : Float(1, 1, 2, 2, strides=[2, 2, 2, 1], device=cpu)):
|
|
%qdt : int = prim::Constant[value=13]()
|
|
%qxz : int = prim::Constant[value=13]()
|
|
%qxs : float = prim::Constant[value=0.1]()
|
|
%qyz : int = prim::Constant[value=16]()
|
|
%qys : float = prim::Constant[value=0.15]()
|
|
%qzz : int = prim::Constant[value=19]()
|
|
%qzs : float = prim::Constant[value=0.2]()
|
|
%qx : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%x, %qxs, %qxz, %qdt)
|
|
%qy : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%y, %qys, %qyz, %qdt)
|
|
%qz : QUInt8(1, 1, 2, 2) = aten::quantize_per_tensor(%z, %qzs, %qzz, %qdt)
|
|
%catx : Tensor[] = prim::ListConstruct(%qx, %qy, %qz)
|
|
%catd : int = prim::Constant[value=0]()
|
|
%qcat : QUInt8(3, 1, 2, 2) = quantized::cat(%catx, %catd, %qxs, %qxz)
|
|
%cat : Float(3, 1, 2, 2) = aten::dequantize(%qcat)
|
|
return (%cat))IR";
|
|
auto graph = std::make_shared<Graph>();
|
|
parseIR(graph_string, &*graph);
|
|
|
|
auto x = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto y = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto z = at::rand({1, 1, 2, 2}, TensorOptions(kCPU).dtype(at::kFloat));
|
|
auto qx = at::quantize_per_tensor(x, 0.1f, 13, at::kQUInt8);
|
|
auto qy = at::quantize_per_tensor(y, 0.15f, 16, at::kQUInt8);
|
|
auto qz = at::quantize_per_tensor(z, 0.2f, 19, at::kQUInt8);
|
|
auto qcat = quantized_cat({qx, qy, qz}, 0, 0.1f, 13);
|
|
auto expected = at::dequantize(qcat);
|
|
|
|
TensorExprKernel k(graph);
|
|
std::vector<at::Tensor> inputs = {x, y, z};
|
|
StmtPtr s = k.getCodeGenStmt();
|
|
|
|
std::vector<IValue> stack = fmap<IValue>(inputs);
|
|
k.run(stack);
|
|
auto result = stack[0].toTensor();
|
|
bool check = at::allclose(expected, result);
|
|
if (!check) {
|
|
std::cout << "x:\n" << x << std::endl;
|
|
std::cout << "y:\n" << y << std::endl;
|
|
std::cout << "z:\n" << z << std::endl;
|
|
std::cout << "qx:\n" << qx << std::endl;
|
|
std::cout << "qy:\n" << qy << std::endl;
|
|
std::cout << "qz:\n" << qz << std::endl;
|
|
std::cout << "qcat:\n" << qcat << std::endl;
|
|
std::cout << "expected:\n" << expected << std::endl;
|
|
std::cout << "result:\n" << result << std::endl;
|
|
}
|
|
CHECK_EQ(check, 1);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|