mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"
This reverts commit 3fcf66f61f.
Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
parent
080f992d68
commit
cf1d95a965
|
|
@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
|
|||
|
||||
@override_qengines
|
||||
def test_custom_module_lstm(self):
|
||||
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
|
||||
@classmethod
|
||||
def from_float(cls, other, qconfig=None):
|
||||
return super().from_float(other, qconfig, split_gates=True)
|
||||
|
||||
qengine = torch.backends.quantized.engine
|
||||
|
||||
batch_size = 4
|
||||
|
|
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
|
|||
Bias = [False, True]
|
||||
Batch_first = [False, True]
|
||||
Bidirectional = [False, True]
|
||||
Split_gates = [False, True]
|
||||
|
||||
dtype = np.uint8
|
||||
qtype = torch.quint8
|
||||
|
|
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
|
|||
x = qx.dequantize()
|
||||
|
||||
with torch.no_grad():
|
||||
for bias, batch_first, bidirectional, split_gates in itertools.product(
|
||||
Bias, Batch_first, Bidirectional, Split_gates):
|
||||
for bias, batch_first, bidirectional in itertools.product(
|
||||
Bias, Batch_first, Bidirectional):
|
||||
# Assume 12dB is sufficient for functional equivalence
|
||||
# Without the bias, linear performs poorly
|
||||
min_power = 10 if bias else 5
|
||||
|
|
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
|
|||
|
||||
# Prepare
|
||||
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||
custom_config_dict = (
|
||||
None
|
||||
if not split_gates
|
||||
else { # switch to class with split_gates True via from_float
|
||||
"float_to_observed_custom_module_class": {
|
||||
torch.nn.LSTM: QuantizableLSTMSplitGates
|
||||
},
|
||||
"observed_to_quantized_custom_module_class": {
|
||||
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
|
||||
},
|
||||
}
|
||||
)
|
||||
lstm_prepared = torch.ao.quantization.prepare(
|
||||
lstm, prepare_custom_config_dict=custom_config_dict
|
||||
)
|
||||
lstm_prepared = torch.ao.quantization.prepare(lstm)
|
||||
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
|
||||
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
|
||||
self.assertEqual(
|
||||
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
|
||||
)
|
||||
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
|
||||
assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
|
||||
|
||||
# Calibrate
|
||||
y = lstm_prepared(x)
|
||||
self.assertEqual(y_ref, y)
|
||||
|
||||
# Quantize
|
||||
lstm_quantized = torch.ao.quantization.convert(
|
||||
lstm_prepared, convert_custom_config_dict=custom_config_dict
|
||||
)
|
||||
lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
|
||||
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
|
||||
qy = lstm_quantized(qx)
|
||||
|
||||
|
|
|
|||
|
|
@ -21,11 +21,6 @@ class LSTMCell(torch.nn.Module):
|
|||
|
||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
||||
|
||||
`split_gates`: specify True to compute the input/forget/cell/output gates separately
|
||||
to avoid an intermediate tensor which is subsequently chunk'd. This optimization can
|
||||
be beneficial for on-device inference latency. This flag is cascaded down from the
|
||||
parent classes.
|
||||
|
||||
Examples::
|
||||
|
||||
>>> import torch.ao.nn.quantizable as nnqa
|
||||
|
|
@ -39,7 +34,6 @@ class LSTMCell(torch.nn.Module):
|
|||
... output.append(hx)
|
||||
"""
|
||||
_FLOAT_MODULE = torch.nn.LSTMCell
|
||||
__constants__ = ["split_gates"] # for jit.script
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -48,37 +42,20 @@ class LSTMCell(torch.nn.Module):
|
|||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
*,
|
||||
split_gates=False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.input_size = input_dim
|
||||
self.hidden_size = hidden_dim
|
||||
self.bias = bias
|
||||
self.split_gates = split_gates
|
||||
|
||||
if not split_gates:
|
||||
self.igates: torch.nn.Module = torch.nn.Linear(
|
||||
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.hgates: torch.nn.Module = torch.nn.Linear(
|
||||
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional()
|
||||
else:
|
||||
# keep separate Linear layers for each gate
|
||||
self.igates = torch.nn.ModuleDict()
|
||||
self.hgates = torch.nn.ModuleDict()
|
||||
self.gates = torch.nn.ModuleDict()
|
||||
for g in ["input", "forget", "cell", "output"]:
|
||||
self.igates[g] = torch.nn.Linear(
|
||||
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.hgates[g] = torch.nn.Linear(
|
||||
hidden_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.gates[g] = torch.ao.nn.quantized.FloatFunctional()
|
||||
self.igates = torch.nn.Linear(
|
||||
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.hgates = torch.nn.Linear(
|
||||
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.gates = torch.ao.nn.quantized.FloatFunctional()
|
||||
|
||||
self.input_gate = torch.nn.Sigmoid()
|
||||
self.forget_gate = torch.nn.Sigmoid()
|
||||
|
|
@ -103,29 +80,16 @@ class LSTMCell(torch.nn.Module):
|
|||
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
||||
hx, cx = hidden
|
||||
|
||||
if not self.split_gates:
|
||||
igates = self.igates(x)
|
||||
hgates = self.hgates(hx)
|
||||
gates = self.gates.add(igates, hgates)
|
||||
igates = self.igates(x)
|
||||
hgates = self.hgates(hx)
|
||||
gates = self.gates.add(igates, hgates)
|
||||
|
||||
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
||||
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
||||
|
||||
input_gate = self.input_gate(input_gate)
|
||||
forget_gate = self.forget_gate(forget_gate)
|
||||
cell_gate = self.cell_gate(cell_gate)
|
||||
out_gate = self.output_gate(out_gate)
|
||||
else:
|
||||
# apply each input + hidden projection and add together
|
||||
gate = {}
|
||||
for (key, gates), igates, hgates in zip(
|
||||
self.gates.items(), self.igates.values(), self.hgates.values()
|
||||
):
|
||||
gate[key] = gates.add(igates(x), hgates(hx))
|
||||
|
||||
input_gate = self.input_gate(gate["input"])
|
||||
forget_gate = self.forget_gate(gate["forget"])
|
||||
cell_gate = self.cell_gate(gate["cell"])
|
||||
out_gate = self.output_gate(gate["output"])
|
||||
input_gate = self.input_gate(input_gate)
|
||||
forget_gate = self.forget_gate(forget_gate)
|
||||
cell_gate = self.cell_gate(cell_gate)
|
||||
out_gate = self.output_gate(out_gate)
|
||||
|
||||
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
||||
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
||||
|
|
@ -158,7 +122,7 @@ class LSTMCell(torch.nn.Module):
|
|||
return "QuantizableLSTMCell"
|
||||
|
||||
@classmethod
|
||||
def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False):
|
||||
def from_params(cls, wi, wh, bi=None, bh=None):
|
||||
"""Uses the weights and biases to create a new LSTM cell.
|
||||
|
||||
Args:
|
||||
|
|
@ -168,52 +132,25 @@ class LSTMCell(torch.nn.Module):
|
|||
assert (bi is None) == (bh is None) # Either both None or both have values
|
||||
input_size = wi.shape[1]
|
||||
hidden_size = wh.shape[1]
|
||||
cell = cls(
|
||||
input_dim=input_size,
|
||||
hidden_dim=hidden_size,
|
||||
bias=(bi is not None),
|
||||
split_gates=split_gates,
|
||||
)
|
||||
|
||||
if not split_gates:
|
||||
cell.igates.weight = torch.nn.Parameter(wi)
|
||||
if bi is not None:
|
||||
cell.igates.bias = torch.nn.Parameter(bi)
|
||||
cell.hgates.weight = torch.nn.Parameter(wh)
|
||||
if bh is not None:
|
||||
cell.hgates.bias = torch.nn.Parameter(bh)
|
||||
else:
|
||||
# split weight/bias
|
||||
for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]):
|
||||
for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()):
|
||||
gate.weight = torch.nn.Parameter(w_chunk)
|
||||
|
||||
if b is not None:
|
||||
for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()):
|
||||
gate.bias = torch.nn.Parameter(b_chunk)
|
||||
|
||||
cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
|
||||
cell.igates.weight = torch.nn.Parameter(wi)
|
||||
if bi is not None:
|
||||
cell.igates.bias = torch.nn.Parameter(bi)
|
||||
cell.hgates.weight = torch.nn.Parameter(wh)
|
||||
if bh is not None:
|
||||
cell.hgates.bias = torch.nn.Parameter(bh)
|
||||
return cell
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False):
|
||||
def from_float(cls, other, use_precomputed_fake_quant=False):
|
||||
assert type(other) == cls._FLOAT_MODULE
|
||||
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
|
||||
observed = cls.from_params(
|
||||
other.weight_ih,
|
||||
other.weight_hh,
|
||||
other.bias_ih,
|
||||
other.bias_hh,
|
||||
split_gates=split_gates,
|
||||
other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
|
||||
)
|
||||
observed.qconfig = other.qconfig
|
||||
observed.igates.qconfig = other.qconfig
|
||||
observed.hgates.qconfig = other.qconfig
|
||||
if split_gates:
|
||||
# also apply qconfig directly to Linear modules
|
||||
for g in observed.igates.values():
|
||||
g.qconfig = other.qconfig
|
||||
for g in observed.hgates.values():
|
||||
g.qconfig = other.qconfig
|
||||
return observed
|
||||
|
||||
|
||||
|
|
@ -231,14 +168,10 @@ class _LSTMSingleLayer(torch.nn.Module):
|
|||
bias: bool = True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
*,
|
||||
split_gates=False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.cell = LSTMCell(
|
||||
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
|
||||
)
|
||||
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
result = []
|
||||
|
|
@ -252,9 +185,7 @@ class _LSTMSingleLayer(torch.nn.Module):
|
|||
@classmethod
|
||||
def from_params(cls, *args, **kwargs):
|
||||
cell = LSTMCell.from_params(*args, **kwargs)
|
||||
layer = cls(
|
||||
cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates
|
||||
)
|
||||
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
|
||||
layer.cell = cell
|
||||
return layer
|
||||
|
||||
|
|
@ -271,23 +202,17 @@ class _LSTMLayer(torch.nn.Module):
|
|||
bidirectional: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
*,
|
||||
split_gates=False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
self.batch_first = batch_first
|
||||
self.bidirectional = bidirectional
|
||||
self.layer_fw = _LSTMSingleLayer(
|
||||
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
|
||||
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
if self.bidirectional:
|
||||
self.layer_bw = _LSTMSingleLayer(
|
||||
input_dim,
|
||||
hidden_dim,
|
||||
bias=bias,
|
||||
split_gates=split_gates,
|
||||
**factory_kwargs,
|
||||
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
|
||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||
|
|
@ -358,34 +283,22 @@ class _LSTMLayer(torch.nn.Module):
|
|||
bias = kwargs.get("bias", other.bias)
|
||||
batch_first = kwargs.get("batch_first", other.batch_first)
|
||||
bidirectional = kwargs.get("bidirectional", other.bidirectional)
|
||||
split_gates = kwargs.get("split_gates", False)
|
||||
|
||||
layer = cls(
|
||||
input_size,
|
||||
hidden_size,
|
||||
bias,
|
||||
batch_first,
|
||||
bidirectional,
|
||||
split_gates=split_gates,
|
||||
)
|
||||
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
|
||||
layer.qconfig = getattr(other, "qconfig", qconfig)
|
||||
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
||||
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
||||
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
|
||||
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
|
||||
|
||||
layer.layer_fw = _LSTMSingleLayer.from_params(
|
||||
wi, wh, bi, bh, split_gates=split_gates
|
||||
)
|
||||
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
|
||||
if other.bidirectional:
|
||||
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
|
||||
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
|
||||
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
|
||||
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
|
||||
layer.layer_bw = _LSTMSingleLayer.from_params(
|
||||
wi, wh, bi, bh, split_gates=split_gates
|
||||
)
|
||||
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||
return layer
|
||||
|
||||
|
||||
|
|
@ -429,8 +342,6 @@ class LSTM(torch.nn.Module):
|
|||
bidirectional: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
*,
|
||||
split_gates: bool = False,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
|
|
@ -475,7 +386,6 @@ class LSTM(torch.nn.Module):
|
|||
self.bias,
|
||||
batch_first=False,
|
||||
bidirectional=self.bidirectional,
|
||||
split_gates=split_gates,
|
||||
**factory_kwargs,
|
||||
)
|
||||
]
|
||||
|
|
@ -486,7 +396,6 @@ class LSTM(torch.nn.Module):
|
|||
self.bias,
|
||||
batch_first=False,
|
||||
bidirectional=self.bidirectional,
|
||||
split_gates=split_gates,
|
||||
**factory_kwargs,
|
||||
)
|
||||
for layer in range(1, num_layers)
|
||||
|
|
@ -552,7 +461,7 @@ class LSTM(torch.nn.Module):
|
|||
return "QuantizableLSTM"
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, other, qconfig=None, split_gates=False):
|
||||
def from_float(cls, other, qconfig=None):
|
||||
assert isinstance(other, cls._FLOAT_MODULE)
|
||||
assert hasattr(other, "qconfig") or qconfig
|
||||
observed = cls(
|
||||
|
|
@ -563,12 +472,11 @@ class LSTM(torch.nn.Module):
|
|||
other.batch_first,
|
||||
other.dropout,
|
||||
other.bidirectional,
|
||||
split_gates=split_gates,
|
||||
)
|
||||
observed.qconfig = getattr(other, "qconfig", qconfig)
|
||||
for idx in range(other.num_layers):
|
||||
observed.layers[idx] = _LSTMLayer.from_float(
|
||||
other, idx, qconfig, batch_first=False, split_gates=split_gates
|
||||
other, idx, qconfig, batch_first=False
|
||||
)
|
||||
|
||||
# Prepare the model
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
|
|||
|
||||
@classmethod
|
||||
def from_observed(cls, other):
|
||||
assert isinstance(other, cls._FLOAT_MODULE)
|
||||
assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
|
||||
converted = torch.ao.quantization.convert(
|
||||
other, inplace=False, remove_qconfig=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ def _get_lstm_with_individually_observed_parts(
|
|||
tanh_obs_ctr: Optional[_PartialWrapper] = None,
|
||||
cell_state_obs_ctr: Optional[_PartialWrapper] = None,
|
||||
hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
|
||||
split_gates: bool = False,
|
||||
) -> torch.ao.nn.quantizable.LSTM:
|
||||
"""
|
||||
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
|
||||
|
|
@ -76,7 +75,6 @@ def _get_lstm_with_individually_observed_parts(
|
|||
float_lstm.batch_first,
|
||||
float_lstm.dropout,
|
||||
float_lstm.bidirectional,
|
||||
split_gates=split_gates,
|
||||
)
|
||||
quantizable_lstm.qconfig = float_lstm.qconfig
|
||||
|
||||
|
|
@ -84,11 +82,7 @@ def _get_lstm_with_individually_observed_parts(
|
|||
quantizable_lstm.layers[
|
||||
idx
|
||||
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
|
||||
float_lstm,
|
||||
idx,
|
||||
float_lstm.qconfig,
|
||||
batch_first=False,
|
||||
split_gates=split_gates,
|
||||
float_lstm, idx, float_lstm.qconfig, batch_first=False
|
||||
)
|
||||
|
||||
# Build QConfigMapping for the LSTM cell
|
||||
|
|
@ -111,25 +105,13 @@ def _get_lstm_with_individually_observed_parts(
|
|||
# to configure these ops in FX graph mode quantization today. This is because
|
||||
# the FloatFunctional modules simply disappear from the graph after tracing.
|
||||
# In the future, we should rewrite quantizable LSTM without FloatFunctionals.
|
||||
if not split_gates:
|
||||
op_index_to_activation_post_process_ctr = {
|
||||
(torch.add, 0): linear_output_obs_ctr, # gates.add
|
||||
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
||||
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
||||
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
||||
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
||||
}
|
||||
else:
|
||||
op_index_to_activation_post_process_ctr = {
|
||||
(torch.add, 0): linear_output_obs_ctr, # gates.add (input)
|
||||
(torch.add, 1): linear_output_obs_ctr, # gates.add (forget)
|
||||
(torch.add, 2): linear_output_obs_ctr, # gates.add (cell)
|
||||
(torch.add, 3): linear_output_obs_ctr, # gates.add (output)
|
||||
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
||||
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
||||
(torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
||||
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
||||
}
|
||||
op_index_to_activation_post_process_ctr = {
|
||||
(torch.add, 0): linear_output_obs_ctr, # gates.add
|
||||
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
||||
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
||||
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
||||
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
||||
}
|
||||
add_count = 0
|
||||
mul_count = 0
|
||||
for node in cell.graph.nodes:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user