[quant][ao_migration] torch.nn.quantizabletorch.ao.nn.quantizable. (#78717)

Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.

The list of the `nn.quantized` files that are being migrated:

- [X] `torch.nn.quantized` → `torch.ao.nn.quantized`
    - [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
    - [X] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
    - [X] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
    - [X] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [X] [Current PR] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [ ] `torch.nn.qat` → `torch.ao.nn.qat`
    - [ ] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
    - [ ] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
    - [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
    - [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
    - [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
        - [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
        - [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`

Majority of the files are just moved to the new location.
However, specific files need to be double checked:

- None

Differential Revision: [D36861090](https://our.internmc.facebook.com/intern/diff/D36861090/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36861090/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78717
Approved by: https://github.com/jerryzh168
This commit is contained in:
zaf 2022-08-21 19:34:56 -07:00 committed by PyTorch MergeBot
parent a7344e52b9
commit e0876feb49
16 changed files with 936 additions and 850 deletions

View File

@ -1133,6 +1133,8 @@ Please take a look at `Limitations of Symbolic Tracing <https://docs-preview.pyt
.. They are here for tracking purposes until they are more permanently fixed.
.. py:module:: torch.ao
.. py:module:: torch.ao.nn
.. py:module:: torch.ao.nn.quantizable
.. py:module:: torch.ao.nn.quantizable.modules
.. py:module:: torch.ao.nn.quantized
.. py:module:: torch.ao.nn.sparse
.. py:module:: torch.ao.nn.sparse.quantized

View File

@ -5,7 +5,11 @@
"torch.nn.quantized.modules": "torch.ao.nn.quantized.modules",
"torch.nn.quantized.dynamic": "torch.ao.nn.quantized.dynamic",
"torch.nn.quantized.dynamic.modules": "torch.ao.nn.quantized.dynamic.modules",
"torch.nn.quantized.dynamic.modules.rnn": "torch.ao.nn.quantized.dynamic.modules.rnn"
"torch.nn.quantized.dynamic.modules.rnn": "torch.ao.nn.quantized.dynamic.modules.rnn",
"torch.nn.quantizable": "torch.ao.nn.quantizable",
"torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules",
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
},
"torch.ao.quantization": [
"ABC",

View File

@ -374,3 +374,27 @@ class TestAOMigrationNNQuantized(AOMigrationTestCase):
]
self._test_function_import('sparse', function_list,
base='nn.quantized._reference.modules')
def test_package_import_nn_quantizable(self):
self._test_package_import('quantizable', base='nn')
def test_package_import_nn_quantizable_modules(self):
r"""Tests the migration of the torch.nn.quantizable.modules"""
self._test_package_import('modules', base='nn.quantizable')
self._test_package_import('modules.activation', base='nn.quantizable')
self._test_package_import('modules.rnn', base='nn.quantizable')
def test_import_nn_quantizable_activation(self):
module_list = [
# Modules
'MultiheadAttention',
]
self._test_function_import('activation', module_list, base='nn.quantizable.modules')
def test_import_nn_quantizable_rnn(self):
module_list = [
# Modules
'LSTM',
'LSTMCell',
]
self._test_function_import('rnn', module_list, base='nn.quantizable.modules')

View File

@ -1,3 +1,17 @@
# We are exposing both subpackages to the end-user.
from . import sparse
from . import quantized
# We are exposing all subpackages to the end-user.
# Because of possible inter-dependency, we want to avoid
# the cyclic imports, thus implementing lazy version
# as per https://peps.python.org/pep-0562/
import importlib
__all__ = [
"quantizable",
"quantized",
"sparse",
]
def __getattr__(name):
if name in __all__:
return importlib.import_module("." + name, __name__)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

View File

@ -0,0 +1 @@
from .modules import * # noqa: F403

View File

@ -0,0 +1,9 @@
from .activation import MultiheadAttention
from .rnn import LSTM
from .rnn import LSTMCell
__all__ = [
'LSTM',
'LSTMCell',
'MultiheadAttention',
]

View File

@ -0,0 +1,454 @@
import torch
import torch.jit # this is needed to avoid a circular import
from torch import nn
import torch.nn.functional as nnF
from torch import Tensor
from typing import Optional, Tuple
import warnings
class MultiheadAttention(nn.MultiheadAttention):
_FLOAT_MODULE = nn.MultiheadAttention
r"""Quantizable implementation of the MultiheadAttention.
Note::
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
information
Allows the model to jointly attend to information from different
representation subspaces.
See reference: Attention Is All You Need
The original MHA module is not quantizable.
This reimplements it by explicitly instantiating the linear layers.
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
to :attr:`embed_dim` such that query, key, and value have the same
number of features.
Examples::
>>> import torch.nn.quantizable as nnqa
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
Note::
Please, follow the quantization flow to convert the quantizable MHA.
"""
__constants__ = ['batch_first']
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: int = None, vdim: int = None, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
bias, add_bias_kv,
add_zero_attn, kdim, vdim, batch_first,
**factory_kwargs)
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
# Functionals
self.q_scaling_product = torch.nn.quantized.FloatFunctional()
# note: importing torch.nn.quantized at top creates a circular import
# Quant/Dequant
self.quant_attn_output = torch.ao.quantization.QuantStub()
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
self.dequant_q = torch.ao.quantization.DeQuantStub()
self.dequant_k = torch.ao.quantization.DeQuantStub()
self.dequant_v = torch.ao.quantization.DeQuantStub()
def _get_name(self):
return 'QuantizableMultiheadAttention'
@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
# Setting the dropout to 0.0!
observed = cls(other.embed_dim, other.num_heads, other.dropout,
(other.in_proj_bias is not None),
(other.bias_k is not None),
other.add_zero_attn, other.kdim, other.vdim)
observed.bias_k = other.bias_k
observed.bias_v = other.bias_v
observed.qconfig = other.qconfig
# Set the linear weights
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
if other._qkv_same_embed_dim:
# Use separate params
bias = other.in_proj_bias
_start = 0
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_Q.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_Q.bias = bias
bias = other.in_proj_bias
_start = _end
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_K.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_K.bias = bias
bias = other.in_proj_bias
_start = _end
weight = other.in_proj_weight[_start:, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
observed.linear_V.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_V.bias = bias
else:
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None:
observed.linear_Q.bias = None # type: ignore[assignment]
observed.linear_K.bias = None # type: ignore[assignment]
observed.linear_V.bias = None # type: ignore[assignment]
else:
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
observed.eval()
# Explicit prepare
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@torch.jit.unused
def dequantize(self):
r"""Utility to convert the quantized MHA back to float.
The motivation for this is that it is not trivial to conver the weights
from the format that is used in the quantized version back to the
float.
"""
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
(self.in_proj_bias is not None),
(self.bias_k is not None),
self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
if self.bias_k is not None:
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
if self.bias_v is not None:
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
# Set the linear weights
# Note: Because the linear layers are quantized, mypy does not nkow how
# to deal with them -- might need to ignore the typing checks.
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
fp.out_proj.weight = nn.Parameter(w.dequantize())
if b is not None:
fp.out_proj.bias = nn.Parameter(b)
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
wQ = wQ.dequantize()
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
wK = wK.dequantize()
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
wV = wV.dequantize()
if fp._qkv_same_embed_dim:
# Use separate params
_start = 0
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None:
assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ
_start = _end
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None:
assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK
_start = _end
fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None:
assert all(bV == 0)
fp.in_proj_bias[_start:] = bV
else:
fp.q_proj_weight = nn.Parameter(wQ)
fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None:
self.linear_Q.bias = None
self.linear_K.bias = None
self.linear_V.bias = None
else:
fp.in_proj_bias[0:fp.embed_dim] = bQ
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
fp.in_proj_bias[(fp.embed_dim * 2):] = bV
return fp
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
# See nn.quantized.MultiheadAttention
raise NotImplementedError("It looks like you are trying to prepare an "
"MHA module. Please, see "
"the examples on quantizable MHAs.")
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Note::
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
information
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(N, num_heads, L, S)`.
"""
return self._forward_impl(query, key, value, key_padding_mask,
need_weights, attn_mask, average_attn_weights)
def _forward_impl(self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
# This version will not deal with the static key/value pairs.
# Keeping it here for future changes.
#
# TODO: This method has some duplicate lines with the
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
static_k = None
static_v = None
if self.batch_first:
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
tgt_len, bsz, embed_dim_to_check = query.size()
assert self.embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = self.embed_dim // self.num_heads
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
q = self.linear_Q(query)
k = self.linear_K(key)
v = self.linear_V(value)
q = self.q_scaling_product.mul_scalar(q, scaling)
if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)
if self.bias_k is not None and self.bias_v is not None:
if static_k is None and static_v is None:
# Explicitly assert that bias_k and bias_v are not None
# in a way that TorchScript can understand.
bias_k = self.bias_k
assert bias_k is not None
bias_v = self.bias_v
assert bias_v is not None
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = nnF.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert self.bias_k is None
assert self.bias_v is None
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * self.num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * self.num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
if k.is_quantized:
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
k = torch.cat([k, k_zeros], dim=1)
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
if v.is_quantized:
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None:
attn_mask = nnF.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
# Leaving the quantized zone here
q = self.dequant_q(q)
k = self.dequant_k(k)
v = self.dequant_v(v)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_output_weights = nnF.softmax(
attn_output_weights, dim=-1)
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
if self.batch_first:
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
else:
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
# Reentering the quantized zone
attn_output = self.quant_attn_output(attn_output)
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
return attn_output, attn_output_weights
else:
return attn_output, None

View File

@ -0,0 +1,386 @@
import numbers
from typing import Optional, Tuple
import warnings
import torch
from torch import Tensor
"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""
class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTMCell(10, 20)
>>> input = torch.randn(6, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.gates = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
if hidden is None or hidden[0] is None or hidden[1] is None:
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = torch.sigmoid(input_gate)
forget_gate = torch.sigmoid(forget_gate)
cell_gate = torch.tanh(cell_gate)
out_gate = torch.sigmoid(out_gate)
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
cy = fgate_cx_igate_cgate
tanh_cy = torch.tanh(cy)
hy = self.ogate_cy.mul(out_gate, tanh_cy)
return hy, cy
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
if is_quantized:
h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8)
c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8)
return h, c
def _get_name(self):
return 'QuantizableLSTMCell'
@classmethod
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
wi, wh: Weights for the input and hidden layers
bi, bh: Biases for the input and hidden layers
"""
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
bias=(bi is not None))
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
return cell
@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh,
other.bias_ih, other.bias_hh)
observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig
return observed
class _LSTMSingleLayer(torch.nn.Module):
r"""A single one-directional LSTM layer.
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
for xx in x:
hidden = self.cell(xx, hidden)
result.append(hidden[0]) # type: ignore[index]
result_tensor = torch.stack(result, 0)
return result_tensor, hidden
@classmethod
def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs)
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
layer.cell = cell
return layer
class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
batch_first: bool = False, bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
if hidden is None:
hx_fw, cx_fw = (None, None)
else:
hx_fw, cx_fw = hidden
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
if self.bidirectional:
if hx_fw is None:
hx_bw = None
else:
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
if cx_fw is None:
cx_bw = None
else:
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
if hx_bw is not None and cx_bw is not None:
hidden_bw = hx_bw, cx_bw
if hx_fw is None and cx_fw is None:
hidden_fw = None
else:
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
if hasattr(self, 'layer_bw') and self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
if hidden_fw is None and hidden_bw is None:
h = None
c = None
elif hidden_fw is None:
(h, c) = torch.jit._unwrap_optional(hidden_bw)
elif hidden_bw is None:
(h, c) = torch.jit._unwrap_optional(hidden_fw)
else:
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
return result, (h, c)
@classmethod
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
There is no FP equivalent of this class. This function is here just to
mimic the behavior of the `prepare` within the `torch.ao.quantization`
flow.
"""
assert hasattr(other, 'qconfig') or (qconfig is not None)
input_size = kwargs.get('input_size', other.input_size)
hidden_size = kwargs.get('hidden_size', other.hidden_size)
bias = kwargs.get('bias', other.bias)
batch_first = kwargs.get('batch_first', other.batch_first)
bidirectional = kwargs.get('bidirectional', other.bidirectional)
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
layer.qconfig = getattr(other, 'qconfig', qconfig)
wi = getattr(other, f'weight_ih_l{layer_idx}')
wh = getattr(other, f'weight_hh_l{layer_idx}')
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
if other.bidirectional:
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples below.
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> # To get the weights:
>>> # xdoctest: +SKIP
>>> print(rnn.layers[0].weight_ih)
tensor([[...]])
>>> print(rnn.layers[0].weight_hh)
AssertionError: There is no reverse path in the non-bidirectional layer
"""
_FLOAT_MODULE = torch.nn.LSTM
def __init__(self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True,
batch_first: bool = False, dropout: float = 0.,
bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.training = False # We don't want to train using this module
num_directions = 2 if bidirectional else 1
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0:
warnings.warn("dropout option for quantizable LSTM is ignored. "
"If you are training, please, use nn.LSTM version "
"followed by `prepare` step.")
if num_layers == 1:
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} "
"and num_layers={}".format(dropout, num_layers))
layers = [_LSTMLayer(self.input_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional, **factory_kwargs)]
for layer in range(1, num_layers):
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs))
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
max_batch_size = x.size(1)
num_directions = 2 if self.bidirectional else 1
if hidden is None:
zeros = torch.zeros(num_directions, max_batch_size,
self.hidden_size, dtype=torch.float,
device=x.device)
zeros.squeeze_(0)
if x.is_quantized:
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
zero_point=0, dtype=x.dtype)
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
else:
hidden_non_opt = torch.jit._unwrap_optional(hidden)
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
else:
hxcx = hidden_non_opt
hx_list = []
cx_list = []
for idx, layer in enumerate(self.layers):
x, (h, c) = layer(x, hxcx[idx])
hx_list.append(torch.jit._unwrap_optional(h))
cx_list.append(torch.jit._unwrap_optional(c))
hx_tensor = torch.stack(hx_list)
cx_tensor = torch.stack(cx_list)
# We are creating another dimension for bidirectional case
# need to collapse it
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
if self.batch_first:
x = x.transpose(0, 1)
return x, (hx_tensor, cx_tensor)
def _get_name(self):
return 'QuantizableLSTM'
@classmethod
def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE)
assert (hasattr(other, 'qconfig') or qconfig)
observed = cls(other.input_size, other.hidden_size, other.num_layers,
other.bias, other.batch_first, other.dropout,
other.bidirectional)
observed.qconfig = getattr(other, 'qconfig', qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
batch_first=False)
observed.eval()
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
raise NotImplementedError("It looks like you are trying to convert a "
"non-quantizable LSTM module. Please, see "
"the examples on quantizable LSTMs.")

View File

@ -1,4 +1,12 @@
import torch
# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable`
# packages. However, the `quantizable` package uses "lazy imports"
# to avoid circular dependency.
# Hence we need to include it here to make sure it is resolved before
# they are used in the modules.
import torch.ao.nn.quantizable
from torch.nn.modules.pooling import MaxPool2d
from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU

View File

@ -184,8 +184,9 @@ class Softmax(torch.nn.Softmax):
def from_reference(cls, mod, scale, zero_point):
return cls(mod.dim, float(scale), int(zero_point))
class MultiheadAttention(torch.nn.quantizable.MultiheadAttention):
_FLOAT_MODULE = torch.nn.quantizable.MultiheadAttention
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self):
return "QuantizedMultiheadAttention"

View File

@ -1,6 +1,6 @@
import torch
class LSTM(torch.nn.quantizable.LSTM):
class LSTM(torch.ao.nn.quantizable.LSTM):
r"""A quantized long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`

View File

@ -16,6 +16,9 @@ import torch.nn.qat.dynamic as nnqatd
from typing import Optional, Union, Dict, Set, Callable, Any
# Because `torch.ao.nn` uses lazy imports, we need to make
# sure we import the contents explicitly here.
import torch.ao.nn.sparse
import torch.ao.nn as ao_nn
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
from torch.ao.quantization.fake_quantize import (

View File

@ -1,13 +1,12 @@
import torch
import torch.ao.nn
def get_static_sparse_quantized_mapping():
import torch.ao.nn.sparse
_static_sparse_quantized_mapping = dict({
torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear,
})
return _static_sparse_quantized_mapping
def get_dynamic_sparse_quantized_mapping():
import torch.ao.nn.sparse
_dynamic_sparse_quantized_mapping = dict({
torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear,
})

View File

@ -1,6 +1,6 @@
from .activation import MultiheadAttention
from .rnn import LSTM
from .rnn import LSTMCell
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
from torch.ao.nn.quantizable.modules.rnn import LSTM
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
__all__ = [
'LSTM',

View File

@ -1,454 +1,10 @@
import torch
import torch.jit # this is needed to avoid a circular import
from torch import nn
import torch.nn.functional as nnF
# flake8: noqa: F401
r"""Quantizable Modules
from torch import Tensor
from typing import Optional, Tuple
import warnings
class MultiheadAttention(nn.MultiheadAttention):
_FLOAT_MODULE = nn.MultiheadAttention
r"""Quantizable implementation of the MultiheadAttention.
Note::
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
information
Allows the model to jointly attend to information from different
representation subspaces.
See reference: Attention Is All You Need
The original MHA module is not quantizable.
This reimplements it by explicitly instantiating the linear layers.
.. math::
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
Args:
embed_dim: total dimension of the model.
num_heads: parallel attention heads.
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
bias: add bias as module parameter. Default: True.
add_bias_kv: add bias to the key and value sequences at dim=0.
add_zero_attn: add a new batch of zeros to the key and
value sequences at dim=1.
kdim: total number of features in key. Default: None.
vdim: total number of features in value. Default: None.
batch_first: If ``True``, then the input and output tensors are provided
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
to :attr:`embed_dim` such that query, key, and value have the same
number of features.
Examples::
>>> import torch.nn.quantizable as nnqa
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
Note::
Please, follow the quantization flow to convert the quantizable MHA.
"""
__constants__ = ['batch_first']
def __init__(self, embed_dim: int, num_heads: int,
dropout: float = 0., bias: bool = True,
add_bias_kv: bool = False, add_zero_attn: bool = False,
kdim: int = None, vdim: int = None, batch_first: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
bias, add_bias_kv,
add_zero_attn, kdim, vdim, batch_first,
**factory_kwargs)
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
# Functionals
self.q_scaling_product = torch.nn.quantized.FloatFunctional()
# note: importing torch.nn.quantized at top creates a circular import
# Quant/Dequant
self.quant_attn_output = torch.ao.quantization.QuantStub()
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
self.dequant_q = torch.ao.quantization.DeQuantStub()
self.dequant_k = torch.ao.quantization.DeQuantStub()
self.dequant_v = torch.ao.quantization.DeQuantStub()
def _get_name(self):
return 'QuantizableMultiheadAttention'
@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
# Setting the dropout to 0.0!
observed = cls(other.embed_dim, other.num_heads, other.dropout,
(other.in_proj_bias is not None),
(other.bias_k is not None),
other.add_zero_attn, other.kdim, other.vdim)
observed.bias_k = other.bias_k
observed.bias_v = other.bias_v
observed.qconfig = other.qconfig
# Set the linear weights
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
if other._qkv_same_embed_dim:
# Use separate params
bias = other.in_proj_bias
_start = 0
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_Q.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_Q.bias = bias
bias = other.in_proj_bias
_start = _end
_end = _start + other.embed_dim
weight = other.in_proj_weight[_start:_end, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
observed.linear_K.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_K.bias = bias
bias = other.in_proj_bias
_start = _end
weight = other.in_proj_weight[_start:, :]
if bias is not None:
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
observed.linear_V.weight = torch.nn.Parameter(weight,
weight.requires_grad)
observed.linear_V.bias = bias
else:
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None:
observed.linear_Q.bias = None # type: ignore[assignment]
observed.linear_K.bias = None # type: ignore[assignment]
observed.linear_V.bias = None # type: ignore[assignment]
else:
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
observed.eval()
# Explicit prepare
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@torch.jit.unused
def dequantize(self):
r"""Utility to convert the quantized MHA back to float.
The motivation for this is that it is not trivial to conver the weights
from the format that is used in the quantized version back to the
float.
"""
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
(self.in_proj_bias is not None),
(self.bias_k is not None),
self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
if self.bias_k is not None:
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
if self.bias_v is not None:
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
# Set the linear weights
# Note: Because the linear layers are quantized, mypy does not nkow how
# to deal with them -- might need to ignore the typing checks.
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
fp.out_proj.weight = nn.Parameter(w.dequantize())
if b is not None:
fp.out_proj.bias = nn.Parameter(b)
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
wQ = wQ.dequantize()
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
wK = wK.dequantize()
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
wV = wV.dequantize()
if fp._qkv_same_embed_dim:
# Use separate params
_start = 0
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None:
assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ
_start = _end
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None:
assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK
_start = _end
fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None:
assert all(bV == 0)
fp.in_proj_bias[_start:] = bV
else:
fp.q_proj_weight = nn.Parameter(wQ)
fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None:
self.linear_Q.bias = None
self.linear_K.bias = None
self.linear_V.bias = None
else:
fp.in_proj_bias[0:fp.embed_dim] = bQ
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
fp.in_proj_bias[(fp.embed_dim * 2):] = bV
return fp
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
# See nn.quantized.MultiheadAttention
raise NotImplementedError("It looks like you are trying to prepare an "
"MHA module. Please, see "
"the examples on quantizable MHAs.")
def forward(self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
r"""
Note::
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
information
Args:
query, key, value: map a query and a set of key-value pairs to an output.
See "Attention Is All You Need" for more details.
key_padding_mask: if provided, specified padding elements in the key will
be ignored by the attention. When given a binary mask and a value is True,
the corresponding value on the attention layer will be ignored. When given
a byte mask and a value is non-zero, the corresponding value on the attention
layer will be ignored
need_weights: output attn_output_weights.
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
Shape:
- Inputs:
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
If a ByteTensor is provided, the non-zero positions will be ignored while the position
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
is provided, it will be added to the attention weight.
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
- Outputs:
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
head of shape :math:`(N, num_heads, L, S)`.
"""
return self._forward_impl(query, key, value, key_padding_mask,
need_weights, attn_mask, average_attn_weights)
def _forward_impl(self,
query: Tensor,
key: Tensor,
value: Tensor,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
# This version will not deal with the static key/value pairs.
# Keeping it here for future changes.
#
# TODO: This method has some duplicate lines with the
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
static_k = None
static_v = None
if self.batch_first:
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
tgt_len, bsz, embed_dim_to_check = query.size()
assert self.embed_dim == embed_dim_to_check
# allow MHA to have different sizes for the feature dimension
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
head_dim = self.embed_dim // self.num_heads
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
scaling = float(head_dim) ** -0.5
q = self.linear_Q(query)
k = self.linear_K(key)
v = self.linear_V(value)
q = self.q_scaling_product.mul_scalar(q, scaling)
if attn_mask is not None:
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
if attn_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
attn_mask = attn_mask.to(torch.bool)
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(0)
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 2D attn_mask is not correct.')
elif attn_mask.dim() == 3:
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
raise RuntimeError('The size of the 3D attn_mask is not correct.')
else:
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
# attn_mask's dim is 3 now.
# convert ByteTensor key_padding_mask to bool
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
key_padding_mask = key_padding_mask.to(torch.bool)
if self.bias_k is not None and self.bias_v is not None:
if static_k is None and static_v is None:
# Explicitly assert that bias_k and bias_v are not None
# in a way that TorchScript can understand.
bias_k = self.bias_k
assert bias_k is not None
bias_v = self.bias_v
assert bias_v is not None
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
attn_mask = nnF.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
else:
assert static_k is None, "bias cannot be added to static key."
assert static_v is None, "bias cannot be added to static value."
else:
assert self.bias_k is None
assert self.bias_v is None
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
if k is not None:
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if v is not None:
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
if static_k is not None:
assert static_k.size(0) == bsz * self.num_heads
assert static_k.size(2) == head_dim
k = static_k
if static_v is not None:
assert static_v.size(0) == bsz * self.num_heads
assert static_v.size(2) == head_dim
v = static_v
src_len = k.size(1)
if key_padding_mask is not None:
assert key_padding_mask.size(0) == bsz
assert key_padding_mask.size(1) == src_len
if self.add_zero_attn:
src_len += 1
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
if k.is_quantized:
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
k = torch.cat([k, k_zeros], dim=1)
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
if v.is_quantized:
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None:
attn_mask = nnF.pad(attn_mask, (0, 1))
if key_padding_mask is not None:
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
# Leaving the quantized zone here
q = self.dequant_q(q)
k = self.dequant_k(k)
v = self.dequant_v(v)
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
else:
attn_output_weights += attn_mask
if key_padding_mask is not None:
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
attn_output_weights = attn_output_weights.masked_fill(
key_padding_mask.unsqueeze(1).unsqueeze(2),
float('-inf'),
)
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_output_weights = nnF.softmax(
attn_output_weights, dim=-1)
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
attn_output = torch.bmm(attn_output_weights, v)
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
if self.batch_first:
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
else:
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
# Reentering the quantized zone
attn_output = self.quant_attn_output(attn_output)
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
if need_weights:
# average attention weights over heads
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
if average_attn_weights:
attn_output_weights = attn_output_weights.mean(dim=1)
return attn_output, attn_output_weights
else:
return attn_output, None
This file is in the process of migration to `torch/ao/nn/quantizable`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantizable/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention

View File

@ -1,386 +1,11 @@
import numbers
from typing import Optional, Tuple
import warnings
import torch
from torch import Tensor
# flake8: noqa: F401
r"""Quantizable Modules
This file is in the process of migration to `torch/ao/nn/quantizable`, and
is kept here for compatibility while the migration process is ongoing.
If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantizable/modules`,
while adding an import statement here.
"""
We will recreate all the RNN modules as we require the modules to be decomposed
into its building blocks to be able to observe.
"""
class LSTMCell(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM) cell.
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTMCell(10, 20)
>>> input = torch.randn(6, 10)
>>> hx = torch.randn(3, 20)
>>> cx = torch.randn(3, 20)
>>> output = []
>>> for i in range(6):
... hx, cx = rnn(input[i], (hx, cx))
... output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
self.gates = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
if hidden is None or hidden[0] is None or hidden[1] is None:
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = torch.sigmoid(input_gate)
forget_gate = torch.sigmoid(forget_gate)
cell_gate = torch.tanh(cell_gate)
out_gate = torch.sigmoid(out_gate)
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
cy = fgate_cx_igate_cgate
tanh_cy = torch.tanh(cy)
hy = self.ogate_cy.mul(out_gate, tanh_cy)
return hy, cy
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
if is_quantized:
h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8)
c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8)
return h, c
def _get_name(self):
return 'QuantizableLSTMCell'
@classmethod
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
wi, wh: Weights for the input and hidden layers
bi, bh: Biases for the input and hidden layers
"""
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
bias=(bi is not None))
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
return cell
@classmethod
def from_float(cls, other):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
observed = cls.from_params(other.weight_ih, other.weight_hh,
other.bias_ih, other.bias_hh)
observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig
return observed
class _LSTMSingleLayer(torch.nn.Module):
r"""A single one-directional LSTM layer.
The difference between a layer and a cell is that the layer can process a
sequence, while the cell only expects an instantaneous value.
"""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
for xx in x:
hidden = self.cell(xx, hidden)
result.append(hidden[0]) # type: ignore[index]
result_tensor = torch.stack(result, 0)
return result_tensor, hidden
@classmethod
def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs)
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
layer.cell = cell
return layer
class _LSTMLayer(torch.nn.Module):
r"""A single bi-directional LSTM layer."""
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
batch_first: bool = False, bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
if hidden is None:
hx_fw, cx_fw = (None, None)
else:
hx_fw, cx_fw = hidden
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
if self.bidirectional:
if hx_fw is None:
hx_bw = None
else:
hx_bw = hx_fw[1]
hx_fw = hx_fw[0]
if cx_fw is None:
cx_bw = None
else:
cx_bw = cx_fw[1]
cx_fw = cx_fw[0]
if hx_bw is not None and cx_bw is not None:
hidden_bw = hx_bw, cx_bw
if hx_fw is None and cx_fw is None:
hidden_fw = None
else:
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
if hasattr(self, 'layer_bw') and self.bidirectional:
x_reversed = x.flip(0)
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
result_bw = result_bw.flip(0)
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
if hidden_fw is None and hidden_bw is None:
h = None
c = None
elif hidden_fw is None:
(h, c) = torch.jit._unwrap_optional(hidden_bw)
elif hidden_bw is None:
(h, c) = torch.jit._unwrap_optional(hidden_fw)
else:
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
else:
result = result_fw
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
if self.batch_first:
result.transpose_(0, 1)
return result, (h, c)
@classmethod
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
r"""
There is no FP equivalent of this class. This function is here just to
mimic the behavior of the `prepare` within the `torch.ao.quantization`
flow.
"""
assert hasattr(other, 'qconfig') or (qconfig is not None)
input_size = kwargs.get('input_size', other.input_size)
hidden_size = kwargs.get('hidden_size', other.hidden_size)
bias = kwargs.get('bias', other.bias)
batch_first = kwargs.get('batch_first', other.batch_first)
bidirectional = kwargs.get('bidirectional', other.bidirectional)
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
layer.qconfig = getattr(other, 'qconfig', qconfig)
wi = getattr(other, f'weight_ih_l{layer_idx}')
wh = getattr(other, f'weight_hh_l{layer_idx}')
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
if other.bidirectional:
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
class LSTM(torch.nn.Module):
r"""A quantizable long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples below.
Examples::
>>> import torch.nn.quantizable as nnqa
>>> rnn = nnqa.LSTM(10, 20, 2)
>>> input = torch.randn(5, 3, 10)
>>> h0 = torch.randn(2, 3, 20)
>>> c0 = torch.randn(2, 3, 20)
>>> output, (hn, cn) = rnn(input, (h0, c0))
>>> # To get the weights:
>>> # xdoctest: +SKIP
>>> print(rnn.layers[0].weight_ih)
tensor([[...]])
>>> print(rnn.layers[0].weight_hh)
AssertionError: There is no reverse path in the non-bidirectional layer
"""
_FLOAT_MODULE = torch.nn.LSTM
def __init__(self, input_size: int, hidden_size: int,
num_layers: int = 1, bias: bool = True,
batch_first: bool = False, dropout: float = 0.,
bidirectional: bool = False,
device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bias = bias
self.batch_first = batch_first
self.dropout = float(dropout)
self.bidirectional = bidirectional
self.training = False # We don't want to train using this module
num_directions = 2 if bidirectional else 1
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
isinstance(dropout, bool):
raise ValueError("dropout should be a number in range [0, 1] "
"representing the probability of an element being "
"zeroed")
if dropout > 0:
warnings.warn("dropout option for quantizable LSTM is ignored. "
"If you are training, please, use nn.LSTM version "
"followed by `prepare` step.")
if num_layers == 1:
warnings.warn("dropout option adds dropout after all but last "
"recurrent layer, so non-zero dropout expects "
"num_layers greater than 1, but got dropout={} "
"and num_layers={}".format(dropout, num_layers))
layers = [_LSTMLayer(self.input_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional, **factory_kwargs)]
for layer in range(1, num_layers):
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
self.bias, batch_first=False,
bidirectional=self.bidirectional,
**factory_kwargs))
self.layers = torch.nn.ModuleList(layers)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
if self.batch_first:
x = x.transpose(0, 1)
max_batch_size = x.size(1)
num_directions = 2 if self.bidirectional else 1
if hidden is None:
zeros = torch.zeros(num_directions, max_batch_size,
self.hidden_size, dtype=torch.float,
device=x.device)
zeros.squeeze_(0)
if x.is_quantized:
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
zero_point=0, dtype=x.dtype)
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
else:
hidden_non_opt = torch.jit._unwrap_optional(hidden)
if isinstance(hidden_non_opt[0], Tensor):
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
max_batch_size,
self.hidden_size).unbind(0)
hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
else:
hxcx = hidden_non_opt
hx_list = []
cx_list = []
for idx, layer in enumerate(self.layers):
x, (h, c) = layer(x, hxcx[idx])
hx_list.append(torch.jit._unwrap_optional(h))
cx_list.append(torch.jit._unwrap_optional(c))
hx_tensor = torch.stack(hx_list)
cx_tensor = torch.stack(cx_list)
# We are creating another dimension for bidirectional case
# need to collapse it
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
if self.batch_first:
x = x.transpose(0, 1)
return x, (hx_tensor, cx_tensor)
def _get_name(self):
return 'QuantizableLSTM'
@classmethod
def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE)
assert (hasattr(other, 'qconfig') or qconfig)
observed = cls(other.input_size, other.hidden_size, other.num_layers,
other.bias, other.batch_first, other.dropout,
other.bidirectional)
observed.qconfig = getattr(other, 'qconfig', qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
batch_first=False)
observed.eval()
observed = torch.ao.quantization.prepare(observed, inplace=True)
return observed
@classmethod
def from_observed(cls, other):
# The whole flow is float -> observed -> quantized
# This class does float -> observed only
raise NotImplementedError("It looks like you are trying to convert a "
"non-quantizable LSTM module. Please, see "
"the examples on quantizable LSTMs.")
from torch.ao.nn.quantizable.modules.rnn import LSTM
from torch.ao.nn.quantizable.modules.rnn import LSTMCell