mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/40100 ELU has a range of [-1, inf]. In the original PR which added the quantized operator we decided to pass the quantization params from the input. However, it makes more sense to require observation for this op. This PR changes the API to require observation. Next PRs in this stack will add the eager and graph mode handling. Test Plan: ``` python test/test_quantization.py TestQuantizedOps.test_qelu ``` Imported from OSS Differential Revision: D22075083 fbshipit-source-id: 0ea0fd05a00cc7a5f122a2b1de09144bbd586f32
622 lines
27 KiB
Python
622 lines
27 KiB
Python
r""" Functional interface (quantized)."""
|
|
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
from __future__ import unicode_literals
|
|
|
|
from typing import List, Optional
|
|
|
|
import torch
|
|
from torch import Tensor
|
|
from torch.nn.modules.utils import _pair, _triple
|
|
from torch.nn.quantized.modules.utils import _pair_from_first
|
|
|
|
# Although some of the functions and docstrings are mirrored from the torch.nn,
|
|
# we want to have them here for future changes.
|
|
|
|
def avg_pool2d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
|
count_include_pad=True, divisor_override=None):
|
|
r"""
|
|
Applies 2D average-pooling operation in :math:`kH \times kW` regions by step size
|
|
:math:`sH \times sW` steps. The number of output features is equal to the number of
|
|
input planes.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
See :class:`~torch.nn.quantized.AvgPool2d` for details and output shape.
|
|
|
|
Args:
|
|
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
kernel_size: size of the pooling region. Can be a single number or a
|
|
tuple `(kH, kW)`
|
|
stride: stride of the pooling operation. Can be a single number or a
|
|
tuple `(sH, sW)`. Default: :attr:`kernel_size`
|
|
padding: implicit zero paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padH, padW)`. Default: 0
|
|
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
|
|
to compute the output shape. Default: ``False``
|
|
count_include_pad: when True, will include the zero-padding in the
|
|
averaging calculation. Default: ``True``
|
|
divisor_override: if specified, it will be used as divisor, otherwise
|
|
size of the pooling region will be used. Default: None
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.avg_pool2d' must be quantized!")
|
|
return torch.nn.functional.avg_pool2d(input, kernel_size, stride, padding,
|
|
ceil_mode, count_include_pad,
|
|
divisor_override)
|
|
|
|
def avg_pool3d(input, kernel_size, stride=None, padding=0, ceil_mode=False,
|
|
count_include_pad=True, divisor_override=None):
|
|
r"""
|
|
Applies 3D average-pooling operation in :math:`kD \ times kH \times kW` regions by step size
|
|
:math:`sD \times sH \times sW` steps. The number of output features is equal to the number of
|
|
input planes.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
Args:
|
|
input: quantized input tensor :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
kernel_size: size of the pooling region. Can be a single number or a
|
|
tuple `(kD, kH, kW)`
|
|
stride: stride of the pooling operation. Can be a single number or a
|
|
tuple `(sD, sH, sW)`. Default: :attr:`kernel_size`
|
|
padding: implicit zero paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padD, padH, padW)`. Default: 0
|
|
ceil_mode: when True, will use `ceil` instead of `floor` in the formula
|
|
to compute the output shape. Default: ``False``
|
|
count_include_pad: when True, will include the zero-padding in the
|
|
averaging calculation. Default: ``True``
|
|
divisor_override: if specified, it will be used as divisor, otherwise
|
|
size of the pooling region will be used. Default: None
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.avg_pool3d' must be quantized!")
|
|
return torch.nn.functional.avg_pool3d(input, kernel_size, stride, padding,
|
|
ceil_mode, count_include_pad,
|
|
divisor_override)
|
|
|
|
def adaptive_avg_pool2d(input, output_size):
|
|
# type: (Tensor, BroadcastingList2[int]) -> Tensor
|
|
r"""
|
|
Applies a 2D adaptive average pooling over a quantized input signal composed
|
|
of several quantized input planes.
|
|
|
|
.. note:: The input quantization paramteres propagate to the output.
|
|
|
|
See :class:`~torch.nn.quantized.AdaptiveAvgPool2d` for details and output shape.
|
|
|
|
Args:
|
|
output_size: the target output size (single integer or
|
|
double-integer tuple)
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.adaptive_avg_pool2d' must be quantized!")
|
|
return torch.nn.functional.adaptive_avg_pool2d(input, output_size)
|
|
|
|
def conv1d(input, weight, bias,
|
|
stride=1, padding=0, dilation=1, groups=1,
|
|
padding_mode='zeros',
|
|
scale=1.0, zero_point=0,
|
|
dtype=torch.quint8):
|
|
r"""
|
|
Applies a 1D convolution over a quantized 1D input composed of several input
|
|
planes.
|
|
|
|
See :class:`~torch.nn.quantized.Conv1d` for details and output shape.
|
|
|
|
Args:
|
|
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iW)`
|
|
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , iW)`
|
|
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
|
|
stride: the stride of the convolving kernel. Can be a single number or a
|
|
tuple `(sW,)`. Default: 1
|
|
padding: implicit paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padW,)`. Default: 0
|
|
dilation: the spacing between kernel elements. Can be a single number or
|
|
a tuple `(dW,)`. Default: 1
|
|
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
|
|
number of groups. Default: 1
|
|
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
|
|
scale: quantization scale for the output. Default: 1.0
|
|
zero_point: quantization zero_point for the output. Default: 0
|
|
dtype: quantization data type to use. Default: ``torch.quint8``
|
|
|
|
Examples::
|
|
|
|
>>> from torch.nn.quantized import functional as qF
|
|
>>> filters = torch.randn(33, 16, 3, dtype=torch.float)
|
|
>>> inputs = torch.randn(20, 16, 50, dtype=torch.float)
|
|
>>> bias = torch.randn(33, dtype=torch.float)
|
|
>>>
|
|
>>> scale, zero_point = 1.0, 0
|
|
>>> dtype_inputs = torch.quint8
|
|
>>> dtype_filters = torch.qint8
|
|
>>>
|
|
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
|
|
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
|
|
>>> qF.conv1d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
|
|
""" # noqa: E501
|
|
if padding_mode != 'zeros':
|
|
raise NotImplementedError("Only zero-padding is supported!")
|
|
if input.dtype != torch.quint8:
|
|
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
|
|
if weight.dtype != torch.qint8:
|
|
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
|
|
if input.ndim != 3:
|
|
raise ValueError("Input shape must be `(N, C, L)`!")
|
|
stride = _pair_from_first(stride)
|
|
padding = _pair_from_first(padding)
|
|
dilation = _pair_from_first(dilation)
|
|
|
|
packed_params = torch.ops.quantized.conv1d_prepack(
|
|
weight, bias, stride, padding, dilation, groups)
|
|
return torch.ops.quantized.conv1d(input, packed_params, scale, zero_point)
|
|
|
|
def conv2d(input, weight, bias,
|
|
stride=1, padding=0, dilation=1, groups=1,
|
|
padding_mode='zeros',
|
|
scale=1.0, zero_point=0,
|
|
dtype=torch.quint8):
|
|
r"""
|
|
Applies a 2D convolution over a quantized 2D input composed of several input
|
|
planes.
|
|
|
|
See :class:`~torch.nn.quantized.Conv2d` for details and output shape.
|
|
|
|
Args:
|
|
input: quantized input tensor of shape :math:`(\text{minibatch} , \text{in\_channels} , iH , iW)`
|
|
weight: quantized filters of shape :math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kH , kW)`
|
|
bias: **non-quantized** bias tensor of shape :math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
|
|
stride: the stride of the convolving kernel. Can be a single number or a
|
|
tuple `(sH, sW)`. Default: 1
|
|
padding: implicit paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padH, padW)`. Default: 0
|
|
dilation: the spacing between kernel elements. Can be a single number or
|
|
a tuple `(dH, dW)`. Default: 1
|
|
groups: split input into groups, :math:`\text{in\_channels}` should be divisible by the
|
|
number of groups. Default: 1
|
|
padding_mode: the padding mode to use. Only "zeros" is supported for quantized convolution at the moment. Default: "zeros"
|
|
scale: quantization scale for the output. Default: 1.0
|
|
zero_point: quantization zero_point for the output. Default: 0
|
|
dtype: quantization data type to use. Default: ``torch.quint8``
|
|
|
|
Examples::
|
|
|
|
>>> from torch.nn.quantized import functional as qF
|
|
>>> filters = torch.randn(8, 4, 3, 3, dtype=torch.float)
|
|
>>> inputs = torch.randn(1, 4, 5, 5, dtype=torch.float)
|
|
>>> bias = torch.randn(8, dtype=torch.float)
|
|
>>>
|
|
>>> scale, zero_point = 1.0, 0
|
|
>>> dtype_inputs = torch.quint8
|
|
>>> dtype_filters = torch.qint8
|
|
>>>
|
|
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
|
|
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
|
|
>>> qF.conv2d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
|
|
""" # noqa: E501
|
|
if padding_mode != 'zeros':
|
|
raise NotImplementedError("Only zero-padding is supported!")
|
|
if input.dtype != torch.quint8:
|
|
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
|
|
if weight.dtype != torch.qint8:
|
|
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
|
|
if input.ndim != 4:
|
|
raise ValueError("Input shape must be `(N, C, H, W)`!")
|
|
stride = _pair(stride)
|
|
padding = _pair(padding)
|
|
dilation = _pair(dilation)
|
|
|
|
packed_params = torch.ops.quantized.conv2d_prepack(
|
|
weight, bias, stride, padding, dilation, groups)
|
|
return torch.ops.quantized.conv2d(input, packed_params, scale, zero_point)
|
|
|
|
def conv3d(input, weight, bias, stride=1, padding=0, dilation=1, groups=1,
|
|
padding_mode='zeros', scale=1.0, zero_point=0, dtype=torch.quint8):
|
|
r"""
|
|
Applies a 3D convolution over a quantized 3D input composed of several input
|
|
planes.
|
|
|
|
See :class:`~torch.nn.quantized.Conv3d` for details and output shape.
|
|
|
|
Args:
|
|
input: quantized input tensor of shape
|
|
:math:`(\text{minibatch} , \text{in\_channels} , iD , iH , iW)`
|
|
weight: quantized filters of shape
|
|
:math:`(\text{out\_channels} , \frac{\text{in\_channels}}{\text{groups}} , kD , kH , kW)`
|
|
bias: **non-quantized** bias tensor of shape
|
|
:math:`(\text{out\_channels})`. The tensor type must be `torch.float`.
|
|
stride: the stride of the convolving kernel. Can be a single number or a
|
|
tuple `(sD, sH, sW)`. Default: 1
|
|
padding: implicit paddings on both sides of the input. Can be a
|
|
single number or a tuple `(padD, padH, padW)`. Default: 0
|
|
dilation: the spacing between kernel elements. Can be a single number or
|
|
a tuple `(dD, dH, dW)`. Default: 1
|
|
groups: split input into groups, :math:`\text{in\_channels}` should be
|
|
divisible by the number of groups. Default: 1
|
|
padding_mode: the padding mode to use. Only "zeros" is supported for
|
|
quantized convolution at the moment. Default: "zeros"
|
|
scale: quantization scale for the output. Default: 1.0
|
|
zero_point: quantization zero_point for the output. Default: 0
|
|
dtype: quantization data type to use. Default: ``torch.quint8``
|
|
|
|
Examples::
|
|
|
|
>>> from torch.nn.quantized import functional as qF
|
|
>>> filters = torch.randn(8, 4, 3, 3, 3, dtype=torch.float)
|
|
>>> inputs = torch.randn(1, 4, 5, 5, 5, dtype=torch.float)
|
|
>>> bias = torch.randn(8, dtype=torch.float)
|
|
>>>
|
|
>>> scale, zero_point = 1.0, 0
|
|
>>> dtype_inputs = torch.quint8
|
|
>>> dtype_filters = torch.qint8
|
|
>>>
|
|
>>> q_filters = torch.quantize_per_tensor(filters, scale, zero_point, dtype_filters)
|
|
>>> q_inputs = torch.quantize_per_tensor(inputs, scale, zero_point, dtype_inputs)
|
|
>>> qF.conv3d(q_inputs, q_filters, bias, padding=1, scale=scale, zero_point=zero_point)
|
|
""" # noqa: E501
|
|
if padding_mode != 'zeros':
|
|
raise NotImplementedError("Only zero-padding is supported!")
|
|
if input.dtype != torch.quint8:
|
|
raise NotImplementedError("Only torch.quint8 is supported for activation tensor!")
|
|
if weight.dtype != torch.qint8:
|
|
raise NotImplementedError("Only torch.qint8 is supported for weight tensor!")
|
|
if input.ndim != 5:
|
|
raise ValueError("Input shape must be `(N, C, D, H, W)`!")
|
|
stride = _triple(stride)
|
|
padding = _triple(padding)
|
|
dilation = _triple(dilation)
|
|
|
|
packed_params = torch.ops.quantized.conv3d_prepack(
|
|
weight, bias, stride, padding, dilation, groups)
|
|
return torch.ops.quantized.conv3d(input, packed_params, scale, zero_point)
|
|
|
|
def interpolate(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
|
r"""Down/up samples the input to either the given :attr:`size` or the given
|
|
:attr:`scale_factor`
|
|
|
|
See :func:`torch.nn.functional.interpolate` for implementation details.
|
|
|
|
The input dimensions are interpreted in the form:
|
|
`mini-batch x channels x [optional depth] x [optional height] x width`.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
.. note:: Only 2D/3D input is supported for quantized inputs
|
|
|
|
.. note:: Only the following modes are supported for the quantized inputs:
|
|
|
|
- `bilinear`
|
|
- `nearest`
|
|
|
|
Args:
|
|
input (Tensor): the input tensor
|
|
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
|
|
output spatial size.
|
|
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to match input size if it is a tuple.
|
|
mode (str): algorithm used for upsampling:
|
|
``'nearest'`` | ``'bilinear'``
|
|
align_corners (bool, optional): Geometrically, we consider the pixels of the
|
|
input and output as squares rather than points.
|
|
If set to ``True``, the input and output tensors are aligned by the
|
|
center points of their corner pixels, preserving the values at the corner pixels.
|
|
If set to ``False``, the input and output tensors are aligned by the corner
|
|
points of their corner pixels, and the interpolation uses edge value padding
|
|
for out-of-boundary values, making this operation *independent* of input size
|
|
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
|
|
is ``'bilinear'``.
|
|
Default: ``False``
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.interpolate' must be quantized!")
|
|
return torch.nn.functional.interpolate(input, size, scale_factor, mode,
|
|
align_corners)
|
|
|
|
def linear(input, weight, bias=None, scale=None, zero_point=None):
|
|
# type: (Tensor, Tensor, Optional[Tensor], Optional[float], Optional[int]) -> Tensor
|
|
r"""
|
|
Applies a linear transformation to the incoming quantized data:
|
|
:math:`y = xA^T + b`.
|
|
See :class:`~torch.nn.quantized.Linear`
|
|
|
|
.. note::
|
|
|
|
Current implementation packs weights on every call, which has penalty on performance.
|
|
If you want to avoid the overhead, use :class:`~torch.nn.quantized.Linear`.
|
|
|
|
Args:
|
|
input (Tensor): Quantized input of type `torch.quint8`
|
|
weight (Tensor): Quantized weight of type `torch.qint8`
|
|
bias (Tensor): None or fp32 bias of type `torch.float`
|
|
scale (double): output scale. If None, derived from the input scale
|
|
zero_point (long): output zero point. If None, derived from the input zero_point
|
|
|
|
Shape:
|
|
- Input: :math:`(N, *, in\_features)` where `*` means any number of
|
|
additional dimensions
|
|
- Weight: :math:`(out\_features, in\_features)`
|
|
- Bias: :math:`(out\_features)`
|
|
- Output: :math:`(N, *, out\_features)`
|
|
"""
|
|
if scale is None:
|
|
scale = input.q_scale()
|
|
if zero_point is None:
|
|
zero_point = input.q_zero_point()
|
|
_packed_params = torch.ops.quantized.linear_prepack(weight, bias)
|
|
return torch.ops.quantized.linear(input, _packed_params, scale, zero_point)
|
|
|
|
def max_pool2d(input, kernel_size, stride=None, padding=0, dilation=1,
|
|
ceil_mode=False, return_indices=False):
|
|
r"""Applies a 2D max pooling over a quantized input signal composed of
|
|
several quantized input planes.
|
|
|
|
.. note:: The input quantization parameters are propagated to the output.
|
|
|
|
See :class:`~torch.nn.quantized.MaxPool2d` for details.
|
|
"""
|
|
if return_indices:
|
|
raise NotImplementedError("return_indices is not yet implemented!")
|
|
if stride is None:
|
|
stride = torch.jit.annotate(List[int], [])
|
|
return torch.nn.functional.max_pool2d(input, kernel_size, stride, padding,
|
|
dilation, ceil_mode, return_indices)
|
|
|
|
def relu(input, inplace=False):
|
|
# type: (Tensor, bool) -> Tensor
|
|
r"""relu(input, inplace=False) -> Tensor
|
|
|
|
Applies the rectified linear unit function element-wise.
|
|
See :class:`~torch.nn.quantized.ReLU` for more details.
|
|
|
|
Args:
|
|
input: quantized input
|
|
inplace: perform the computation inplace
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.relu' must be quantized!")
|
|
if inplace:
|
|
return torch.relu_(input)
|
|
else:
|
|
return torch.relu(input)
|
|
|
|
def leaky_relu(input, negative_slope=0.01, inplace=False,
|
|
scale=None, zero_point=None):
|
|
# type: (Tensor, float, bool, float, int) -> Tensor
|
|
r"""
|
|
Quantized version of the.
|
|
leaky_relu(input, negative_slope=0.01, inplace=False, scale, zero_point) -> Tensor
|
|
|
|
Applies element-wise,
|
|
:math:`\text{LeakyReLU}(x) = \max(0, x) + \text{negative\_slope} * \min(0, x)`
|
|
|
|
Args:
|
|
input: Quaintized input
|
|
negative_slope: The slope of the negative input
|
|
inplace: Inplace modification of the input tensor
|
|
scale, zero_point: Scale and zero point of thhe output tensor.
|
|
|
|
See :class:`~torch.nn.LeakyReLU` for more details.
|
|
"""
|
|
if scale is not None and zero_point is not None:
|
|
assert not inplace, "Cannot rescale with `inplace`"
|
|
output = torch._empty_affine_quantized(
|
|
input.shape, scale=scale, zero_point=int(zero_point), dtype=input.dtype)
|
|
torch._C._nn.leaky_relu(input, negative_slope, out=output)
|
|
return output
|
|
if inplace:
|
|
result = torch._C._nn.leaky_relu_(input, negative_slope)
|
|
else:
|
|
result = torch._C._nn.leaky_relu(input, negative_slope)
|
|
return result
|
|
|
|
def hardtanh(input, min_val=-1., max_val=1., inplace=False):
|
|
# type: (Tensor, float, float, bool) -> Tensor
|
|
r"""
|
|
hardtanh(input, min_val=-1., max_val=1., inplace=False) -> Tensor
|
|
|
|
Applies the quantized HardTanh function element-wise, with scale and
|
|
zero-point carried over from the input tensor. See :class:`~torch.nn.Hardtanh`
|
|
for more details.
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.hardtanh' must be quantized!")
|
|
if inplace:
|
|
return torch._C._nn.hardtanh_(input, min_val, max_val)
|
|
return torch._C._nn.hardtanh(input, min_val, max_val)
|
|
|
|
def hardswish(input, scale, zero_point):
|
|
# type: (Tensor, float, int) -> Tensor
|
|
r"""Applies the quantized version of the hardswish function, element-wise,
|
|
as described in the paper:
|
|
|
|
`Searching for MobileNetV3`_.
|
|
|
|
.. math::
|
|
\text{Hardswish}(x) = \begin{cases}
|
|
0 & \text{if~} x \le -3, \\
|
|
x & \text{if~} x \ge +3, \\
|
|
x^2/6 & \text{otherwise}
|
|
\end{cases}
|
|
|
|
Args:
|
|
input: quantized input
|
|
scale, zero_point: Scale and zero point of the output tensor.
|
|
|
|
See :class:`~torch.nn.Hardswish` for more details.
|
|
|
|
.. _`Searching for MobileNetV3`:
|
|
https://arxiv.org/abs/1905.02244
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.hardswish' must be quantized!")
|
|
return torch._ops.ops.quantized.hardswish(input, scale, zero_point)
|
|
|
|
def threshold(input, threshold, value):
|
|
# type: (Tensor, float, float) -> Tensor
|
|
r"""Applies the quantized version of the threshold function element-wise:
|
|
|
|
.. math::
|
|
x = \begin{cases}
|
|
x & \text{if~} x > \text{threshold} \\
|
|
\text{value} & \text{otherwise}
|
|
\end{cases}
|
|
|
|
See :class:`~torch.nn.Threshold` for more details.
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.threshold' must be quantized!")
|
|
if threshold is None:
|
|
raise ValueError("Input to 'threshold' must be specified!")
|
|
if value is None:
|
|
raise ValueError("Input to 'value' must be specified!")
|
|
return torch._ops.ops.quantized.threshold(input, threshold, value)
|
|
|
|
def elu(input, scale, zero_point, alpha=1.):
|
|
# type: (Tensor, float, int, float) -> Tensor
|
|
r"""
|
|
Applies the quantized ELU function element-wise:
|
|
|
|
.. math::
|
|
\text{ELU}(x) = \max(0,x) + \min(0, \alpha * (\exp(x) - 1))
|
|
|
|
Args:
|
|
input: quantized input
|
|
scale, zero_point: Scale and zero point of the output tensor.
|
|
alpha: the :math:`\alpha` value for the ELU formulation. Default: 1.0
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.elu' must be quantized!")
|
|
return torch.ops.quantized.elu(input, scale, zero_point, alpha)
|
|
|
|
def hardsigmoid(input):
|
|
# type: (Tensor) -> Tensor
|
|
r"""
|
|
Applies the quantized element-wise function
|
|
|
|
.. math::
|
|
\text{Hardsigmoid}(x) = \begin{cases}
|
|
0 & \text{if~} x \le -3, \\
|
|
1 & \text{if~} x \ge +3, \\
|
|
x / 6 & \text{otherwise}
|
|
\end{cases}
|
|
|
|
See :class:`~torch.nn.Hardsigmoid` for more details.
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.hardsigmoid' must be quantized!")
|
|
return torch._C._nn.hardsigmoid(input)
|
|
|
|
def clamp(input, min_, max_):
|
|
# type: (Tensor, float, float) -> Tensor
|
|
r"""float(input, min_, max_) -> Tensor
|
|
|
|
Applies the clamp function element-wise.
|
|
See :class:`~torch.nn.quantized.clamp` for more details.
|
|
|
|
Args:
|
|
input: quantized input
|
|
min_: minimum value for clamping
|
|
max_: maximum value for clamping
|
|
"""
|
|
if not input.is_quantized:
|
|
raise ValueError("Input to 'quantized.clamp' must be quantized!")
|
|
return torch.clamp(input, min_, max_)
|
|
|
|
def upsample(input, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
|
r"""Upsamples the input to either the given :attr:`size` or the given
|
|
:attr:`scale_factor`
|
|
|
|
.. warning::
|
|
This function is deprecated in favor of
|
|
:func:`torch.nn.quantized.functional.interpolate`.
|
|
This is equivalent with ``nn.quantized.functional.interpolate(...)``.
|
|
|
|
See :func:`torch.nn.functional.interpolate` for implementation details.
|
|
|
|
The input dimensions are interpreted in the form:
|
|
`mini-batch x channels x [optional depth] x [optional height] x width`.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
.. note:: Only 2D input is supported for quantized inputs
|
|
|
|
.. note:: Only the following modes are supported for the quantized inputs:
|
|
|
|
- `bilinear`
|
|
- `nearest`
|
|
|
|
Args:
|
|
input (Tensor): quantized input tensor
|
|
size (int or Tuple[int] or Tuple[int, int] or Tuple[int, int, int]):
|
|
output spatial size.
|
|
scale_factor (float or Tuple[float]): multiplier for spatial size. Has to be an integer.
|
|
mode (string): algorithm used for upsampling:
|
|
``'nearest'`` | ``'bilinear'``
|
|
align_corners (bool, optional): Geometrically, we consider the pixels of the
|
|
input and output as squares rather than points.
|
|
If set to ``True``, the input and output tensors are aligned by the
|
|
center points of their corner pixels, preserving the values at the corner pixels.
|
|
If set to ``False``, the input and output tensors are aligned by the corner
|
|
points of their corner pixels, and the interpolation uses edge value padding
|
|
for out-of-boundary values, making this operation *independent* of input size
|
|
when :attr:`scale_factor` is kept the same. This only has an effect when :attr:`mode`
|
|
is ``'bilinear'``.
|
|
Default: ``False``
|
|
|
|
.. warning::
|
|
With ``align_corners = True``, the linearly interpolating modes
|
|
(`bilinear`) don't proportionally align the
|
|
output and input pixels, and thus the output values can depend on the
|
|
input size. This was the default behavior for these modes up to version
|
|
0.3.1. Since then, the default behavior is ``align_corners = False``.
|
|
See :class:`~torch.nn.Upsample` for concrete examples on how this
|
|
affects the outputs.
|
|
"""
|
|
warnings.warn("nn.quantized.functional.upsample is deprecated. Use nn.quantized.functional.interpolate instead.")
|
|
return interpolate(input, size, scale_factor, mode, align_corners)
|
|
|
|
def upsample_bilinear(input, size=None, scale_factor=None):
|
|
r"""Upsamples the input, using bilinear upsampling.
|
|
|
|
.. warning::
|
|
This function is deprecated in favor of
|
|
:func:`torch.nn.quantized.functional.interpolate`.
|
|
This is equivalent with
|
|
``nn.quantized.functional.interpolate(..., mode='bilinear', align_corners=True)``.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
.. note:: Only 2D inputs are supported
|
|
|
|
Args:
|
|
input (Tensor): quantized input
|
|
size (int or Tuple[int, int]): output spatial size.
|
|
scale_factor (int or Tuple[int, int]): multiplier for spatial size
|
|
"""
|
|
# DeprecationWarning is ignored by default
|
|
warnings.warn("nn.quantized.functional.upsample_bilinear is deprecated. Use nn.quantized.functional.interpolate instead.")
|
|
return interpolate(input, size, scale_factor, mode='bilinear', align_corners=True)
|
|
|
|
def upsample_nearest(input, size=None, scale_factor=None):
|
|
r"""Upsamples the input, using nearest neighbours' pixel values.
|
|
|
|
.. warning::
|
|
This function is deprecated in favor of
|
|
:func:`torch.nn.quantized.functional.interpolate`.
|
|
This is equivalent with ``nn.quantized.functional.interpolate(..., mode='nearest')``.
|
|
|
|
.. note:: The input quantization parameters propagate to the output.
|
|
|
|
.. note:: Only 2D inputs are supported
|
|
|
|
Args:
|
|
input (Tensor): quantized input
|
|
size (int or Tuple[int, int] or Tuple[int, int, int]): output spatial
|
|
size.
|
|
scale_factor (int): multiplier for spatial size. Has to be an integer.
|
|
"""
|
|
# DeprecationWarning is ignored by default
|
|
warnings.warn("nn.quantized.functional.upsample_nearest is deprecated. Use nn.quantized.functional.interpolate instead.")
|
|
return interpolate(input, size, scale_factor, mode='nearest')
|