mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157234 Approved by: https://github.com/jingsh ghstack dependencies: #157231, #157232
60 lines
1.9 KiB
Python
60 lines
1.9 KiB
Python
from typing import Any
|
|
|
|
import torch
|
|
|
|
|
|
__all__ = [
|
|
"LSTM",
|
|
]
|
|
|
|
|
|
class LSTM(torch.ao.nn.quantizable.LSTM):
|
|
r"""A quantized long short-term memory (LSTM).
|
|
|
|
For the description and the argument types, please, refer to :class:`~torch.nn.LSTM`
|
|
|
|
Attributes:
|
|
layers : instances of the `_LSTMLayer`
|
|
|
|
.. note::
|
|
To access the weights and biases, you need to access them per layer.
|
|
See examples in :class:`~torch.ao.nn.quantizable.LSTM`
|
|
|
|
Examples::
|
|
>>> # xdoctest: +SKIP
|
|
>>> custom_module_config = {
|
|
... 'float_to_observed_custom_module_class': {
|
|
... nn.LSTM: nn.quantizable.LSTM,
|
|
... },
|
|
... 'observed_to_quantized_custom_module_class': {
|
|
... nn.quantizable.LSTM: nn.quantized.LSTM,
|
|
... }
|
|
... }
|
|
>>> tq.prepare(model, prepare_custom_module_class=custom_module_config)
|
|
>>> tq.convert(model, convert_custom_module_class=custom_module_config)
|
|
"""
|
|
|
|
_FLOAT_MODULE = torch.ao.nn.quantizable.LSTM # type: ignore[assignment]
|
|
|
|
def _get_name(self) -> str:
|
|
return "QuantizedLSTM"
|
|
|
|
@classmethod
|
|
def from_float(cls, *args: Any, **kwargs: Any) -> None:
|
|
# The whole flow is float -> observed -> quantized
|
|
# This class does observed -> quantized only
|
|
raise NotImplementedError(
|
|
"It looks like you are trying to convert a "
|
|
"non-observed LSTM module. Please, see "
|
|
"the examples on quantizable LSTMs."
|
|
)
|
|
|
|
@classmethod
|
|
def from_observed(cls: type["LSTM"], other: torch.ao.nn.quantizable.LSTM) -> "LSTM":
|
|
assert isinstance(other, cls._FLOAT_MODULE) # type: ignore[has-type]
|
|
converted = torch.ao.quantization.convert(
|
|
other, inplace=False, remove_qconfig=True
|
|
)
|
|
converted.__class__ = cls
|
|
return converted
|