mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied. - #94587 - #94588 - #94592 Also, methods with only a `super()` call are removed: ```diff class MyModule(nn.Module): - def __init__(self): - super().__init__() - def forward(self, ...): ... ``` Some cases that change the semantics should be kept unchanged. E.g.:f152a79be9/caffe2/python/net_printer.py (L184-L190)f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)Pull Request resolved: https://github.com/pytorch/pytorch/pull/94592 Approved by: https://github.com/ezyang, https://github.com/seemethere
299 lines
12 KiB
Python
299 lines
12 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_quantization import (
|
|
skipIfNoFBGEMM
|
|
)
|
|
from torch.testing._internal.common_utils import suppress_warnings
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
from typing import Tuple
|
|
import copy
|
|
|
|
class TestDeprecatedJitQuantized(JitTestCase):
|
|
@skipIfNoFBGEMM
|
|
def test_rnn_cell_quantized(self):
|
|
d_in, d_hid = 2, 2
|
|
|
|
for cell in [
|
|
torch.nn.LSTMCell(d_in, d_hid).float(),
|
|
torch.nn.GRUCell(d_in, d_hid).float(),
|
|
torch.nn.RNNCell(d_in, d_hid).float(),
|
|
]:
|
|
if isinstance(cell, torch.nn.LSTMCell):
|
|
num_chunks = 4
|
|
elif isinstance(cell, torch.nn.GRUCell):
|
|
num_chunks = 3
|
|
elif isinstance(cell, torch.nn.RNNCell):
|
|
num_chunks = 1
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
|
|
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float)
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
hx = torch.tensor(h0_vals, dtype=torch.float)
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
cx = torch.tensor(h0_vals, dtype=torch.float)
|
|
hiddens = (hx, cx)
|
|
else:
|
|
hiddens = hx
|
|
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super().__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x: torch.Tensor,
|
|
hiddens: Tuple[torch.Tensor, torch.Tensor]
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return self.cell(x, hiddens)
|
|
else:
|
|
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super().__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> torch.Tensor:
|
|
return self.cell(x, hiddens)
|
|
|
|
cell = ScriptWrapper(cell)
|
|
outs = cell(x, hiddens)
|
|
cell = self.getExportImportCopyWithPacking(cell)
|
|
|
|
outs = cell(x, hiddens)
|
|
ref_outs = ref(x, hiddens)
|
|
|
|
self.assertEqual(len(outs), len(ref_outs))
|
|
for out, ref_out in zip(outs, ref_outs):
|
|
torch.testing.assert_close(out, ref_out)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_rnn_quantized(self):
|
|
d_in, d_hid = 2, 2
|
|
|
|
for cell in [
|
|
torch.nn.LSTM(d_in, d_hid).float(),
|
|
torch.nn.GRU(d_in, d_hid).float(),
|
|
]:
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
if isinstance(cell, torch.nn.LSTM):
|
|
num_chunks = 4
|
|
elif isinstance(cell, torch.nn.GRU):
|
|
num_chunks = 3
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh_l0 = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
cell_int8 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.int8)
|
|
cell_fp16 = torch.jit.quantized.quantize_rnn_modules(cell, dtype=torch.float16)
|
|
|
|
niter = 10
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float).unsqueeze(0).repeat(niter, 1, 1)
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
hx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
cx = torch.tensor(h0_vals, dtype=torch.float).unsqueeze(0)
|
|
|
|
if isinstance(ref, torch.nn.LSTM):
|
|
hiddens = (hx, cx)
|
|
elif isinstance(ref, torch.nn.GRU):
|
|
hiddens = hx
|
|
|
|
ref_out, ref_hid = ref(x, hiddens)
|
|
|
|
# Compare int8 quantized to unquantized
|
|
output_int8, final_hiddens_int8 = cell_int8(x, hiddens)
|
|
|
|
torch.testing.assert_close(output_int8, ref_out)
|
|
for out, ref in zip(final_hiddens_int8, ref_hid):
|
|
torch.testing.assert_close(out, ref)
|
|
|
|
# Compare fp16 quantized to unquantized
|
|
output_fp16, final_hiddens_fp16 = cell_fp16(x, hiddens)
|
|
|
|
torch.testing.assert_close(output_fp16, ref_out)
|
|
for out, ref in zip(final_hiddens_fp16, ref_hid):
|
|
torch.testing.assert_close(out, ref)
|
|
|
|
def compare_quantized_unquantized(ScriptWrapper, cell):
|
|
wrapper = ScriptWrapper(cell)
|
|
|
|
# Compare quantize scripted module to unquantized
|
|
script_out, script_hid = wrapper(x, hiddens)
|
|
torch.testing.assert_close(script_out, ref_out)
|
|
for out, ref in zip(script_hid, ref_hid):
|
|
torch.testing.assert_close(out, ref)
|
|
|
|
# Compare export/import to unquantized
|
|
export_import_wrapper = self.getExportImportCopyWithPacking(wrapper)
|
|
ei_out, ei_hid = export_import_wrapper(x, hiddens)
|
|
torch.testing.assert_close(ei_out, ref_out)
|
|
for out, ref in zip(ei_hid, ref_hid):
|
|
torch.testing.assert_close(out, ref)
|
|
|
|
if isinstance(cell, torch.jit.quantized.QuantizedGRU):
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super().__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x: torch.Tensor, hiddens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
return self.cell(x, hiddens)
|
|
|
|
compare_quantized_unquantized(ScriptWrapper, cell)
|
|
elif isinstance(cell, torch.jit.quantized.QuantizedLSTM):
|
|
for cell in [cell_int8, cell_fp16]:
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super().__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor])
|
|
# -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]
|
|
return self.cell(x, hiddens)
|
|
compare_quantized_unquantized(ScriptWrapper, cell)
|
|
|
|
if 'fbgemm' in torch.backends.quantized.supported_engines:
|
|
# Suppression: using deprecated quant api
|
|
@suppress_warnings
|
|
def test_quantization_modules(self):
|
|
K1, N1 = 2, 2
|
|
|
|
class FooBar(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(K1, N1).float()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
fb = FooBar()
|
|
fb.linear1.weight = torch.nn.Parameter(
|
|
torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
|
|
fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
|
|
|
|
x = (torch.rand(1, K1).float() - 0.5) / 10.0
|
|
value = torch.tensor([[100, -150]], dtype=torch.float)
|
|
|
|
y_ref = fb(value)
|
|
|
|
fb_int8 = torch.jit.quantized.quantize_linear_modules(fb)
|
|
traced_int8 = torch.jit.trace(fb_int8, (x,))
|
|
fb_int8 = self.getExportImportCopyWithPacking(traced_int8)
|
|
y_int8 = fb_int8(value)
|
|
|
|
fb_fp16 = torch.jit.quantized.quantize_linear_modules(fb, torch.float16)
|
|
traced_fp16 = torch.jit.trace(fb_fp16, (x,))
|
|
fb_fp16 = self.getExportImportCopyWithPacking(traced_fp16)
|
|
y_fp16 = fb_fp16(value)
|
|
|
|
torch.testing.assert_close(y_int8, y_ref, rtol=0.0001, atol=1e-3)
|
|
torch.testing.assert_close(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
|
|
|
|
@skipIfNoFBGEMM
|
|
def test_erase_class_tensor_shapes(self):
|
|
class Linear(torch.nn.Module):
|
|
def __init__(self, in_features, out_features):
|
|
super().__init__()
|
|
qweight = torch._empty_affine_quantized(
|
|
[out_features, in_features], scale=1, zero_point=0,
|
|
dtype=torch.qint8)
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(qweight)
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
return (torch.ops.quantized.linear_unpack(self._packed_weight)[0], self.training)
|
|
|
|
def forward(self):
|
|
return self._packed_weight
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(state[0])
|
|
self.training = state[1]
|
|
|
|
@property
|
|
def weight(self):
|
|
return torch.ops.quantized.linear_unpack(self._packed_weight)[0]
|
|
|
|
@weight.setter
|
|
def weight(self, w):
|
|
self._packed_weight = torch.ops.quantized.linear_prepack(w)
|
|
|
|
with torch._jit_internal._disable_emit_hooks():
|
|
x = torch.jit.script(Linear(10, 10))
|
|
torch._C._jit_pass_erase_shape_information(x.graph)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
raise RuntimeError("This test file is not meant to be run directly, use:\n\n"
|
|
"\tpython test/test_quantization.py TESTNAME\n\n"
|
|
"instead.")
|