Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"

This reverts commit 3fcf66f61f.

Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
PyTorch MergeBot 2024-11-22 15:54:05 +00:00
parent 080f992d68
commit cf1d95a965
4 changed files with 48 additions and 183 deletions

View File

@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
@override_qengines
def test_custom_module_lstm(self):
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_float(cls, other, qconfig=None):
return super().from_float(other, qconfig, split_gates=True)
qengine = torch.backends.quantized.engine
batch_size = 4
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
Bias = [False, True]
Batch_first = [False, True]
Bidirectional = [False, True]
Split_gates = [False, True]
dtype = np.uint8
qtype = torch.quint8
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
x = qx.dequantize()
with torch.no_grad():
for bias, batch_first, bidirectional, split_gates in itertools.product(
Bias, Batch_first, Bidirectional, Split_gates):
for bias, batch_first, bidirectional in itertools.product(
Bias, Batch_first, Bidirectional):
# Assume 12dB is sufficient for functional equivalence
# Without the bias, linear performs poorly
min_power = 10 if bias else 5
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
# Prepare
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
custom_config_dict = (
None
if not split_gates
else { # switch to class with split_gates True via from_float
"float_to_observed_custom_module_class": {
torch.nn.LSTM: QuantizableLSTMSplitGates
},
"observed_to_quantized_custom_module_class": {
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
},
}
)
lstm_prepared = torch.ao.quantization.prepare(
lstm, prepare_custom_config_dict=custom_config_dict
)
lstm_prepared = torch.ao.quantization.prepare(lstm)
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
self.assertEqual(
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
)
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
# Calibrate
y = lstm_prepared(x)
self.assertEqual(y_ref, y)
# Quantize
lstm_quantized = torch.ao.quantization.convert(
lstm_prepared, convert_custom_config_dict=custom_config_dict
)
lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
qy = lstm_quantized(qx)

View File

