mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
qlinear operator level benchmark (#22914)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22914 Adding op level benchmarking for qlinear operator Reviewed By: mingzhe09088 Differential Revision: D16285204 fbshipit-source-id: 99b734ddfa0af6aada820cac7b2f38ef7a5868cb
This commit is contained in:
parent
7a99f3987b
commit
f72d754877
|
|
@ -9,7 +9,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
|
||||
|
||||
"""Microbenchmarks for batchnorm operator."""
|
||||
"""Microbenchmarks for Linear operator."""
|
||||
|
||||
configs = op_bench.config_list(
|
||||
attrs=[
|
||||
|
|
|
|||
62
benchmarks/operator_benchmark/pt/qlinear_test.py
Normal file
62
benchmarks/operator_benchmark/pt/qlinear_test.py
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
|
||||
import operator_benchmark as op_bench
|
||||
import torch
|
||||
import torch.nn.quantized as nnq
|
||||
|
||||
|
||||
"""
|
||||
Microbenchmarks for Quantized Linear operators.
|
||||
"""
|
||||
|
||||
# Configs for qlinear
|
||||
qlinear_configs = op_bench.config_list(
|
||||
attrs=[
|
||||
[1024, 1024, 1024],
|
||||
[64, 800, 320],
|
||||
[64, 768, 512],
|
||||
[16, 256, 512],
|
||||
[128, 128, 128],
|
||||
[256, 512, 256],
|
||||
[6400, 15, 141],
|
||||
[6400, 8, 141],
|
||||
[16, 211, 2504],
|
||||
[16, 369, 1434],
|
||||
[1, 1024, 3496],
|
||||
[16, 256, 512],
|
||||
[1, 1600, 3456],
|
||||
],
|
||||
attr_names=["N", "OUT", "IN"], # M, N, K
|
||||
tags=["short"],
|
||||
)
|
||||
|
||||
|
||||
class QLinearBenchmark(op_bench.TorchBenchmarkBase):
|
||||
def init(self, N, IN, OUT):
|
||||
scale = 1.0 / 255
|
||||
zero_point = 0
|
||||
X = torch.randn(N, IN, dtype=torch.float32)
|
||||
qX = torch.quantize_linear(X, scale=scale, zero_point=zero_point, dtype=torch.quint8)
|
||||
W = torch.randn(OUT, IN, dtype=torch.float32)
|
||||
qW = torch.quantize_linear(W, scale=scale, zero_point=0, dtype=torch.qint8)
|
||||
|
||||
self.input = qX
|
||||
self.qlinear = nnq.Linear(IN, OUT)
|
||||
self.qlinear.weight = qW
|
||||
self.qlinear.scale = scale
|
||||
self.qlinear.zero_point = zero_point
|
||||
self.set_module_name("QLinear")
|
||||
|
||||
def forward(self):
|
||||
return self.qlinear(self.input)
|
||||
|
||||
|
||||
op_bench.generate_pt_test(qlinear_configs, QLinearBenchmark)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
op_bench.benchmark_runner.main()
|
||||
Loading…
Reference in New Issue
Block a user