pytorch/torch/ao/nn/quantized/modules/rnn.py

60 lines
1.9 KiB
Python

from typing import Any
import torch
__all__ = [
"LSTM",
]
class LSTM(torch.ao.nn.quantizable.LSTM):
r"""A quantized long short-term memory (LSTM).
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
Attributes:
layers : instances of the `_LSTMLayer`
.. note::
To access the weights and biases, you need to access them per layer.
See examples in :class:`~torch.ao.nn.quantizable.LSTM`
Examples::
>>> # xdoctest: +SKIP
>>> custom_module_config = {
... 'float_to_observed_custom_module_class': {
... nn.LSTM: nn.quantizable.LSTM,
... },
... 'observed_to_quantized_custom_module_class': {
... nn.quantizable.LSTM: nn.quantized.LSTM,
... }
... }
>>> tq.prepare(model, prepare_custom_module_class=custom_module_config)
>>> tq.convert(model, convert_custom_module_class=custom_module_config)
"""
_FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment]
def _get_name(self) -> str:
return "QuantizedLSTM"
@classmethod
def from_float(cls, *args: Any, **kwargs: Any) -> None:
# The whole flow is float -> observed -> quantized
# This class does observed -> quantized only
raise NotImplementedError(
"It looks like you are trying to convert a "
"non-observed LSTM module. Please, see "
"the examples on quantizable LSTMs."
)
@classmethod
def from_observed(cls: type["LSTM"], other: torch.ao.nn.quantizable.LSTM) -> "LSTM":
assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)
converted.__class__ = cls
return converted