@ -21,11 +21,6 @@ class LSTMCell(torch.nn.Module):
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
`split_gates`: specify True to compute the input/forget/cell/output gates separately
to avoid an intermediate tensor which is subsequently chunk'd. This optimization can
be beneficial for on-device inference latency. This flag is cascaded down from the
parent classes.
Examples::
>>> import torch.ao.nn.quantizable as nnqa
@ -39,7 +34,6 @@ class LSTMCell(torch.nn.Module):
... output.append(hx)
"""
_FLOAT_MODULE = torch.nn.LSTMCell
__constants__ = ["split_gates"] # for jit.script
def __init__(
self,
@ -48,37 +42,20 @@ class LSTMCell(torch.nn.Module):
bias: bool = True,
device=None,
dtype=None,
*,
split_gates=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.input_size = input_dim
self.hidden_size = hidden_dim
self.bias = bias
self.split_gates = split_gates
if not split_gates:
self.igates: torch.nn.Module = torch.nn.Linear(
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.hgates: torch.nn.Module = torch.nn.Linear(
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional()
else:
# keep separate Linear layers for each gate
self.igates = torch.nn.ModuleDict()
self.hgates = torch.nn.ModuleDict()
self.gates = torch.nn.ModuleDict()
for g in ["input", "forget", "cell", "output"]:
self.igates[g] = torch.nn.Linear(
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
self.hgates[g] = torch.nn.Linear(
hidden_dim, hidden_dim, bias=bias, **factory_kwargs
)
self.gates[g] = torch.ao.nn.quantized.FloatFunctional()
self.igates = torch.nn.Linear(
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.hgates = torch.nn.Linear(
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
)
self.gates = torch.ao.nn.quantized.FloatFunctional()
self.input_gate = torch.nn.Sigmoid()
self.forget_gate = torch.nn.Sigmoid()
@ -103,29 +80,16 @@ class LSTMCell(torch.nn.Module):
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden
if not self.split_gates:
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
igates = self.igates(x)
hgates = self.hgates(hx)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = self.input_gate(input_gate)
forget_gate = self.forget_gate(forget_gate)
cell_gate = self.cell_gate(cell_gate)
out_gate = self.output_gate(out_gate)
else:
# apply each input + hidden projection and add together
gate = {}
for (key, gates), igates, hgates in zip(
self.gates.items(), self.igates.values(), self.hgates.values()
):
gate[key] = gates.add(igates(x), hgates(hx))
input_gate = self.input_gate(gate["input"])
forget_gate = self.forget_gate(gate["forget"])
cell_gate = self.cell_gate(gate["cell"])
out_gate = self.output_gate(gate["output"])
input_gate = self.input_gate(input_gate)
forget_gate = self.forget_gate(forget_gate)
cell_gate = self.cell_gate(cell_gate)
out_gate = self.output_gate(out_gate)
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
@ -158,7 +122,7 @@ class LSTMCell(torch.nn.Module):
return "QuantizableLSTMCell"
@classmethod
def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False):
def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell.
Args:
@ -168,52 +132,25 @@ class LSTMCell(torch.nn.Module):
assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1]
hidden_size = wh.shape[1]
cell = cls(
input_dim=input_size,
hidden_dim=hidden_size,
bias=(bi is not None),
split_gates=split_gates,
)
if not split_gates:
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
else:
# split weight/bias
for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]):
for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()):
gate.weight = torch.nn.Parameter(w_chunk)
if b is not None:
for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()):
gate.bias = torch.nn.Parameter(b_chunk)
cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
return cell
@classmethod
def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False):
def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
observed = cls.from_params(
other.weight_ih,
other.weight_hh,
other.bias_ih,
other.bias_hh,
split_gates=split_gates,
other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
)
observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig
if split_gates:
# also apply qconfig directly to Linear modules
for g in observed.igates.values():
g.qconfig = other.qconfig
for g in observed.hgates.values():
g.qconfig = other.qconfig
return observed
@ -231,14 +168,10 @@ class _LSTMSingleLayer(torch.nn.Module):
bias: bool = True,
device=None,
dtype=None,
*,
split_gates=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.cell = LSTMCell(
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
)
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = []
@ -252,9 +185,7 @@ class _LSTMSingleLayer(torch.nn.Module):
@classmethod
def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs)
layer = cls(
cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates
)
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
layer.cell = cell
return layer
@ -271,23 +202,17 @@ class _LSTMLayer(torch.nn.Module):
bidirectional: bool = False,
device=None,
dtype=None,
*,
split_gates=False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.batch_first = batch_first
self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer(
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
if self.bidirectional:
self.layer_bw = _LSTMSingleLayer(
input_dim,
hidden_dim,
bias=bias,
split_gates=split_gates,
**factory_kwargs,
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
@ -358,34 +283,22 @@ class _LSTMLayer(torch.nn.Module):
bias = kwargs.get("bias", other.bias)
batch_first = kwargs.get("batch_first", other.batch_first)
bidirectional = kwargs.get("bidirectional", other.bidirectional)
split_gates = kwargs.get("split_gates", False)
layer = cls(
input_size,
hidden_size,
bias,
batch_first,
bidirectional,
split_gates=split_gates,
)
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}")
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
layer.layer_fw = _LSTMSingleLayer.from_params(
wi, wh, bi, bh, split_gates=split_gates
)
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
if other.bidirectional:
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
layer.layer_bw = _LSTMSingleLayer.from_params(
wi, wh, bi, bh, split_gates=split_gates
)
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
return layer
@ -429,8 +342,6 @@ class LSTM(torch.nn.Module):
bidirectional: bool = False,
device=None,
dtype=None,
*,
split_gates: bool = False,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
@ -475,7 +386,6 @@ class LSTM(torch.nn.Module):
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
split_gates=split_gates,
**factory_kwargs,
)
]
@ -486,7 +396,6 @@ class LSTM(torch.nn.Module):
self.bias,
batch_first=False,
bidirectional=self.bidirectional,
split_gates=split_gates,
**factory_kwargs,
)
for layer in range(1, num_layers)
@ -552,7 +461,7 @@ class LSTM(torch.nn.Module):
return "QuantizableLSTM"
@classmethod
def from_float(cls, other, qconfig=None, split_gates=False):
def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE)
assert hasattr(other, "qconfig") or qconfig
observed = cls(
@ -563,12 +472,11 @@ class LSTM(torch.nn.Module):
other.batch_first,
other.dropout,
other.bidirectional,
split_gates=split_gates,
)
observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(
other, idx, qconfig, batch_first=False, split_gates=split_gates
other, idx, qconfig, batch_first=False
)
# Prepare the model

View File

@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_observed(cls, other):
assert isinstance(other, cls._FLOAT_MODULE)
assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True
)

View File

@ -25,7 +25,6 @@ def _get_lstm_with_individually_observed_parts(
tanh_obs_ctr: Optional[_PartialWrapper] = None,
cell_state_obs_ctr: Optional[_PartialWrapper] = None,
hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
split_gates: bool = False,
) -> torch.ao.nn.quantizable.LSTM:
"""
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
@ -76,7 +75,6 @@ def _get_lstm_with_individually_observed_parts(
float_lstm.batch_first,
float_lstm.dropout,
float_lstm.bidirectional,
split_gates=split_gates,
)
quantizable_lstm.qconfig = float_lstm.qconfig
@ -84,11 +82,7 @@ def _get_lstm_with_individually_observed_parts(
quantizable_lstm.layers[
idx
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
float_lstm,
idx,
float_lstm.qconfig,
batch_first=False,
split_gates=split_gates,
float_lstm, idx, float_lstm.qconfig, batch_first=False
)
# Build QConfigMapping for the LSTM cell
@ -111,25 +105,13 @@ def _get_lstm_with_individually_observed_parts(
# to configure these ops in FX graph mode quantization today. This is because
# the FloatFunctional modules simply disappear from the graph after tracing.
# In the future, we should rewrite quantizable LSTM without FloatFunctionals.
if not split_gates:
op_index_to_activation_post_process_ctr = {
(torch.add, 0): linear_output_obs_ctr, # gates.add
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
}
else:
op_index_to_activation_post_process_ctr = {
(torch.add, 0): linear_output_obs_ctr, # gates.add (input)
(torch.add, 1): linear_output_obs_ctr, # gates.add (forget)
(torch.add, 2): linear_output_obs_ctr, # gates.add (cell)
(torch.add, 3): linear_output_obs_ctr, # gates.add (output)
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
}
op_index_to_activation_post_process_ctr = {
(torch.add, 0): linear_output_obs_ctr, # gates.add
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
}
add_count = 0
mul_count = 0
for node in cell.graph.nodes: