mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the `project-excludes` field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: 0 errors (4,263 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164748 Approved by: https://github.com/oulgen
313 lines
9.6 KiB
Python
313 lines
9.6 KiB
Python
# mypy: allow-untyped-defs
|
|
from typing import ClassVar, Literal, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.ao.nn.intrinsic import _FusedModule
|
|
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
|
|
from torch.nn.modules.utils import _pair, _single, _triple
|
|
|
|
|
|
__all__ = ["Conv1d", "Conv2d", "Conv3d"]
|
|
|
|
|
|
class _ConvNd(nn.modules.conv._ConvNd):
|
|
_FLOAT_MODULE: ClassVar[type[nn.modules.conv._ConvNd]]
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: tuple[int, ...],
|
|
stride: tuple[int, ...],
|
|
padding: Union[str, tuple[int, ...]],
|
|
dilation: tuple[int, ...],
|
|
transposed: bool,
|
|
output_padding: tuple[int, ...],
|
|
groups: int,
|
|
bias: bool,
|
|
padding_mode: Literal["zeros", "reflect", "replicate", "circular"],
|
|
qconfig=None,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
factory_kwargs = {"device": device, "dtype": dtype}
|
|
nn.modules.conv._ConvNd.__init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
padding,
|
|
dilation,
|
|
transposed,
|
|
output_padding,
|
|
groups,
|
|
bias,
|
|
padding_mode,
|
|
**factory_kwargs,
|
|
)
|
|
assert qconfig, "qconfig must be provided for QAT module"
|
|
self.qconfig = qconfig
|
|
self.weight_fake_quant = qconfig.weight(factory_kwargs=factory_kwargs)
|
|
|
|
def forward(self, input):
|
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
|
|
|
@staticmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
r"""Create a qat module from a float module
|
|
|
|
Args:
|
|
`mod`: a float module, either produced by torch.ao.quantization utilities
|
|
or directly from user
|
|
"""
|
|
assert type(mod) == cls._FLOAT_MODULE, (
|
|
"qat."
|
|
+ cls.__name__
|
|
+ ".from_float only works for "
|
|
+ cls._FLOAT_MODULE.__name__
|
|
)
|
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
|
assert mod.qconfig, "Input float module must have a valid qconfig"
|
|
if issubclass(type(mod), _FusedModule):
|
|
mod = mod[0]
|
|
qconfig = mod.qconfig
|
|
qat_conv = cls(
|
|
mod.in_channels,
|
|
mod.out_channels,
|
|
mod.kernel_size,
|
|
stride=mod.stride,
|
|
padding=mod.padding,
|
|
dilation=mod.dilation,
|
|
groups=mod.groups,
|
|
bias=mod.bias is not None,
|
|
padding_mode=mod.padding_mode,
|
|
qconfig=qconfig,
|
|
)
|
|
qat_conv.weight = mod.weight
|
|
qat_conv.bias = mod.bias
|
|
return qat_conv
|
|
|
|
def to_float(self):
|
|
"""This works for both single qat conv, and the qat conv - relu modules
|
|
to convert the qat module to a floating point module
|
|
"""
|
|
cls = type(self)
|
|
conv = cls._FLOAT_CONV_MODULE( # type: ignore[attr-defined]
|
|
self.in_channels,
|
|
self.out_channels,
|
|
self.kernel_size,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
self.bias is not None,
|
|
self.padding_mode,
|
|
)
|
|
conv.weight = torch.nn.Parameter(self.weight.detach())
|
|
if self.bias is not None:
|
|
conv.bias = torch.nn.Parameter(self.bias.detach())
|
|
# conv relu
|
|
if issubclass(cls, _FusedModule):
|
|
modules = [conv]
|
|
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
|
relu = cls._FLOAT_RELU_MODULE()
|
|
modules.append(relu)
|
|
# pyrefly: ignore # missing-attribute
|
|
fused = cls._FLOAT_MODULE(*modules)
|
|
fused.train(self.training)
|
|
return fused
|
|
else:
|
|
return conv
|
|
|
|
|
|
class Conv1d(_ConvNd, nn.Conv1d):
|
|
r"""
|
|
A Conv1d module attached with FakeQuantize modules for weight,
|
|
used for quantization aware training.
|
|
|
|
We adopt the same interface as :class:`~torch.nn.Conv1d`
|
|
|
|
Similar to :class:`~torch.nn.Conv2d`, with FakeQuantize modules initialized to
|
|
default.
|
|
|
|
Attributes:
|
|
weight_fake_quant: fake quant module for weight
|
|
"""
|
|
|
|
_FLOAT_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
|
|
_FLOAT_CONV_MODULE: ClassVar[type[nn.Conv1d]] = nn.Conv1d
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: _size_1_t,
|
|
stride: _size_1_t = 1,
|
|
padding: Union[str, _size_1_t] = 0,
|
|
dilation: _size_1_t = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
|
|
qconfig=None,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
kernel_size_ = _single(kernel_size)
|
|
stride_ = _single(stride)
|
|
padding_ = padding if isinstance(padding, str) else _single(padding)
|
|
dilation_ = _single(dilation)
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size_,
|
|
stride=stride_,
|
|
padding=padding_,
|
|
dilation=dilation_,
|
|
transposed=False,
|
|
output_padding=_single(0),
|
|
groups=groups,
|
|
bias=bias,
|
|
padding_mode=padding_mode,
|
|
qconfig=qconfig,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
|
return super().from_float(
|
|
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
|
|
class Conv2d(_ConvNd, nn.Conv2d):
|
|
r"""
|
|
A Conv2d module attached with FakeQuantize modules for weight,
|
|
used for quantization aware training.
|
|
|
|
We adopt the same interface as `torch.nn.Conv2d`, please see
|
|
https://pytorch.org/docs/stable/nn.html?highlight=conv2d#torch.nn.Conv2d
|
|
for documentation.
|
|
|
|
Similar to `torch.nn.Conv2d`, with FakeQuantize modules initialized to
|
|
default.
|
|
|
|
Attributes:
|
|
weight_fake_quant: fake quant module for weight
|
|
"""
|
|
|
|
_FLOAT_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
|
|
_FLOAT_CONV_MODULE: ClassVar[type[nn.Conv2d]] = nn.Conv2d
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: _size_2_t,
|
|
stride: _size_2_t = 1,
|
|
padding: Union[str, _size_2_t] = 0,
|
|
dilation: _size_2_t = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
|
|
qconfig=None,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
kernel_size_ = _pair(kernel_size)
|
|
stride_ = _pair(stride)
|
|
padding_ = padding if isinstance(padding, str) else _pair(padding)
|
|
dilation_ = _pair(dilation)
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size_,
|
|
stride=stride_,
|
|
padding=padding_,
|
|
dilation=dilation_,
|
|
transposed=False,
|
|
output_padding=_pair(0),
|
|
groups=groups,
|
|
bias=bias,
|
|
padding_mode=padding_mode,
|
|
qconfig=qconfig,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
|
return super().from_float(
|
|
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|
|
|
|
|
|
class Conv3d(_ConvNd, nn.Conv3d):
|
|
r"""
|
|
A Conv3d module attached with FakeQuantize modules for weight,
|
|
used for quantization aware training.
|
|
|
|
We adopt the same interface as `torch.nn.Conv3d`, please see
|
|
https://pytorch.org/docs/stable/nn.html?highlight=conv3d#torch.nn.Conv3d
|
|
for documentation.
|
|
|
|
Similar to `torch.nn.Conv3d`, with FakeQuantize modules initialized to
|
|
default.
|
|
|
|
Attributes:
|
|
weight_fake_quant: fake quant module for weight
|
|
"""
|
|
|
|
_FLOAT_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
|
|
_FLOAT_CONV_MODULE: ClassVar[type[nn.Conv3d]] = nn.Conv3d
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels: int,
|
|
out_channels: int,
|
|
kernel_size: _size_3_t,
|
|
stride: _size_3_t = 1,
|
|
padding: Union[str, _size_3_t] = 0,
|
|
dilation: _size_3_t = 1,
|
|
groups: int = 1,
|
|
bias: bool = True,
|
|
padding_mode: Literal["zeros", "reflect", "replicate", "circular"] = "zeros",
|
|
qconfig=None,
|
|
device=None,
|
|
dtype=None,
|
|
) -> None:
|
|
kernel_size_ = _triple(kernel_size)
|
|
stride_ = _triple(stride)
|
|
padding_ = padding if isinstance(padding, str) else _triple(padding)
|
|
dilation_ = _triple(dilation)
|
|
super().__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size_,
|
|
stride=stride_,
|
|
padding=padding_,
|
|
dilation=dilation_,
|
|
transposed=False,
|
|
output_padding=_triple(0),
|
|
groups=groups,
|
|
bias=bias,
|
|
padding_mode=padding_mode,
|
|
qconfig=qconfig,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self._conv_forward(input, self.weight_fake_quant(self.weight), self.bias)
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False): # type: ignore[override]
|
|
return super().from_float(
|
|
cls, mod, use_precomputed_fake_quant=use_precomputed_fake_quant
|
|
)
|