mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/29529 Pull Request resolved: https://github.com/pytorch/glow/pull/3771 We would like to replace `conv_prepack` with `conv2d_prepack` and `conv_unpack` with `conv2d_unpack`. This makes the naming consistent between 2D and 3D conv: ``` torch.ops.quantized.conv2d_prepack torch.ops.quantized.conv2d_unpack torch.ops.quantized.conv2d torch.ops.quantized.conv3d_prepack torch.ops.quantized.conv3d_unpack torch.ops.quantized.conv3d ``` We should do this earlier rather than later when we have more users for the quantized conv2d ops, for better engineering. The replacement bash command is as the follows: ``` find ./ -type f -exec sed -i -e 's/quantized::conv_prepack/quantized::conv2d_prepack/g' {} \; find ./ -type f -exec sed -i -e 's/quantized::conv_unpack/quantized::conv2d_unpack/g' {} \; find ./ -type f -exec sed -i -e 's/torch.ops.quantized.conv_prepack/torch.ops.quantized.conv2d_prepack/g' {} \; find ./ -type f -exec sed -i -e 's/torch.ops.quantized.conv_unpack/torch.ops.quantized.conv2d_unpack/g' {} \; ``` ghstack-source-id: 93661879 Test Plan: CI Reviewed By: jackm321 Differential Revision: D18421079 fbshipit-source-id: 17ae8b1ee79223bd2c5d4bbccd57af6580c4ab12
150 lines
5.3 KiB
Python
150 lines
5.3 KiB
Python
from __future__ import absolute_import, division, print_function, unicode_literals
|
|
|
|
import torch
|
|
from .qconfig import QConfig
|
|
|
|
class ConvPackedParams(torch.nn.Module):
|
|
def __init__(self):
|
|
super(ConvPackedParams, self).__init__()
|
|
wq = torch._empty_affine_quantized([1, 1, 1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
|
|
self.stride = [1, 1]
|
|
self.padding = [0, 0]
|
|
self.dilation = [1, 1]
|
|
self.groups = 1
|
|
self.set_weight_bias(wq, None)
|
|
|
|
@torch.jit.export
|
|
def set_conv_params(self, stride, padding, dilation, groups):
|
|
# type: (List[int], List[int], List[int], int) -> None
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.dilation = dilation
|
|
self.groups = groups
|
|
|
|
@torch.jit.export
|
|
def set_weight_bias(self, weight, bias):
|
|
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
|
|
self._packed_params = torch.ops.quantized.conv2d_prepack(weight, bias, self.stride, self.padding, self.dilation, self.groups)
|
|
|
|
@torch.jit.export
|
|
def _weight_bias(self):
|
|
return torch.ops.quantized.conv2d_unpack(self._packed_params)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
qweight, bias = self._weight_bias()
|
|
return (qweight,
|
|
bias,
|
|
self.stride,
|
|
self.padding,
|
|
self.dilation,
|
|
self.groups,
|
|
self.training)
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
self.stride = state[2]
|
|
self.padding = state[3]
|
|
self.dilation = state[4]
|
|
self.groups = state[5]
|
|
self.set_weight_bias(state[0],
|
|
state[1])
|
|
self.training = state[6]
|
|
|
|
class LinearPackedParams(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LinearPackedParams, self).__init__()
|
|
wq = torch._empty_affine_quantized([1, 1], scale=1.0, zero_point=0, dtype=torch.qint8)
|
|
self.set_weight_bias(wq, None)
|
|
|
|
@torch.jit.export
|
|
def set_weight_bias(self, weight, bias):
|
|
# type: (torch.Tensor, Optional[torch.Tensor]) -> None
|
|
self._packed_params = torch.ops.quantized.linear_prepack(weight, bias)
|
|
|
|
@torch.jit.export
|
|
def _weight_bias(self):
|
|
return torch.ops.quantized.linear_unpack(self._packed_params)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
qweight, bias = self._weight_bias()
|
|
return qweight, bias, self.training
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
# type: (Tuple[Tensor, Optional[Tensor], bool]) -> None
|
|
self.set_weight_bias(state[0], state[1])
|
|
self.training = state[2]
|
|
|
|
|
|
linear_packed_params = None
|
|
conv_packed_params = None
|
|
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
|
linear_packed_params = torch.jit.script(LinearPackedParams())._c
|
|
conv_packed_params = torch.jit.script(ConvPackedParams())._c
|
|
|
|
def _check_is_script_module(model):
|
|
if not isinstance(model, torch.jit.ScriptModule):
|
|
raise ValueError('input must be a script module, got: ' + str(type(model)))
|
|
|
|
def prepare_script(model, qconfig_dict, inplace=False):
|
|
_check_is_script_module(model)
|
|
if not inplace:
|
|
model = model.copy()
|
|
torch._C._jit_pass_insert_observers(model._c,
|
|
'forward',
|
|
qconfig_dict,
|
|
True)
|
|
return model
|
|
|
|
def convert_script(model, inplace=False):
|
|
_check_is_script_module(model)
|
|
if not inplace:
|
|
model = model.copy()
|
|
torch._C._jit_pass_insert_quant_dequant(model._c, 'forward', True)
|
|
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
|
torch._C._jit_pass_insert_prepack_unpack(model._c)
|
|
if linear_packed_params and conv_packed_params:
|
|
torch._C._jit_pass_fold_prepack(model._c,
|
|
linear_packed_params,
|
|
conv_packed_params)
|
|
|
|
return model
|
|
|
|
# TODO: non-scriptable QConfig will be supported later
|
|
def script_qconfig(qconfig):
|
|
return QConfig(
|
|
activation=torch.jit.script(qconfig.activation())._c,
|
|
weight=torch.jit.script(qconfig.weight())._c)
|
|
|
|
def quantize_script(model, qconfig_dict, run_fn, run_args, inplace=False):
|
|
_check_is_script_module(model)
|
|
if not model._c._has_method('forward'):
|
|
raise ValueError('input script module does not have forward method')
|
|
assert not inplace, "We don't support inplace right now"
|
|
if not inplace:
|
|
model = model.copy()
|
|
scripted_qconfig_dict = {k: script_qconfig(v) for k, v in qconfig_dict.items()}
|
|
# We are not going to run fold_convbn pass right now
|
|
# since it is not able to work correctly, we will
|
|
# revisit after constants is properly handled in
|
|
# JIT
|
|
# torch._C._jit_pass_fold_convbn(model._c)
|
|
prepare_script(model, scripted_qconfig_dict, True)
|
|
run_fn(model._c._get_method('forward'), *run_args)
|
|
# When we mutating graph we didn't create a new ClassType
|
|
# and the graph executor will run an out dated version
|
|
# of the graph if we do inplace graph mutation, therefore
|
|
# we copy the model here
|
|
# [TODO] This will be fixed later when we figure out
|
|
# how to properly mutate types
|
|
model = convert_script(model, False)
|
|
return model
|