mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
We implement the quantized upsample_bilinear2d case for interpolate kernel in this PR.
For nhwc performance improvement:
import torch, time
for dtype in [torch.qint8, torch.quint8, torch.qint32]:
print('****', str(dtype), '*****')
x = torch.rand(1, 56, 56, 256)
q_x = torch.quantize_per_tensor(x, 0.5, 1, dtype)
q_x = q_x.permute([0, 3, 1, 2])
x = x.permute([0, 3, 1, 2])
NITER = 100
s = time.time()
for i in range(NITER):
float_out = torch.nn.functional.interpolate(x, size=5, scale_factor=None, mode="bilinear", align_corners=True)
time_per_iter_float = (time.time() - s) / NITER
s = time.time()
for i in range(NITER):
quant_out = torch.nn.quantized.functional.interpolate(q_x, size=5, scale_factor=None, mode="bilinear", align_corners=True)
time_per_iter_quant = (time.time() - s) / NITER
ref_quantized = torch.quantize_per_tensor(float_out, 0.5, 1, dtype)
# torch.testing.assert_allclose(ref_quantized.dequantize(), quant_out.dequantize())
print('time/iter ms (float)', 'time/iter ms (quant)', 'quant/float', sep='\t')
print(time_per_iter_float * 1000, time_per_iter_quant * 1000, time_per_iter_quant / time_per_iter_float, sep='\t')
bytes_float = (x.numel() + float_out.numel()) * x.element_size()
bytes_quant = (q_x.numel() + quant_out.numel()) * q_x.element_size()
float_bw_gbps = bytes_float / time_per_iter_float / 1e9
quant_bw_gbps = bytes_quant / time_per_iter_quant / 1e9
print('GB/s float', 'GB/s quant', sep='\t')
print(float_bw_gbps, quant_bw_gbps, sep='\t')
===========without nhwc handling===========
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
1.999044418334961 2.5860953330993652 1.2936657681940702
GB/s float GB/s quant
1.6192056416115257 0.3129103516188541
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.02730655670166 2.6061582565307617 1.2855274639721328
GB/s float GB/s quant
1.596632728927902 0.3105014816242217
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0180463790893555 2.4047350883483887 1.1916153728010588
GB/s float GB/s quant
1.603959172365819 1.3460376636426636
===========with nhwc handling===========
**** torch.qint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.0913314819335938 0.09696483612060547 0.04636512047863123
GB/s float GB/s quant
1.5477527249803915 8.345458337015
**** torch.quint8 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.1065664291381836 0.09959936141967773 0.04728042754408879
GB/s float GB/s quant
1.5365591871338384 8.124710725706763
**** torch.qint32 *****
time/iter ms (float) time/iter ms (quant) quant/float
2.044203281402588 0.6003522872924805 0.29368521846837126
GB/s float GB/s quant
1.5834354779917448 5.391607675216635
Pull Request resolved: https://github.com/pytorch/pytorch/pull/26631
Differential Revision: D17521498
Pulled By: llyfacebook
fbshipit-source-id: 385ae0f77777cd8bee385cafb80e492127b7d103
141 lines
5.8 KiB
Python
141 lines
5.8 KiB
Python
r""" Functional interface (quantized)."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
import torch
|
|
from torch._jit_internal import List as _List
|
|
from torch.nn.modules.utils import _pair
|
|
|
|
|
|
def relu(input, inplace=False):
|
|
# type: (Tensor, bool) -> Tensor
|
|
r"""relu(input, inplace=False) -> Tensor
|
|
|
|
Applies the rectified linear unit function element-wise. See
|
|
:class:`~torch.nn.ReLU` for more details.
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.relu' must be quantized!")
|
|
if inplace:
|
|
return torch.relu_(input)
|
|
else:
|
|
return torch.relu(input)
|
|
|
|
def linear(input, weight, bias=None, scale=None, zero_point=None):
|
|
# type: (Tensor, Tensor, Optional[Tensor]) -> Tensor
|
|
r"""
|
|
Applies a linear transformation to the incoming quantized data:
|
|
:math:`y = xA^T + b`.
|
|
See :class:`~torch.nn.Linear`
|
|
|
|
.. note::
|
|
|
|
Current implementation uses packed weights. This has penalty on performance.
|
|
If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`.
|
|
|
|
Args:
|
|
input (Tensor): Quantized input of type `torch.quint8`
|
|
weight (Tensor): Quantized weight of type `torch.qint8`
|
|
bias (Tensor): None or fp32 bias of type `torch.float`
|
|
scale (double): output scale. If None, derived from the input scale
|
|
zero_point (long): output zero point. If None, derived from the input zero_point
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *, in\_features)` where `*` means any number of
|
|
additional dimensions
|
|
- Weight: :math:`(out\_features, in\_features)`
|
|
- Bias: :math:`(out\_features)`
|
|
- Output: :math:`(N, *, out\_features)`
|
|
"""
|
|
if scale is None:
|
|
scale = input.q_scale()
|
|
if zero_point is None:
|
|
zero_point = input.q_zero_point()
|
|
_packed_params = torch.ops.quantized.linear_prepack(weight, bias)
|
|
return torch.ops.quantized.linear(input, _packed_params, scale,
|
|
zero_point)
|
|
|
|
def conv2d(input, weight, bias,
|
|
stride=1, padding=0, dilation=1, groups=1,
|
|
padding_mode='zeros',
|
|
scale=1.0, zero_point=0,
|
|
dtype=torch.quint8):
|
|
r"""
|
|
conv2d(input, weight, bias,
|
|
stride=1, padding=0, dilation=1, groups=1,
|
|
padding_mode='zeros',
|
|
scale=1.0, zero_point=0,
|
|
dtype=torch.quint8) -> Tensor
|
|
|
|
Applies a 2D convolution over a quantized 2D input composed of several input
|
|
planes.
|
|
|
|
See :class:`~torch.nn.Conv2d` for details and output shape.
|
|
|
|
Args:
|
|
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
|
|
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
|
|
stride: the stride of the convolving kernel. Can be a single number or a
|
|
tuple `(sH, sW)`. Default: 1
|
|
padding: implicit paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padH, padW)`. Default: 0
|
|
dilation: the spacing between kernel elements. Can be a single number or
|
|
a tuple `(dH, dW)`. Default: 1
|
|
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
|
|
number of groups. Default: 1
|
|
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
|
|
scale: quantization scale for the output. Default: 1.0
|
|
zero_point: quantization zero_point for the output. Default: 0
|
|
dtype: quantization data type to use. Default: ``torch.quint8``
|
|
|
|
Examples::
|
|
|
|
>>> from torch.nn.quantized import functional as qF
|
|
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
|
|
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
|
|
>>> bias = torch.randn(4, dtype=torch.float)
|
|
>>>
|
|
>>> scale, zero_point = 1.0, 0
|
|
>>> dtype = torch.quint8
|
|
>>>
|
|
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype)
|
|
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype)
|
|
>>> qF.conv2d(q_inputs, q_filters, bias, scale, zero_point, padding=1)
|
|
""" # noqa: E501
|
|
if padding_mode != 'zeros':
|
|
raise NotImplementedError("Only zero-padding is supported!")
|
|
if input.ndim != 4:
|
|
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
|
|
prepacked_weight = torch.ops.quantized.conv_prepack(
|
|
weight, bias, stride, padding, dilation, groups)
|
|
return torch.ops.quantized.conv2d(input,
|
|
prepacked_weight,
|
|
stride, padding, dilation,
|
|
groups, scale, zero_point)
|
|
|
|
def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
|
ceil_mode=False, return_indices=False):
|
|
r"""Applies a 2D max pooling over a quantized input signal composed of
|
|
several quantized input planes.
|
|
|
|
See :class:`~torch.nn.quantized.MaxPool2d` for details.
|
|
"""
|
|
if return_indices:
|
|
raise NotImplementedError("return_indices is not yet implemented!")
|
|
if stride is None:
|
|
stride = torch.jit.annotate(_List[int], [])
|
|
return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding,
|
|
dilation, ceil_mode, return_indices)
|
|
|
|
# TODO(zaf): Add documentation
|
|
adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d
|
|
interpolate = torch.nn.functional.interpolate
|
|
avg_pool2d = torch.nn.functional.avg_pool2d
|