mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[quant][ao_migration] torch.nn.quantizable → torch.ao.nn.quantizable. (#78717)
Context: In order to avoid the cluttering of the `torch.nn` namespace
the quantized modules namespace is moved to `torch.ao.nn`.
The list of the `nn.quantized` files that are being migrated:
- [X] `torch.nn.quantized` → `torch.ao.nn.quantized`
- [X] `torch.nn.quantized.functional` → `torch.ao.nn.quantized.functional`
- [X] `torch.nn.quantized.modules` → `torch.ao.nn.quantized.modules`
- [X] `torch.nn.quantized.dynamic` → `torch.ao.nn.quantized.dynamic`
- [X] `torch.nn.quantized._reference` → `torch.ao.nn.quantized._reference`
- [X] [Current PR] `torch.nn.quantizable` → `torch.ao.nn.quantizable`
- [ ] `torch.nn.qat` → `torch.ao.nn.qat`
- [ ] `torch.nn.qat.modules` → `torch.ao.nn.qat.modules`
- [ ] `torch.nn.qat.dynamic` → `torch.ao.nn.qat.dynamic`
- [ ] `torch.nn.intrinsic` → `torch.ao.nn.intrinsic`
- [ ] `torch.nn.intrinsic.modules` → `torch.ao.nn.intrinsic.modules`
- [ ] `torch.nn.intrinsic.qat` → `torch.ao.nn.intrinsic.qat`
- [ ] `torch.nn.intrinsic.quantized` → `torch.ao.nn.intrinsic.quantized`
- [ ] `torch.nn.intrinsic.quantized.modules` → `torch.ao.nn.intrinsic.quantized.modules`
- [ ] `torch.nn.intrinsic.quantized.dynamic` → `torch.ao.nn.intrinsic.quantized.dynamic`
Majority of the files are just moved to the new location.
However, specific files need to be double checked:
- None
Differential Revision: [D36861090](https://our.internmc.facebook.com/intern/diff/D36861090/)
**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36861090/)!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78717
Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
a7344e52b9
commit
e0876feb49
|
|
@ -1133,6 +1133,8 @@ Please take a look at `Limitations of Symbolic Tracing <https://docs-preview.pyt
|
|||
.. They are here for tracking purposes until they are more permanently fixed.
|
||||
.. py:module:: torch.ao
|
||||
.. py:module:: torch.ao.nn
|
||||
.. py:module:: torch.ao.nn.quantizable
|
||||
.. py:module:: torch.ao.nn.quantizable.modules
|
||||
.. py:module:: torch.ao.nn.quantized
|
||||
.. py:module:: torch.ao.nn.sparse
|
||||
.. py:module:: torch.ao.nn.sparse.quantized
|
||||
|
|
|
|||
|
|
@ -5,7 +5,11 @@
|
|||
"torch.nn.quantized.modules": "torch.ao.nn.quantized.modules",
|
||||
"torch.nn.quantized.dynamic": "torch.ao.nn.quantized.dynamic",
|
||||
"torch.nn.quantized.dynamic.modules": "torch.ao.nn.quantized.dynamic.modules",
|
||||
"torch.nn.quantized.dynamic.modules.rnn": "torch.ao.nn.quantized.dynamic.modules.rnn"
|
||||
"torch.nn.quantized.dynamic.modules.rnn": "torch.ao.nn.quantized.dynamic.modules.rnn",
|
||||
"torch.nn.quantizable": "torch.ao.nn.quantizable",
|
||||
"torch.nn.quantizable.modules": "torch.ao.nn.quantizable.modules",
|
||||
"torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation",
|
||||
"torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn"
|
||||
},
|
||||
"torch.ao.quantization": [
|
||||
"ABC",
|
||||
|
|
|
|||
|
|
@ -374,3 +374,27 @@ class TestAOMigrationNNQuantized(AOMigrationTestCase):
|
|||
]
|
||||
self._test_function_import('sparse', function_list,
|
||||
base='nn.quantized._reference.modules')
|
||||
|
||||
def test_package_import_nn_quantizable(self):
|
||||
self._test_package_import('quantizable', base='nn')
|
||||
|
||||
def test_package_import_nn_quantizable_modules(self):
|
||||
r"""Tests the migration of the torch.nn.quantizable.modules"""
|
||||
self._test_package_import('modules', base='nn.quantizable')
|
||||
self._test_package_import('modules.activation', base='nn.quantizable')
|
||||
self._test_package_import('modules.rnn', base='nn.quantizable')
|
||||
|
||||
def test_import_nn_quantizable_activation(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'MultiheadAttention',
|
||||
]
|
||||
self._test_function_import('activation', module_list, base='nn.quantizable.modules')
|
||||
|
||||
def test_import_nn_quantizable_rnn(self):
|
||||
module_list = [
|
||||
# Modules
|
||||
'LSTM',
|
||||
'LSTMCell',
|
||||
]
|
||||
self._test_function_import('rnn', module_list, base='nn.quantizable.modules')
|
||||
|
|
|
|||
|
|
@ -1,3 +1,17 @@
|
|||
# We are exposing both subpackages to the end-user.
|
||||
from . import sparse
|
||||
from . import quantized
|
||||
# We are exposing all subpackages to the end-user.
|
||||
# Because of possible inter-dependency, we want to avoid
|
||||
# the cyclic imports, thus implementing lazy version
|
||||
# as per https://peps.python.org/pep-0562/
|
||||
|
||||
import importlib
|
||||
|
||||
__all__ = [
|
||||
"quantizable",
|
||||
"quantized",
|
||||
"sparse",
|
||||
]
|
||||
|
||||
def __getattr__(name):
|
||||
if name in __all__:
|
||||
return importlib.import_module("." + name, __name__)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
|
|
|||
1
torch/ao/nn/quantizable/__init__.py
Normal file
1
torch/ao/nn/quantizable/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
from .modules import * # noqa: F403
|
||||
9
torch/ao/nn/quantizable/modules/__init__.py
Normal file
9
torch/ao/nn/quantizable/modules/__init__.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
from .activation import MultiheadAttention
|
||||
from .rnn import LSTM
|
||||
from .rnn import LSTMCell
|
||||
|
||||
__all__ = [
|
||||
'LSTM',
|
||||
'LSTMCell',
|
||||
'MultiheadAttention',
|
||||
]
|
||||
454
torch/ao/nn/quantizable/modules/activation.py
Normal file
454
torch/ao/nn/quantizable/modules/activation.py
Normal file
|
|
@ -0,0 +1,454 @@
|
|||
import torch
|
||||
import torch.jit # this is needed to avoid a circular import
|
||||
from torch import nn
|
||||
import torch.nn.functional as nnF
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import warnings
|
||||
|
||||
class MultiheadAttention(nn.MultiheadAttention):
|
||||
_FLOAT_MODULE = nn.MultiheadAttention
|
||||
|
||||
r"""Quantizable implementation of the MultiheadAttention.
|
||||
|
||||
Note::
|
||||
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
|
||||
information
|
||||
|
||||
Allows the model to jointly attend to information from different
|
||||
representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
The original MHA module is not quantizable.
|
||||
This reimplements it by explicitly instantiating the linear layers.
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
bias: add bias as module parameter. Default: True.
|
||||
add_bias_kv: add bias to the key and value sequences at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
kdim: total number of features in key. Default: None.
|
||||
vdim: total number of features in value. Default: None.
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
|
||||
to :attr:`embed_dim` such that query, key, and value have the same
|
||||
number of features.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
Note::
|
||||
Please, follow the quantization flow to convert the quantizable MHA.
|
||||
"""
|
||||
__constants__ = ['batch_first']
|
||||
|
||||
def __init__(self, embed_dim: int, num_heads: int,
|
||||
dropout: float = 0., bias: bool = True,
|
||||
add_bias_kv: bool = False, add_zero_attn: bool = False,
|
||||
kdim: int = None, vdim: int = None, batch_first: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
|
||||
bias, add_bias_kv,
|
||||
add_zero_attn, kdim, vdim, batch_first,
|
||||
**factory_kwargs)
|
||||
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
|
||||
|
||||
# Functionals
|
||||
self.q_scaling_product = torch.nn.quantized.FloatFunctional()
|
||||
# note: importing torch.nn.quantized at top creates a circular import
|
||||
|
||||
# Quant/Dequant
|
||||
self.quant_attn_output = torch.ao.quantization.QuantStub()
|
||||
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
|
||||
self.dequant_q = torch.ao.quantization.DeQuantStub()
|
||||
self.dequant_k = torch.ao.quantization.DeQuantStub()
|
||||
self.dequant_v = torch.ao.quantization.DeQuantStub()
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableMultiheadAttention'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other):
|
||||
assert type(other) == cls._FLOAT_MODULE
|
||||
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
||||
# Setting the dropout to 0.0!
|
||||
observed = cls(other.embed_dim, other.num_heads, other.dropout,
|
||||
(other.in_proj_bias is not None),
|
||||
(other.bias_k is not None),
|
||||
other.add_zero_attn, other.kdim, other.vdim)
|
||||
observed.bias_k = other.bias_k
|
||||
observed.bias_v = other.bias_v
|
||||
observed.qconfig = other.qconfig
|
||||
|
||||
# Set the linear weights
|
||||
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
|
||||
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
|
||||
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
|
||||
if other._qkv_same_embed_dim:
|
||||
# Use separate params
|
||||
bias = other.in_proj_bias
|
||||
_start = 0
|
||||
_end = _start + other.embed_dim
|
||||
weight = other.in_proj_weight[_start:_end, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
||||
observed.linear_Q.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_Q.bias = bias
|
||||
|
||||
bias = other.in_proj_bias
|
||||
_start = _end
|
||||
_end = _start + other.embed_dim
|
||||
weight = other.in_proj_weight[_start:_end, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
||||
observed.linear_K.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_K.bias = bias
|
||||
|
||||
bias = other.in_proj_bias
|
||||
_start = _end
|
||||
weight = other.in_proj_weight[_start:, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
|
||||
observed.linear_V.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_V.bias = bias
|
||||
else:
|
||||
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
|
||||
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
||||
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
||||
if other.in_proj_bias is None:
|
||||
observed.linear_Q.bias = None # type: ignore[assignment]
|
||||
observed.linear_K.bias = None # type: ignore[assignment]
|
||||
observed.linear_V.bias = None # type: ignore[assignment]
|
||||
else:
|
||||
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
|
||||
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
|
||||
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
|
||||
observed.eval()
|
||||
# Explicit prepare
|
||||
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
||||
return observed
|
||||
|
||||
@torch.jit.unused
|
||||
def dequantize(self):
|
||||
r"""Utility to convert the quantized MHA back to float.
|
||||
|
||||
The motivation for this is that it is not trivial to conver the weights
|
||||
from the format that is used in the quantized version back to the
|
||||
float.
|
||||
"""
|
||||
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
|
||||
(self.in_proj_bias is not None),
|
||||
(self.bias_k is not None),
|
||||
self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
|
||||
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
|
||||
if self.bias_k is not None:
|
||||
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
|
||||
if self.bias_v is not None:
|
||||
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
|
||||
|
||||
# Set the linear weights
|
||||
# Note: Because the linear layers are quantized, mypy does not nkow how
|
||||
# to deal with them -- might need to ignore the typing checks.
|
||||
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
||||
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
|
||||
fp.out_proj.weight = nn.Parameter(w.dequantize())
|
||||
if b is not None:
|
||||
fp.out_proj.bias = nn.Parameter(b)
|
||||
|
||||
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
|
||||
wQ = wQ.dequantize()
|
||||
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
|
||||
wK = wK.dequantize()
|
||||
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
|
||||
wV = wV.dequantize()
|
||||
if fp._qkv_same_embed_dim:
|
||||
# Use separate params
|
||||
_start = 0
|
||||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wQ
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bQ == 0)
|
||||
fp.in_proj_bias[_start:_end] = bQ
|
||||
|
||||
_start = _end
|
||||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wK
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bK == 0)
|
||||
fp.in_proj_bias[_start:_end] = bK
|
||||
|
||||
_start = _end
|
||||
fp.in_proj_weight[_start:, :] = wV
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bV == 0)
|
||||
fp.in_proj_bias[_start:] = bV
|
||||
else:
|
||||
fp.q_proj_weight = nn.Parameter(wQ)
|
||||
fp.k_proj_weight = nn.Parameter(wK)
|
||||
fp.v_proj_weight = nn.Parameter(wV)
|
||||
if fp.in_proj_bias is None:
|
||||
self.linear_Q.bias = None
|
||||
self.linear_K.bias = None
|
||||
self.linear_V.bias = None
|
||||
else:
|
||||
fp.in_proj_bias[0:fp.embed_dim] = bQ
|
||||
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
|
||||
fp.in_proj_bias[(fp.embed_dim * 2):] = bV
|
||||
|
||||
return fp
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
# The whole flow is float -> observed -> quantized
|
||||
# This class does float -> observed only
|
||||
# See nn.quantized.MultiheadAttention
|
||||
raise NotImplementedError("It looks like you are trying to prepare an "
|
||||
"MHA module. Please, see "
|
||||
"the examples on quantizable MHAs.")
|
||||
|
||||
def forward(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Note::
|
||||
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
|
||||
information
|
||||
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
|
||||
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(N, num_heads, L, S)`.
|
||||
"""
|
||||
return self._forward_impl(query, key, value, key_padding_mask,
|
||||
need_weights, attn_mask, average_attn_weights)
|
||||
|
||||
def _forward_impl(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# This version will not deal with the static key/value pairs.
|
||||
# Keeping it here for future changes.
|
||||
#
|
||||
# TODO: This method has some duplicate lines with the
|
||||
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
|
||||
static_k = None
|
||||
static_v = None
|
||||
|
||||
if self.batch_first:
|
||||
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
|
||||
|
||||
tgt_len, bsz, embed_dim_to_check = query.size()
|
||||
assert self.embed_dim == embed_dim_to_check
|
||||
# allow MHA to have different sizes for the feature dimension
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
q = self.linear_Q(query)
|
||||
k = self.linear_K(key)
|
||||
v = self.linear_V(value)
|
||||
|
||||
q = self.q_scaling_product.mul_scalar(q, scaling)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
||||
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
||||
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
||||
else:
|
||||
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
if self.bias_k is not None and self.bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
|
||||
# Explicitly assert that bias_k and bias_v are not None
|
||||
# in a way that TorchScript can understand.
|
||||
bias_k = self.bias_k
|
||||
assert bias_k is not None
|
||||
bias_v = self.bias_v
|
||||
assert bias_v is not None
|
||||
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert static_k is None, "bias cannot be added to static key."
|
||||
assert static_v is None, "bias cannot be added to static value."
|
||||
else:
|
||||
assert self.bias_k is None
|
||||
assert self.bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * self.num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * self.num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||
if k.is_quantized:
|
||||
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
|
||||
k = torch.cat([k, k_zeros], dim=1)
|
||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||
if v.is_quantized:
|
||||
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
|
||||
v = torch.cat([v, v_zeros], dim=1)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
|
||||
# Leaving the quantized zone here
|
||||
q = self.dequant_q(q)
|
||||
k = self.dequant_k(k)
|
||||
v = self.dequant_v(v)
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = nnF.softmax(
|
||||
attn_output_weights, dim=-1)
|
||||
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
|
||||
if self.batch_first:
|
||||
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
||||
else:
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
|
||||
# Reentering the quantized zone
|
||||
attn_output = self.quant_attn_output(attn_output)
|
||||
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
||||
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
|
||||
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if average_attn_weights:
|
||||
attn_output_weights = attn_output_weights.mean(dim=1)
|
||||
return attn_output, attn_output_weights
|
||||
else:
|
||||
return attn_output, None
|
||||
386
torch/ao/nn/quantizable/modules/rnn.py
Normal file
386
torch/ao/nn/quantizable/modules/rnn.py
Normal file
|
|
@ -0,0 +1,386 @@
|
|||
import numbers
|
||||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
"""
|
||||
We will recreate all the RNN modules as we require the modules to be decomposed
|
||||
into its building blocks to be able to observe.
|
||||
"""
|
||||
|
||||
class LSTMCell(torch.nn.Module):
|
||||
r"""A quantizable long short-term memory (LSTM) cell.
|
||||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> rnn = nnqa.LSTMCell(10, 20)
|
||||
>>> input = torch.randn(6, 10)
|
||||
>>> hx = torch.randn(3, 20)
|
||||
>>> cx = torch.randn(3, 20)
|
||||
>>> output = []
|
||||
>>> for i in range(6):
|
||||
... hx, cx = rnn(input[i], (hx, cx))
|
||||
... output.append(hx)
|
||||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTMCell
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.input_size = input_dim
|
||||
self.hidden_size = hidden_dim
|
||||
self.bias = bias
|
||||
|
||||
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.gates = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
|
||||
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
||||
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
||||
if hidden is None or hidden[0] is None or hidden[1] is None:
|
||||
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
||||
hx, cx = hidden
|
||||
|
||||
igates = self.igates(x)
|
||||
hgates = self.hgates(hx)
|
||||
gates = self.gates.add(igates, hgates)
|
||||
|
||||
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
||||
|
||||
input_gate = torch.sigmoid(input_gate)
|
||||
forget_gate = torch.sigmoid(forget_gate)
|
||||
cell_gate = torch.tanh(cell_gate)
|
||||
out_gate = torch.sigmoid(out_gate)
|
||||
|
||||
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
||||
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
||||
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
|
||||
cy = fgate_cx_igate_cgate
|
||||
|
||||
tanh_cy = torch.tanh(cy)
|
||||
hy = self.ogate_cy.mul(out_gate, tanh_cy)
|
||||
return hy, cy
|
||||
|
||||
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
|
||||
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
|
||||
if is_quantized:
|
||||
h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8)
|
||||
c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8)
|
||||
return h, c
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableLSTMCell'
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, wi, wh, bi=None, bh=None):
|
||||
"""Uses the weights and biases to create a new LSTM cell.
|
||||
|
||||
Args:
|
||||
wi, wh: Weights for the input and hidden layers
|
||||
bi, bh: Biases for the input and hidden layers
|
||||
"""
|
||||
assert (bi is None) == (bh is None) # Either both None or both have values
|
||||
input_size = wi.shape[1]
|
||||
hidden_size = wh.shape[1]
|
||||
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
|
||||
bias=(bi is not None))
|
||||
cell.igates.weight = torch.nn.Parameter(wi)
|
||||
if bi is not None:
|
||||
cell.igates.bias = torch.nn.Parameter(bi)
|
||||
cell.hgates.weight = torch.nn.Parameter(wh)
|
||||
if bh is not None:
|
||||
cell.hgates.bias = torch.nn.Parameter(bh)
|
||||
return cell
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other):
|
||||
assert type(other) == cls._FLOAT_MODULE
|
||||
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
||||
observed = cls.from_params(other.weight_ih, other.weight_hh,
|
||||
other.bias_ih, other.bias_hh)
|
||||
observed.qconfig = other.qconfig
|
||||
observed.igates.qconfig = other.qconfig
|
||||
observed.hgates.qconfig = other.qconfig
|
||||
return observed
|
||||
|
||||
|
||||
class _LSTMSingleLayer(torch.nn.Module):
|
||||
r"""A single one-directional LSTM layer.
|
||||
|
||||
The difference between a layer and a cell is that the layer can process a
|
||||
sequence, while the cell only expects an instantaneous value.
|
||||
"""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
result = []
|
||||
for xx in x:
|
||||
hidden = self.cell(xx, hidden)
|
||||
result.append(hidden[0]) # type: ignore[index]
|
||||
result_tensor = torch.stack(result, 0)
|
||||
return result_tensor, hidden
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, *args, **kwargs):
|
||||
cell = LSTMCell.from_params(*args, **kwargs)
|
||||
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
|
||||
layer.cell = cell
|
||||
return layer
|
||||
|
||||
|
||||
class _LSTMLayer(torch.nn.Module):
|
||||
r"""A single bi-directional LSTM layer."""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
batch_first: bool = False, bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.batch_first = batch_first
|
||||
self.bidirectional = bidirectional
|
||||
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
if self.bidirectional:
|
||||
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
if hidden is None:
|
||||
hx_fw, cx_fw = (None, None)
|
||||
else:
|
||||
hx_fw, cx_fw = hidden
|
||||
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
|
||||
if self.bidirectional:
|
||||
if hx_fw is None:
|
||||
hx_bw = None
|
||||
else:
|
||||
hx_bw = hx_fw[1]
|
||||
hx_fw = hx_fw[0]
|
||||
if cx_fw is None:
|
||||
cx_bw = None
|
||||
else:
|
||||
cx_bw = cx_fw[1]
|
||||
cx_fw = cx_fw[0]
|
||||
if hx_bw is not None and cx_bw is not None:
|
||||
hidden_bw = hx_bw, cx_bw
|
||||
if hx_fw is None and cx_fw is None:
|
||||
hidden_fw = None
|
||||
else:
|
||||
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
|
||||
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
|
||||
|
||||
if hasattr(self, 'layer_bw') and self.bidirectional:
|
||||
x_reversed = x.flip(0)
|
||||
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
|
||||
result_bw = result_bw.flip(0)
|
||||
|
||||
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
|
||||
if hidden_fw is None and hidden_bw is None:
|
||||
h = None
|
||||
c = None
|
||||
elif hidden_fw is None:
|
||||
(h, c) = torch.jit._unwrap_optional(hidden_bw)
|
||||
elif hidden_bw is None:
|
||||
(h, c) = torch.jit._unwrap_optional(hidden_fw)
|
||||
else:
|
||||
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
|
||||
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
|
||||
else:
|
||||
result = result_fw
|
||||
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
|
||||
|
||||
if self.batch_first:
|
||||
result.transpose_(0, 1)
|
||||
|
||||
return result, (h, c)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
|
||||
r"""
|
||||
There is no FP equivalent of this class. This function is here just to
|
||||
mimic the behavior of the `prepare` within the `torch.ao.quantization`
|
||||
flow.
|
||||
"""
|
||||
assert hasattr(other, 'qconfig') or (qconfig is not None)
|
||||
|
||||
input_size = kwargs.get('input_size', other.input_size)
|
||||
hidden_size = kwargs.get('hidden_size', other.hidden_size)
|
||||
bias = kwargs.get('bias', other.bias)
|
||||
batch_first = kwargs.get('batch_first', other.batch_first)
|
||||
bidirectional = kwargs.get('bidirectional', other.bidirectional)
|
||||
|
||||
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
|
||||
layer.qconfig = getattr(other, 'qconfig', qconfig)
|
||||
wi = getattr(other, f'weight_ih_l{layer_idx}')
|
||||
wh = getattr(other, f'weight_hh_l{layer_idx}')
|
||||
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
|
||||
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
|
||||
|
||||
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
|
||||
if other.bidirectional:
|
||||
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
|
||||
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
|
||||
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
|
||||
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
|
||||
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
return layer
|
||||
|
||||
|
||||
class LSTM(torch.nn.Module):
|
||||
r"""A quantizable long short-term memory (LSTM).
|
||||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
|
||||
|
||||
Attributes:
|
||||
layers : instances of the `_LSTMLayer`
|
||||
|
||||
.. note::
|
||||
To access the weights and biases, you need to access them per layer.
|
||||
See examples below.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> rnn = nnqa.LSTM(10, 20, 2)
|
||||
>>> input = torch.randn(5, 3, 10)
|
||||
>>> h0 = torch.randn(2, 3, 20)
|
||||
>>> c0 = torch.randn(2, 3, 20)
|
||||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||
>>> # To get the weights:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> print(rnn.layers[0].weight_ih)
|
||||
tensor([[...]])
|
||||
>>> print(rnn.layers[0].weight_hh)
|
||||
AssertionError: There is no reverse path in the non-bidirectional layer
|
||||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTM
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True,
|
||||
batch_first: bool = False, dropout: float = 0.,
|
||||
bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.bias = bias
|
||||
self.batch_first = batch_first
|
||||
self.dropout = float(dropout)
|
||||
self.bidirectional = bidirectional
|
||||
self.training = False # We don't want to train using this module
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
||||
isinstance(dropout, bool):
|
||||
raise ValueError("dropout should be a number in range [0, 1] "
|
||||
"representing the probability of an element being "
|
||||
"zeroed")
|
||||
if dropout > 0:
|
||||
warnings.warn("dropout option for quantizable LSTM is ignored. "
|
||||
"If you are training, please, use nn.LSTM version "
|
||||
"followed by `prepare` step.")
|
||||
if num_layers == 1:
|
||||
warnings.warn("dropout option adds dropout after all but last "
|
||||
"recurrent layer, so non-zero dropout expects "
|
||||
"num_layers greater than 1, but got dropout={} "
|
||||
"and num_layers={}".format(dropout, num_layers))
|
||||
|
||||
layers = [_LSTMLayer(self.input_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional, **factory_kwargs)]
|
||||
for layer in range(1, num_layers):
|
||||
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional,
|
||||
**factory_kwargs))
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
max_batch_size = x.size(1)
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
if hidden is None:
|
||||
zeros = torch.zeros(num_directions, max_batch_size,
|
||||
self.hidden_size, dtype=torch.float,
|
||||
device=x.device)
|
||||
zeros.squeeze_(0)
|
||||
if x.is_quantized:
|
||||
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
|
||||
zero_point=0, dtype=x.dtype)
|
||||
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
|
||||
else:
|
||||
hidden_non_opt = torch.jit._unwrap_optional(hidden)
|
||||
if isinstance(hidden_non_opt[0], Tensor):
|
||||
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size).unbind(0)
|
||||
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size).unbind(0)
|
||||
hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
|
||||
else:
|
||||
hxcx = hidden_non_opt
|
||||
|
||||
hx_list = []
|
||||
cx_list = []
|
||||
for idx, layer in enumerate(self.layers):
|
||||
x, (h, c) = layer(x, hxcx[idx])
|
||||
hx_list.append(torch.jit._unwrap_optional(h))
|
||||
cx_list.append(torch.jit._unwrap_optional(c))
|
||||
hx_tensor = torch.stack(hx_list)
|
||||
cx_tensor = torch.stack(cx_list)
|
||||
|
||||
# We are creating another dimension for bidirectional case
|
||||
# need to collapse it
|
||||
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
|
||||
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
|
||||
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, (hx_tensor, cx_tensor)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableLSTM'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, qconfig=None):
|
||||
assert isinstance(other, cls._FLOAT_MODULE)
|
||||
assert (hasattr(other, 'qconfig') or qconfig)
|
||||
observed = cls(other.input_size, other.hidden_size, other.num_layers,
|
||||
other.bias, other.batch_first, other.dropout,
|
||||
other.bidirectional)
|
||||
observed.qconfig = getattr(other, 'qconfig', qconfig)
|
||||
for idx in range(other.num_layers):
|
||||
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
|
||||
batch_first=False)
|
||||
observed.eval()
|
||||
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
||||
return observed
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
# The whole flow is float -> observed -> quantized
|
||||
# This class does float -> observed only
|
||||
raise NotImplementedError("It looks like you are trying to convert a "
|
||||
"non-quantizable LSTM module. Please, see "
|
||||
"the examples on quantizable LSTMs.")
|
||||
|
|
@ -1,4 +1,12 @@
|
|||
import torch
|
||||
|
||||
# The quantized modules use `torch.nn` and `torch.ao.nn.quantizable`
|
||||
# packages. However, the `quantizable` package uses "lazy imports"
|
||||
# to avoid circular dependency.
|
||||
# Hence we need to include it here to make sure it is resolved before
|
||||
# they are used in the modules.
|
||||
import torch.ao.nn.quantizable
|
||||
|
||||
from torch.nn.modules.pooling import MaxPool2d
|
||||
|
||||
from .activation import ReLU6, Hardswish, ELU, LeakyReLU, Sigmoid, Softmax, MultiheadAttention, PReLU
|
||||
|
|
|
|||
|
|
@ -184,8 +184,9 @@ class Softmax(torch.nn.Softmax):
|
|||
def from_reference(cls, mod, scale, zero_point):
|
||||
return cls(mod.dim, float(scale), int(zero_point))
|
||||
|
||||
class MultiheadAttention(torch.nn.quantizable.MultiheadAttention):
|
||||
_FLOAT_MODULE = torch.nn.quantizable.MultiheadAttention
|
||||
|
||||
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
||||
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
||||
|
||||
def _get_name(self):
|
||||
return "QuantizedMultiheadAttention"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
|
||||
class LSTM(torch.nn.quantizable.LSTM):
|
||||
class LSTM(torch.ao.nn.quantizable.LSTM):
|
||||
r"""A quantized long short-term memory (LSTM).
|
||||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
|
||||
|
|
|
|||
|
|
@ -16,6 +16,9 @@ import torch.nn.qat.dynamic as nnqatd
|
|||
|
||||
from typing import Optional, Union, Dict, Set, Callable, Any
|
||||
|
||||
# Because `torch.ao.nn` uses lazy imports, we need to make
|
||||
# sure we import the contents explicitly here.
|
||||
import torch.ao.nn.sparse
|
||||
import torch.ao.nn as ao_nn
|
||||
from torch.ao.quantization.stubs import QuantStub, DeQuantStub
|
||||
from torch.ao.quantization.fake_quantize import (
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
import torch
|
||||
import torch.ao.nn
|
||||
|
||||
def get_static_sparse_quantized_mapping():
|
||||
import torch.ao.nn.sparse
|
||||
_static_sparse_quantized_mapping = dict({
|
||||
torch.nn.Linear: torch.ao.nn.sparse.quantized.Linear,
|
||||
})
|
||||
return _static_sparse_quantized_mapping
|
||||
|
||||
def get_dynamic_sparse_quantized_mapping():
|
||||
import torch.ao.nn.sparse
|
||||
_dynamic_sparse_quantized_mapping = dict({
|
||||
torch.nn.Linear: torch.ao.nn.sparse.quantized.dynamic.Linear,
|
||||
})
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from .activation import MultiheadAttention
|
||||
from .rnn import LSTM
|
||||
from .rnn import LSTMCell
|
||||
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
||||
from torch.ao.nn.quantizable.modules.rnn import LSTM
|
||||
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
|
||||
|
||||
__all__ = [
|
||||
'LSTM',
|
||||
|
|
|
|||
|
|
@ -1,454 +1,10 @@
|
|||
import torch
|
||||
import torch.jit # this is needed to avoid a circular import
|
||||
from torch import nn
|
||||
import torch.nn.functional as nnF
|
||||
# flake8: noqa: F401
|
||||
r"""Quantizable Modules
|
||||
|
||||
from torch import Tensor
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import warnings
|
||||
|
||||
class MultiheadAttention(nn.MultiheadAttention):
|
||||
_FLOAT_MODULE = nn.MultiheadAttention
|
||||
|
||||
r"""Quantizable implementation of the MultiheadAttention.
|
||||
|
||||
Note::
|
||||
Please, refer to :class:`~torch.nn.MultiheadAttention` for more
|
||||
information
|
||||
|
||||
Allows the model to jointly attend to information from different
|
||||
representation subspaces.
|
||||
See reference: Attention Is All You Need
|
||||
|
||||
The original MHA module is not quantizable.
|
||||
This reimplements it by explicitly instantiating the linear layers.
|
||||
|
||||
.. math::
|
||||
\text{MultiHead}(Q, K, V) = \text{Concat}(head_1,\dots,head_h)W^O
|
||||
\text{where} head_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V)
|
||||
|
||||
Args:
|
||||
embed_dim: total dimension of the model.
|
||||
num_heads: parallel attention heads.
|
||||
dropout: a Dropout layer on attn_output_weights. Default: 0.0.
|
||||
bias: add bias as module parameter. Default: True.
|
||||
add_bias_kv: add bias to the key and value sequences at dim=0.
|
||||
add_zero_attn: add a new batch of zeros to the key and
|
||||
value sequences at dim=1.
|
||||
kdim: total number of features in key. Default: None.
|
||||
vdim: total number of features in value. Default: None.
|
||||
batch_first: If ``True``, then the input and output tensors are provided
|
||||
as (batch, seq, feature). Default: ``False`` (seq, batch, feature).
|
||||
|
||||
Note that if :attr:`kdim` and :attr:`vdim` are None, they will be set
|
||||
to :attr:`embed_dim` such that query, key, and value have the same
|
||||
number of features.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> multihead_attn = nnqa.MultiheadAttention(embed_dim, num_heads)
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
Note::
|
||||
Please, follow the quantization flow to convert the quantizable MHA.
|
||||
"""
|
||||
__constants__ = ['batch_first']
|
||||
|
||||
def __init__(self, embed_dim: int, num_heads: int,
|
||||
dropout: float = 0., bias: bool = True,
|
||||
add_bias_kv: bool = False, add_zero_attn: bool = False,
|
||||
kdim: int = None, vdim: int = None, batch_first: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super(MultiheadAttention, self).__init__(embed_dim, num_heads, dropout,
|
||||
bias, add_bias_kv,
|
||||
add_zero_attn, kdim, vdim, batch_first,
|
||||
**factory_kwargs)
|
||||
self.linear_Q = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_K = nn.Linear(self.kdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
self.linear_V = nn.Linear(self.vdim, self.embed_dim, bias=bias, **factory_kwargs)
|
||||
# for the type: ignore, see https://github.com/pytorch/pytorch/issues/58969
|
||||
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=bias, **factory_kwargs) # type: ignore[assignment]
|
||||
|
||||
# Functionals
|
||||
self.q_scaling_product = torch.nn.quantized.FloatFunctional()
|
||||
# note: importing torch.nn.quantized at top creates a circular import
|
||||
|
||||
# Quant/Dequant
|
||||
self.quant_attn_output = torch.ao.quantization.QuantStub()
|
||||
self.quant_attn_output_weights = torch.ao.quantization.QuantStub()
|
||||
self.dequant_q = torch.ao.quantization.DeQuantStub()
|
||||
self.dequant_k = torch.ao.quantization.DeQuantStub()
|
||||
self.dequant_v = torch.ao.quantization.DeQuantStub()
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableMultiheadAttention'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other):
|
||||
assert type(other) == cls._FLOAT_MODULE
|
||||
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
||||
# Setting the dropout to 0.0!
|
||||
observed = cls(other.embed_dim, other.num_heads, other.dropout,
|
||||
(other.in_proj_bias is not None),
|
||||
(other.bias_k is not None),
|
||||
other.add_zero_attn, other.kdim, other.vdim)
|
||||
observed.bias_k = other.bias_k
|
||||
observed.bias_v = other.bias_v
|
||||
observed.qconfig = other.qconfig
|
||||
|
||||
# Set the linear weights
|
||||
# for the type: ignores, see https://github.com/pytorch/pytorch/issues/58969
|
||||
observed.out_proj.weight = other.out_proj.weight # type: ignore[has-type]
|
||||
observed.out_proj.bias = other.out_proj.bias # type: ignore[has-type]
|
||||
if other._qkv_same_embed_dim:
|
||||
# Use separate params
|
||||
bias = other.in_proj_bias
|
||||
_start = 0
|
||||
_end = _start + other.embed_dim
|
||||
weight = other.in_proj_weight[_start:_end, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
||||
observed.linear_Q.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_Q.bias = bias
|
||||
|
||||
bias = other.in_proj_bias
|
||||
_start = _end
|
||||
_end = _start + other.embed_dim
|
||||
weight = other.in_proj_weight[_start:_end, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:_end], bias.requires_grad)
|
||||
observed.linear_K.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_K.bias = bias
|
||||
|
||||
bias = other.in_proj_bias
|
||||
_start = _end
|
||||
weight = other.in_proj_weight[_start:, :]
|
||||
if bias is not None:
|
||||
bias = torch.nn.Parameter(bias[_start:], bias.requires_grad)
|
||||
observed.linear_V.weight = torch.nn.Parameter(weight,
|
||||
weight.requires_grad)
|
||||
observed.linear_V.bias = bias
|
||||
else:
|
||||
observed.linear_Q.weight = nn.Parameter(other.q_proj_weight)
|
||||
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
||||
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
||||
if other.in_proj_bias is None:
|
||||
observed.linear_Q.bias = None # type: ignore[assignment]
|
||||
observed.linear_K.bias = None # type: ignore[assignment]
|
||||
observed.linear_V.bias = None # type: ignore[assignment]
|
||||
else:
|
||||
observed.linear_Q.bias = nn.Parameter(other.in_proj_bias[0:other.embed_dim])
|
||||
observed.linear_K.bias = nn.Parameter(other.in_proj_bias[other.embed_dim:(other.embed_dim * 2)])
|
||||
observed.linear_V.bias = nn.Parameter(other.in_proj_bias[(other.embed_dim * 2):])
|
||||
observed.eval()
|
||||
# Explicit prepare
|
||||
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
||||
return observed
|
||||
|
||||
@torch.jit.unused
|
||||
def dequantize(self):
|
||||
r"""Utility to convert the quantized MHA back to float.
|
||||
|
||||
The motivation for this is that it is not trivial to conver the weights
|
||||
from the format that is used in the quantized version back to the
|
||||
float.
|
||||
"""
|
||||
fp = self._FLOAT_MODULE(self.embed_dim, self.num_heads, self.dropout,
|
||||
(self.in_proj_bias is not None),
|
||||
(self.bias_k is not None),
|
||||
self.add_zero_attn, self.kdim, self.vdim, self.batch_first)
|
||||
assert fp._qkv_same_embed_dim == self._qkv_same_embed_dim
|
||||
if self.bias_k is not None:
|
||||
fp.bias_k = nn.Parameter(self.bias_k.dequantize())
|
||||
if self.bias_v is not None:
|
||||
fp.bias_v = nn.Parameter(self.bias_v.dequantize())
|
||||
|
||||
# Set the linear weights
|
||||
# Note: Because the linear layers are quantized, mypy does not nkow how
|
||||
# to deal with them -- might need to ignore the typing checks.
|
||||
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
||||
w, b = self.out_proj._weight_bias() # type: ignore[operator, has-type]
|
||||
fp.out_proj.weight = nn.Parameter(w.dequantize())
|
||||
if b is not None:
|
||||
fp.out_proj.bias = nn.Parameter(b)
|
||||
|
||||
wQ, bQ = self.linear_Q._weight_bias() # type: ignore[operator]
|
||||
wQ = wQ.dequantize()
|
||||
wK, bK = self.linear_K._weight_bias() # type: ignore[operator]
|
||||
wK = wK.dequantize()
|
||||
wV, bV = self.linear_V._weight_bias() # type: ignore[operator]
|
||||
wV = wV.dequantize()
|
||||
if fp._qkv_same_embed_dim:
|
||||
# Use separate params
|
||||
_start = 0
|
||||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wQ
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bQ == 0)
|
||||
fp.in_proj_bias[_start:_end] = bQ
|
||||
|
||||
_start = _end
|
||||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wK
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bK == 0)
|
||||
fp.in_proj_bias[_start:_end] = bK
|
||||
|
||||
_start = _end
|
||||
fp.in_proj_weight[_start:, :] = wV
|
||||
if fp.in_proj_bias is not None:
|
||||
assert all(bV == 0)
|
||||
fp.in_proj_bias[_start:] = bV
|
||||
else:
|
||||
fp.q_proj_weight = nn.Parameter(wQ)
|
||||
fp.k_proj_weight = nn.Parameter(wK)
|
||||
fp.v_proj_weight = nn.Parameter(wV)
|
||||
if fp.in_proj_bias is None:
|
||||
self.linear_Q.bias = None
|
||||
self.linear_K.bias = None
|
||||
self.linear_V.bias = None
|
||||
else:
|
||||
fp.in_proj_bias[0:fp.embed_dim] = bQ
|
||||
fp.in_proj_bias[fp.embed_dim:(fp.embed_dim * 2)] = bK
|
||||
fp.in_proj_bias[(fp.embed_dim * 2):] = bV
|
||||
|
||||
return fp
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
# The whole flow is float -> observed -> quantized
|
||||
# This class does float -> observed only
|
||||
# See nn.quantized.MultiheadAttention
|
||||
raise NotImplementedError("It looks like you are trying to prepare an "
|
||||
"MHA module. Please, see "
|
||||
"the examples on quantizable MHAs.")
|
||||
|
||||
def forward(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
r"""
|
||||
Note::
|
||||
Please, refer to :func:`~torch.nn.MultiheadAttention.forward` for more
|
||||
information
|
||||
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
See "Attention Is All You Need" for more details.
|
||||
key_padding_mask: if provided, specified padding elements in the key will
|
||||
be ignored by the attention. When given a binary mask and a value is True,
|
||||
the corresponding value on the attention layer will be ignored. When given
|
||||
a byte mask and a value is non-zero, the corresponding value on the attention
|
||||
layer will be ignored
|
||||
need_weights: output attn_output_weights.
|
||||
attn_mask: 2D or 3D mask that prevents attention to certain positions. A 2D mask will be broadcasted for all
|
||||
the batches while a 3D mask allows to specify a different mask for the entries of each batch.
|
||||
|
||||
Shape:
|
||||
- Inputs:
|
||||
- query: :math:`(L, N, E)` where L is the target sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- key: :math:`(S, N, E)`, where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- value: :math:`(S, N, E)` where S is the source sequence length, N is the batch size, E is
|
||||
the embedding dimension. :math:`(N, S, E)` if ``batch_first`` is ``True``.
|
||||
- key_padding_mask: :math:`(N, S)` where N is the batch size, S is the source sequence length.
|
||||
If a ByteTensor is provided, the non-zero positions will be ignored while the position
|
||||
with the zero positions will be unchanged. If a BoolTensor is provided, the positions with the
|
||||
value of ``True`` will be ignored while the position with the value of ``False`` will be unchanged.
|
||||
- attn_mask: 2D mask :math:`(L, S)` where L is the target sequence length, S is the source sequence length.
|
||||
3D mask :math:`(N*num_heads, L, S)` where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. attn_mask ensure that position i is allowed to attend the unmasked
|
||||
positions. If a ByteTensor is provided, the non-zero positions are not allowed to attend
|
||||
while the zero positions will be unchanged. If a BoolTensor is provided, positions with ``True``
|
||||
is not allowed to attend while ``False`` values will be unchanged. If a FloatTensor
|
||||
is provided, it will be added to the attention weight.
|
||||
- average_attn_weights: If true, indicates that the returned ``attn_weights`` should be averaged across
|
||||
heads. Otherwise, ``attn_weights`` are provided separately per head. Note that this flag only has an
|
||||
effect when ``need_weights=True.``. Default: True (i.e. average weights across heads)
|
||||
|
||||
- Outputs:
|
||||
- attn_output: :math:`(L, N, E)` where L is the target sequence length, N is the batch size,
|
||||
E is the embedding dimension. :math:`(N, L, E)` if ``batch_first`` is ``True``.
|
||||
- attn_output_weights: If ``average_attn_weights=True``, returns attention weights averaged
|
||||
across heads of shape :math:`(N, L, S)`, where N is the batch size, L is the target sequence length,
|
||||
S is the source sequence length. If ``average_attn_weights=False``, returns attention weights per
|
||||
head of shape :math:`(N, num_heads, L, S)`.
|
||||
"""
|
||||
return self._forward_impl(query, key, value, key_padding_mask,
|
||||
need_weights, attn_mask, average_attn_weights)
|
||||
|
||||
def _forward_impl(self,
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
# This version will not deal with the static key/value pairs.
|
||||
# Keeping it here for future changes.
|
||||
#
|
||||
# TODO: This method has some duplicate lines with the
|
||||
# `torch.nn.functional.multi_head_attention`. Will need to refactor.
|
||||
static_k = None
|
||||
static_v = None
|
||||
|
||||
if self.batch_first:
|
||||
query, key, value = [x.transpose(0, 1) for x in (query, key, value)]
|
||||
|
||||
tgt_len, bsz, embed_dim_to_check = query.size()
|
||||
assert self.embed_dim == embed_dim_to_check
|
||||
# allow MHA to have different sizes for the feature dimension
|
||||
assert key.size(0) == value.size(0) and key.size(1) == value.size(1)
|
||||
|
||||
head_dim = self.embed_dim // self.num_heads
|
||||
assert head_dim * self.num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
scaling = float(head_dim) ** -0.5
|
||||
|
||||
q = self.linear_Q(query)
|
||||
k = self.linear_K(key)
|
||||
v = self.linear_V(value)
|
||||
|
||||
q = self.q_scaling_product.mul_scalar(q, scaling)
|
||||
|
||||
if attn_mask is not None:
|
||||
assert attn_mask.dtype == torch.float32 or attn_mask.dtype == torch.float64 or \
|
||||
attn_mask.dtype == torch.float16 or attn_mask.dtype == torch.uint8 or attn_mask.dtype == torch.bool, \
|
||||
'Only float, byte, and bool types are supported for attn_mask, not {}'.format(attn_mask.dtype)
|
||||
if attn_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for attn_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(0)
|
||||
if list(attn_mask.size()) != [1, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 2D attn_mask is not correct.')
|
||||
elif attn_mask.dim() == 3:
|
||||
if list(attn_mask.size()) != [bsz * self.num_heads, query.size(0), key.size(0)]:
|
||||
raise RuntimeError('The size of the 3D attn_mask is not correct.')
|
||||
else:
|
||||
raise RuntimeError("attn_mask's dimension {} is not supported".format(attn_mask.dim()))
|
||||
# attn_mask's dim is 3 now.
|
||||
|
||||
# convert ByteTensor key_padding_mask to bool
|
||||
if key_padding_mask is not None and key_padding_mask.dtype == torch.uint8:
|
||||
warnings.warn("Byte tensor for key_padding_mask in nn.MultiheadAttention is deprecated. Use bool tensor instead.")
|
||||
key_padding_mask = key_padding_mask.to(torch.bool)
|
||||
if self.bias_k is not None and self.bias_v is not None:
|
||||
if static_k is None and static_v is None:
|
||||
|
||||
# Explicitly assert that bias_k and bias_v are not None
|
||||
# in a way that TorchScript can understand.
|
||||
bias_k = self.bias_k
|
||||
assert bias_k is not None
|
||||
bias_v = self.bias_v
|
||||
assert bias_v is not None
|
||||
|
||||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert static_k is None, "bias cannot be added to static key."
|
||||
assert static_v is None, "bias cannot be added to static value."
|
||||
else:
|
||||
assert self.bias_k is None
|
||||
assert self.bias_v is None
|
||||
|
||||
q = q.contiguous().view(tgt_len, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if k is not None:
|
||||
k = k.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
if v is not None:
|
||||
v = v.contiguous().view(-1, bsz * self.num_heads, head_dim).transpose(0, 1)
|
||||
|
||||
if static_k is not None:
|
||||
assert static_k.size(0) == bsz * self.num_heads
|
||||
assert static_k.size(2) == head_dim
|
||||
k = static_k
|
||||
|
||||
if static_v is not None:
|
||||
assert static_v.size(0) == bsz * self.num_heads
|
||||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
assert key_padding_mask.size(0) == bsz
|
||||
assert key_padding_mask.size(1) == src_len
|
||||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||
if k.is_quantized:
|
||||
k_zeros = torch.quantize_per_tensor(k_zeros, k.q_scale(), k.q_zero_point(), k.dtype)
|
||||
k = torch.cat([k, k_zeros], dim=1)
|
||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||
if v.is_quantized:
|
||||
v_zeros = torch.quantize_per_tensor(v_zeros, v.q_scale(), v.q_zero_point(), v.dtype)
|
||||
v = torch.cat([v, v_zeros], dim=1)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = nnF.pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = nnF.pad(key_padding_mask, (0, 1))
|
||||
|
||||
# Leaving the quantized zone here
|
||||
q = self.dequant_q(q)
|
||||
k = self.dequant_k(k)
|
||||
v = self.dequant_v(v)
|
||||
attn_output_weights = torch.bmm(q, k.transpose(1, 2))
|
||||
assert list(attn_output_weights.size()) == [bsz * self.num_heads, tgt_len, src_len]
|
||||
|
||||
if attn_mask is not None:
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_output_weights.masked_fill_(attn_mask, float('-inf'))
|
||||
else:
|
||||
attn_output_weights += attn_mask
|
||||
|
||||
if key_padding_mask is not None:
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
attn_output_weights = attn_output_weights.masked_fill(
|
||||
key_padding_mask.unsqueeze(1).unsqueeze(2),
|
||||
float('-inf'),
|
||||
)
|
||||
attn_output_weights = attn_output_weights.view(bsz * self.num_heads, tgt_len, src_len)
|
||||
|
||||
attn_output_weights = nnF.softmax(
|
||||
attn_output_weights, dim=-1)
|
||||
attn_output_weights = nnF.dropout(attn_output_weights, p=self.dropout, training=self.training)
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
assert list(attn_output.size()) == [bsz * self.num_heads, tgt_len, head_dim]
|
||||
if self.batch_first:
|
||||
attn_output = attn_output.view(bsz, tgt_len, self.embed_dim)
|
||||
else:
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, self.embed_dim)
|
||||
|
||||
# Reentering the quantized zone
|
||||
attn_output = self.quant_attn_output(attn_output)
|
||||
# for the type: ignore[has-type], see https://github.com/pytorch/pytorch/issues/58969
|
||||
attn_output = self.out_proj(attn_output) # type: ignore[has-type]
|
||||
attn_output_weights = self.quant_attn_output_weights(attn_output_weights)
|
||||
|
||||
if need_weights:
|
||||
# average attention weights over heads
|
||||
attn_output_weights = attn_output_weights.view(bsz, self.num_heads, tgt_len, src_len)
|
||||
if average_attn_weights:
|
||||
attn_output_weights = attn_output_weights.mean(dim=1)
|
||||
return attn_output, attn_output_weights
|
||||
else:
|
||||
return attn_output, None
|
||||
This file is in the process of migration to `torch/ao/nn/quantizable`, and
|
||||
is kept here for compatibility while the migration process is ongoing.
|
||||
If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
||||
|
|
|
|||
|
|
@ -1,386 +1,11 @@
|
|||
import numbers
|
||||
from typing import Optional, Tuple
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
# flake8: noqa: F401
|
||||
r"""Quantizable Modules
|
||||
|
||||
This file is in the process of migration to `torch/ao/nn/quantizable`, and
|
||||
is kept here for compatibility while the migration process is ongoing.
|
||||
If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
We will recreate all the RNN modules as we require the modules to be decomposed
|
||||
into its building blocks to be able to observe.
|
||||
"""
|
||||
|
||||
class LSTMCell(torch.nn.Module):
|
||||
r"""A quantizable long short-term memory (LSTM) cell.
|
||||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> rnn = nnqa.LSTMCell(10, 20)
|
||||
>>> input = torch.randn(6, 10)
|
||||
>>> hx = torch.randn(3, 20)
|
||||
>>> cx = torch.randn(3, 20)
|
||||
>>> output = []
|
||||
>>> for i in range(6):
|
||||
... hx, cx = rnn(input[i], (hx, cx))
|
||||
... output.append(hx)
|
||||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTMCell
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.input_size = input_dim
|
||||
self.hidden_size = hidden_dim
|
||||
self.bias = bias
|
||||
|
||||
self.igates = torch.nn.Linear(input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.hgates = torch.nn.Linear(hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs)
|
||||
self.gates = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
self.fgate_cx = torch.ao.nn.quantized.FloatFunctional()
|
||||
self.igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
||||
self.fgate_cx_igate_cgate = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
self.ogate_cy = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None) -> Tuple[Tensor, Tensor]:
|
||||
if hidden is None or hidden[0] is None or hidden[1] is None:
|
||||
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
||||
hx, cx = hidden
|
||||
|
||||
igates = self.igates(x)
|
||||
hgates = self.hgates(hx)
|
||||
gates = self.gates.add(igates, hgates)
|
||||
|
||||
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
||||
|
||||
input_gate = torch.sigmoid(input_gate)
|
||||
forget_gate = torch.sigmoid(forget_gate)
|
||||
cell_gate = torch.tanh(cell_gate)
|
||||
out_gate = torch.sigmoid(out_gate)
|
||||
|
||||
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
||||
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
||||
fgate_cx_igate_cgate = self.fgate_cx_igate_cgate.add(fgate_cx, igate_cgate)
|
||||
cy = fgate_cx_igate_cgate
|
||||
|
||||
tanh_cy = torch.tanh(cy)
|
||||
hy = self.ogate_cy.mul(out_gate, tanh_cy)
|
||||
return hy, cy
|
||||
|
||||
def initialize_hidden(self, batch_size: int, is_quantized: bool = False) -> Tuple[Tensor, Tensor]:
|
||||
h, c = torch.zeros((batch_size, self.hidden_size)), torch.zeros((batch_size, self.hidden_size))
|
||||
if is_quantized:
|
||||
h = torch.quantize_per_tensor(h, scale=1.0, zero_point=0, dtype=torch.quint8)
|
||||
c = torch.quantize_per_tensor(c, scale=1.0, zero_point=0, dtype=torch.quint8)
|
||||
return h, c
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableLSTMCell'
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, wi, wh, bi=None, bh=None):
|
||||
"""Uses the weights and biases to create a new LSTM cell.
|
||||
|
||||
Args:
|
||||
wi, wh: Weights for the input and hidden layers
|
||||
bi, bh: Biases for the input and hidden layers
|
||||
"""
|
||||
assert (bi is None) == (bh is None) # Either both None or both have values
|
||||
input_size = wi.shape[1]
|
||||
hidden_size = wh.shape[1]
|
||||
cell = cls(input_dim=input_size, hidden_dim=hidden_size,
|
||||
bias=(bi is not None))
|
||||
cell.igates.weight = torch.nn.Parameter(wi)
|
||||
if bi is not None:
|
||||
cell.igates.bias = torch.nn.Parameter(bi)
|
||||
cell.hgates.weight = torch.nn.Parameter(wh)
|
||||
if bh is not None:
|
||||
cell.hgates.bias = torch.nn.Parameter(bh)
|
||||
return cell
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other):
|
||||
assert type(other) == cls._FLOAT_MODULE
|
||||
assert hasattr(other, 'qconfig'), "The float module must have 'qconfig'"
|
||||
observed = cls.from_params(other.weight_ih, other.weight_hh,
|
||||
other.bias_ih, other.bias_hh)
|
||||
observed.qconfig = other.qconfig
|
||||
observed.igates.qconfig = other.qconfig
|
||||
observed.hgates.qconfig = other.qconfig
|
||||
return observed
|
||||
|
||||
|
||||
class _LSTMSingleLayer(torch.nn.Module):
|
||||
r"""A single one-directional LSTM layer.
|
||||
|
||||
The difference between a layer and a cell is that the layer can process a
|
||||
sequence, while the cell only expects an instantaneous value.
|
||||
"""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
result = []
|
||||
for xx in x:
|
||||
hidden = self.cell(xx, hidden)
|
||||
result.append(hidden[0]) # type: ignore[index]
|
||||
result_tensor = torch.stack(result, 0)
|
||||
return result_tensor, hidden
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, *args, **kwargs):
|
||||
cell = LSTMCell.from_params(*args, **kwargs)
|
||||
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
|
||||
layer.cell = cell
|
||||
return layer
|
||||
|
||||
|
||||
class _LSTMLayer(torch.nn.Module):
|
||||
r"""A single bi-directional LSTM layer."""
|
||||
def __init__(self, input_dim: int, hidden_dim: int, bias: bool = True,
|
||||
batch_first: bool = False, bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.batch_first = batch_first
|
||||
self.bidirectional = bidirectional
|
||||
self.layer_fw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
if self.bidirectional:
|
||||
self.layer_bw = _LSTMSingleLayer(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
if hidden is None:
|
||||
hx_fw, cx_fw = (None, None)
|
||||
else:
|
||||
hx_fw, cx_fw = hidden
|
||||
hidden_bw: Optional[Tuple[Tensor, Tensor]] = None
|
||||
if self.bidirectional:
|
||||
if hx_fw is None:
|
||||
hx_bw = None
|
||||
else:
|
||||
hx_bw = hx_fw[1]
|
||||
hx_fw = hx_fw[0]
|
||||
if cx_fw is None:
|
||||
cx_bw = None
|
||||
else:
|
||||
cx_bw = cx_fw[1]
|
||||
cx_fw = cx_fw[0]
|
||||
if hx_bw is not None and cx_bw is not None:
|
||||
hidden_bw = hx_bw, cx_bw
|
||||
if hx_fw is None and cx_fw is None:
|
||||
hidden_fw = None
|
||||
else:
|
||||
hidden_fw = torch.jit._unwrap_optional(hx_fw), torch.jit._unwrap_optional(cx_fw)
|
||||
result_fw, hidden_fw = self.layer_fw(x, hidden_fw)
|
||||
|
||||
if hasattr(self, 'layer_bw') and self.bidirectional:
|
||||
x_reversed = x.flip(0)
|
||||
result_bw, hidden_bw = self.layer_bw(x_reversed, hidden_bw)
|
||||
result_bw = result_bw.flip(0)
|
||||
|
||||
result = torch.cat([result_fw, result_bw], result_fw.dim() - 1)
|
||||
if hidden_fw is None and hidden_bw is None:
|
||||
h = None
|
||||
c = None
|
||||
elif hidden_fw is None:
|
||||
(h, c) = torch.jit._unwrap_optional(hidden_bw)
|
||||
elif hidden_bw is None:
|
||||
(h, c) = torch.jit._unwrap_optional(hidden_fw)
|
||||
else:
|
||||
h = torch.stack([hidden_fw[0], hidden_bw[0]], 0) # type: ignore[list-item]
|
||||
c = torch.stack([hidden_fw[1], hidden_bw[1]], 0) # type: ignore[list-item]
|
||||
else:
|
||||
result = result_fw
|
||||
h, c = torch.jit._unwrap_optional(hidden_fw) # type: ignore[assignment]
|
||||
|
||||
if self.batch_first:
|
||||
result.transpose_(0, 1)
|
||||
|
||||
return result, (h, c)
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, layer_idx=0, qconfig=None, **kwargs):
|
||||
r"""
|
||||
There is no FP equivalent of this class. This function is here just to
|
||||
mimic the behavior of the `prepare` within the `torch.ao.quantization`
|
||||
flow.
|
||||
"""
|
||||
assert hasattr(other, 'qconfig') or (qconfig is not None)
|
||||
|
||||
input_size = kwargs.get('input_size', other.input_size)
|
||||
hidden_size = kwargs.get('hidden_size', other.hidden_size)
|
||||
bias = kwargs.get('bias', other.bias)
|
||||
batch_first = kwargs.get('batch_first', other.batch_first)
|
||||
bidirectional = kwargs.get('bidirectional', other.bidirectional)
|
||||
|
||||
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
|
||||
layer.qconfig = getattr(other, 'qconfig', qconfig)
|
||||
wi = getattr(other, f'weight_ih_l{layer_idx}')
|
||||
wh = getattr(other, f'weight_hh_l{layer_idx}')
|
||||
bi = getattr(other, f'bias_ih_l{layer_idx}', None)
|
||||
bh = getattr(other, f'bias_hh_l{layer_idx}', None)
|
||||
|
||||
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
|
||||
if other.bidirectional:
|
||||
wi = getattr(other, f'weight_ih_l{layer_idx}_reverse')
|
||||
wh = getattr(other, f'weight_hh_l{layer_idx}_reverse')
|
||||
bi = getattr(other, f'bias_ih_l{layer_idx}_reverse', None)
|
||||
bh = getattr(other, f'bias_hh_l{layer_idx}_reverse', None)
|
||||
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
return layer
|
||||
|
||||
|
||||
class LSTM(torch.nn.Module):
|
||||
r"""A quantizable long short-term memory (LSTM).
|
||||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
|
||||
|
||||
Attributes:
|
||||
layers : instances of the `_LSTMLayer`
|
||||
|
||||
.. note::
|
||||
To access the weights and biases, you need to access them per layer.
|
||||
See examples below.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.nn.quantizable as nnqa
|
||||
>>> rnn = nnqa.LSTM(10, 20, 2)
|
||||
>>> input = torch.randn(5, 3, 10)
|
||||
>>> h0 = torch.randn(2, 3, 20)
|
||||
>>> c0 = torch.randn(2, 3, 20)
|
||||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||
>>> # To get the weights:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> print(rnn.layers[0].weight_ih)
|
||||
tensor([[...]])
|
||||
>>> print(rnn.layers[0].weight_hh)
|
||||
AssertionError: There is no reverse path in the non-bidirectional layer
|
||||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTM
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int,
|
||||
num_layers: int = 1, bias: bool = True,
|
||||
batch_first: bool = False, dropout: float = 0.,
|
||||
bidirectional: bool = False,
|
||||
device=None, dtype=None) -> None:
|
||||
factory_kwargs = {'device': device, 'dtype': dtype}
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.num_layers = num_layers
|
||||
self.bias = bias
|
||||
self.batch_first = batch_first
|
||||
self.dropout = float(dropout)
|
||||
self.bidirectional = bidirectional
|
||||
self.training = False # We don't want to train using this module
|
||||
num_directions = 2 if bidirectional else 1
|
||||
|
||||
if not isinstance(dropout, numbers.Number) or not 0 <= dropout <= 1 or \
|
||||
isinstance(dropout, bool):
|
||||
raise ValueError("dropout should be a number in range [0, 1] "
|
||||
"representing the probability of an element being "
|
||||
"zeroed")
|
||||
if dropout > 0:
|
||||
warnings.warn("dropout option for quantizable LSTM is ignored. "
|
||||
"If you are training, please, use nn.LSTM version "
|
||||
"followed by `prepare` step.")
|
||||
if num_layers == 1:
|
||||
warnings.warn("dropout option adds dropout after all but last "
|
||||
"recurrent layer, so non-zero dropout expects "
|
||||
"num_layers greater than 1, but got dropout={} "
|
||||
"and num_layers={}".format(dropout, num_layers))
|
||||
|
||||
layers = [_LSTMLayer(self.input_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional, **factory_kwargs)]
|
||||
for layer in range(1, num_layers):
|
||||
layers.append(_LSTMLayer(self.hidden_size, self.hidden_size,
|
||||
self.bias, batch_first=False,
|
||||
bidirectional=self.bidirectional,
|
||||
**factory_kwargs))
|
||||
self.layers = torch.nn.ModuleList(layers)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
max_batch_size = x.size(1)
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
if hidden is None:
|
||||
zeros = torch.zeros(num_directions, max_batch_size,
|
||||
self.hidden_size, dtype=torch.float,
|
||||
device=x.device)
|
||||
zeros.squeeze_(0)
|
||||
if x.is_quantized:
|
||||
zeros = torch.quantize_per_tensor(zeros, scale=1.0,
|
||||
zero_point=0, dtype=x.dtype)
|
||||
hxcx = [(zeros, zeros) for _ in range(self.num_layers)]
|
||||
else:
|
||||
hidden_non_opt = torch.jit._unwrap_optional(hidden)
|
||||
if isinstance(hidden_non_opt[0], Tensor):
|
||||
hx = hidden_non_opt[0].reshape(self.num_layers, num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size).unbind(0)
|
||||
cx = hidden_non_opt[1].reshape(self.num_layers, num_directions,
|
||||
max_batch_size,
|
||||
self.hidden_size).unbind(0)
|
||||
hxcx = [(hx[idx].squeeze_(0), cx[idx].squeeze_(0)) for idx in range(self.num_layers)]
|
||||
else:
|
||||
hxcx = hidden_non_opt
|
||||
|
||||
hx_list = []
|
||||
cx_list = []
|
||||
for idx, layer in enumerate(self.layers):
|
||||
x, (h, c) = layer(x, hxcx[idx])
|
||||
hx_list.append(torch.jit._unwrap_optional(h))
|
||||
cx_list.append(torch.jit._unwrap_optional(c))
|
||||
hx_tensor = torch.stack(hx_list)
|
||||
cx_tensor = torch.stack(cx_list)
|
||||
|
||||
# We are creating another dimension for bidirectional case
|
||||
# need to collapse it
|
||||
hx_tensor = hx_tensor.reshape(-1, hx_tensor.shape[-2], hx_tensor.shape[-1])
|
||||
cx_tensor = cx_tensor.reshape(-1, cx_tensor.shape[-2], cx_tensor.shape[-1])
|
||||
|
||||
if self.batch_first:
|
||||
x = x.transpose(0, 1)
|
||||
|
||||
return x, (hx_tensor, cx_tensor)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QuantizableLSTM'
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, qconfig=None):
|
||||
assert isinstance(other, cls._FLOAT_MODULE)
|
||||
assert (hasattr(other, 'qconfig') or qconfig)
|
||||
observed = cls(other.input_size, other.hidden_size, other.num_layers,
|
||||
other.bias, other.batch_first, other.dropout,
|
||||
other.bidirectional)
|
||||
observed.qconfig = getattr(other, 'qconfig', qconfig)
|
||||
for idx in range(other.num_layers):
|
||||
observed.layers[idx] = _LSTMLayer.from_float(other, idx, qconfig,
|
||||
batch_first=False)
|
||||
observed.eval()
|
||||
observed = torch.ao.quantization.prepare(observed, inplace=True)
|
||||
return observed
|
||||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
# The whole flow is float -> observed -> quantized
|
||||
# This class does float -> observed only
|
||||
raise NotImplementedError("It looks like you are trying to convert a "
|
||||
"non-quantizable LSTM module. Please, see "
|
||||
"the examples on quantizable LSTMs.")
|
||||
from torch.ao.nn.quantizable.modules.rnn import LSTM
|
||||
from torch.ao.nn.quantizable.modules.rnn import LSTMCell
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user