Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"

This reverts commit 3fcf66f61f.

Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
PyTorch MergeBot 2024-11-22 15:54:05 +00:00
parent 080f992d68
commit cf1d95a965
4 changed files with 48 additions and 183 deletions

View File

@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
@override_qengines @override_qengines
def test_custom_module_lstm(self): def test_custom_module_lstm(self):
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
@classmethod
def from_float(cls, other, qconfig=None):
return super().from_float(other, qconfig, split_gates=True)
qengine = torch.backends.quantized.engine qengine = torch.backends.quantized.engine
batch_size = 4 batch_size = 4
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
Bias = [False, True] Bias = [False, True]
Batch_first = [False, True] Batch_first = [False, True]
Bidirectional = [False, True] Bidirectional = [False, True]
Split_gates = [False, True]
dtype = np.uint8 dtype = np.uint8
qtype = torch.quint8 qtype = torch.quint8
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
x = qx.dequantize() x = qx.dequantize()
with torch.no_grad(): with torch.no_grad():
for bias, batch_first, bidirectional, split_gates in itertools.product( for bias, batch_first, bidirectional in itertools.product(
Bias, Batch_first, Bidirectional, Split_gates): Bias, Batch_first, Bidirectional):
# Assume 12dB is sufficient for functional equivalence # Assume 12dB is sufficient for functional equivalence
# Without the bias, linear performs poorly # Without the bias, linear performs poorly
min_power = 10 if bias else 5 min_power = 10 if bias else 5
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
# Prepare # Prepare
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine) lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
custom_config_dict = ( lstm_prepared = torch.ao.quantization.prepare(lstm)
None
if not split_gates
else { # switch to class with split_gates True via from_float
"float_to_observed_custom_module_class": {
torch.nn.LSTM: QuantizableLSTMSplitGates
},
"observed_to_quantized_custom_module_class": {
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
},
}
)
lstm_prepared = torch.ao.quantization.prepare(
lstm, prepare_custom_config_dict=custom_config_dict
)
self.assertTrue(hasattr(lstm_prepared[0], 'layers')) self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
self.assertEqual(num_layers, len(lstm_prepared[0].layers)) self.assertEqual(num_layers, len(lstm_prepared[0].layers))
self.assertEqual( assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
)
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
# Calibrate # Calibrate
y = lstm_prepared(x) y = lstm_prepared(x)
self.assertEqual(y_ref, y) self.assertEqual(y_ref, y)
# Quantize # Quantize
lstm_quantized = torch.ao.quantization.convert( lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
lstm_prepared, convert_custom_config_dict=custom_config_dict
)
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
qy = lstm_quantized(qx) qy = lstm_quantized(qx)

View File

@ -21,11 +21,6 @@ class LSTMCell(torch.nn.Module):
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell` For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
`split_gates`: specify True to compute the input/forget/cell/output gates separately
to avoid an intermediate tensor which is subsequently chunk'd. This optimization can
be beneficial for on-device inference latency. This flag is cascaded down from the
parent classes.
Examples:: Examples::
>>> import torch.ao.nn.quantizable as nnqa >>> import torch.ao.nn.quantizable as nnqa
@ -39,7 +34,6 @@ class LSTMCell(torch.nn.Module):
... output.append(hx) ... output.append(hx)
""" """
_FLOAT_MODULE = torch.nn.LSTMCell _FLOAT_MODULE = torch.nn.LSTMCell
__constants__ = ["split_gates"] # for jit.script
def __init__( def __init__(
self, self,
@ -48,37 +42,20 @@ class LSTMCell(torch.nn.Module):
bias: bool = True, bias: bool = True,
device=None, device=None,
dtype=None, dtype=None,
*,
split_gates=False,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.input_size = input_dim self.input_size = input_dim
self.hidden_size = hidden_dim self.hidden_size = hidden_dim
self.bias = bias self.bias = bias
self.split_gates = split_gates
if not split_gates: self.igates = torch.nn.Linear(
self.igates: torch.nn.Module = torch.nn.Linear( input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs )
) self.hgates = torch.nn.Linear(
self.hgates: torch.nn.Module = torch.nn.Linear( hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs )
) self.gates = torch.ao.nn.quantized.FloatFunctional()
self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional()
else:
# keep separate Linear layers for each gate
self.igates = torch.nn.ModuleDict()
self.hgates = torch.nn.ModuleDict()
self.gates = torch.nn.ModuleDict()
for g in ["input", "forget", "cell", "output"]:
self.igates[g] = torch.nn.Linear(
input_dim, hidden_dim, bias=bias, **factory_kwargs
)
self.hgates[g] = torch.nn.Linear(
hidden_dim, hidden_dim, bias=bias, **factory_kwargs
)
self.gates[g] = torch.ao.nn.quantized.FloatFunctional()
self.input_gate = torch.nn.Sigmoid() self.input_gate = torch.nn.Sigmoid()
self.forget_gate = torch.nn.Sigmoid() self.forget_gate = torch.nn.Sigmoid()
@ -103,29 +80,16 @@ class LSTMCell(torch.nn.Module):
hidden = self.initialize_hidden(x.shape[0], x.is_quantized) hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
hx, cx = hidden hx, cx = hidden
if not self.split_gates: igates = self.igates(x)
igates = self.igates(x) hgates = self.hgates(hx)
hgates = self.hgates(hx) gates = self.gates.add(igates, hgates)
gates = self.gates.add(igates, hgates)
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1) input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
input_gate = self.input_gate(input_gate) input_gate = self.input_gate(input_gate)
forget_gate = self.forget_gate(forget_gate) forget_gate = self.forget_gate(forget_gate)
cell_gate = self.cell_gate(cell_gate) cell_gate = self.cell_gate(cell_gate)
out_gate = self.output_gate(out_gate) out_gate = self.output_gate(out_gate)
else:
# apply each input + hidden projection and add together
gate = {}
for (key, gates), igates, hgates in zip(
self.gates.items(), self.igates.values(), self.hgates.values()
):
gate[key] = gates.add(igates(x), hgates(hx))
input_gate = self.input_gate(gate["input"])
forget_gate = self.forget_gate(gate["forget"])
cell_gate = self.cell_gate(gate["cell"])
out_gate = self.output_gate(gate["output"])
fgate_cx = self.fgate_cx.mul(forget_gate, cx) fgate_cx = self.fgate_cx.mul(forget_gate, cx)
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate) igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
@ -158,7 +122,7 @@ class LSTMCell(torch.nn.Module):
return "QuantizableLSTMCell" return "QuantizableLSTMCell"
@classmethod @classmethod
def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False): def from_params(cls, wi, wh, bi=None, bh=None):
"""Uses the weights and biases to create a new LSTM cell. """Uses the weights and biases to create a new LSTM cell.
Args: Args:
@ -168,52 +132,25 @@ class LSTMCell(torch.nn.Module):
assert (bi is None) == (bh is None) # Either both None or both have values assert (bi is None) == (bh is None) # Either both None or both have values
input_size = wi.shape[1] input_size = wi.shape[1]
hidden_size = wh.shape[1] hidden_size = wh.shape[1]
cell = cls( cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
input_dim=input_size, cell.igates.weight = torch.nn.Parameter(wi)
hidden_dim=hidden_size, if bi is not None:
bias=(bi is not None), cell.igates.bias = torch.nn.Parameter(bi)
split_gates=split_gates, cell.hgates.weight = torch.nn.Parameter(wh)
) if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
if not split_gates:
cell.igates.weight = torch.nn.Parameter(wi)
if bi is not None:
cell.igates.bias = torch.nn.Parameter(bi)
cell.hgates.weight = torch.nn.Parameter(wh)
if bh is not None:
cell.hgates.bias = torch.nn.Parameter(bh)
else:
# split weight/bias
for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]):
for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()):
gate.weight = torch.nn.Parameter(w_chunk)
if b is not None:
for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()):
gate.bias = torch.nn.Parameter(b_chunk)
return cell return cell
@classmethod @classmethod
def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False): def from_float(cls, other, use_precomputed_fake_quant=False):
assert type(other) == cls._FLOAT_MODULE assert type(other) == cls._FLOAT_MODULE
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'" assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
observed = cls.from_params( observed = cls.from_params(
other.weight_ih, other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
other.weight_hh,
other.bias_ih,
other.bias_hh,
split_gates=split_gates,
) )
observed.qconfig = other.qconfig observed.qconfig = other.qconfig
observed.igates.qconfig = other.qconfig observed.igates.qconfig = other.qconfig
observed.hgates.qconfig = other.qconfig observed.hgates.qconfig = other.qconfig
if split_gates:
# also apply qconfig directly to Linear modules
for g in observed.igates.values():
g.qconfig = other.qconfig
for g in observed.hgates.values():
g.qconfig = other.qconfig
return observed return observed
@ -231,14 +168,10 @@ class _LSTMSingleLayer(torch.nn.Module):
bias: bool = True, bias: bool = True,
device=None, device=None,
dtype=None, dtype=None,
*,
split_gates=False,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.cell = LSTMCell( self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
)
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
result = [] result = []
@ -252,9 +185,7 @@ class _LSTMSingleLayer(torch.nn.Module):
@classmethod @classmethod
def from_params(cls, *args, **kwargs): def from_params(cls, *args, **kwargs):
cell = LSTMCell.from_params(*args, **kwargs) cell = LSTMCell.from_params(*args, **kwargs)
layer = cls( layer = cls(cell.input_size, cell.hidden_size, cell.bias)
cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates
)
layer.cell = cell layer.cell = cell
return layer return layer
@ -271,23 +202,17 @@ class _LSTMLayer(torch.nn.Module):
bidirectional: bool = False, bidirectional: bool = False,
device=None, device=None,
dtype=None, dtype=None,
*,
split_gates=False,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
self.batch_first = batch_first self.batch_first = batch_first
self.bidirectional = bidirectional self.bidirectional = bidirectional
self.layer_fw = _LSTMSingleLayer( self.layer_fw = _LSTMSingleLayer(
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs input_dim, hidden_dim, bias=bias, **factory_kwargs
) )
if self.bidirectional: if self.bidirectional:
self.layer_bw = _LSTMSingleLayer( self.layer_bw = _LSTMSingleLayer(
input_dim, input_dim, hidden_dim, bias=bias, **factory_kwargs
hidden_dim,
bias=bias,
split_gates=split_gates,
**factory_kwargs,
) )
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None): def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
@ -358,34 +283,22 @@ class _LSTMLayer(torch.nn.Module):
bias = kwargs.get("bias", other.bias) bias = kwargs.get("bias", other.bias)
batch_first = kwargs.get("batch_first", other.batch_first) batch_first = kwargs.get("batch_first", other.batch_first)
bidirectional = kwargs.get("bidirectional", other.bidirectional) bidirectional = kwargs.get("bidirectional", other.bidirectional)
split_gates = kwargs.get("split_gates", False)
layer = cls( layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
input_size,
hidden_size,
bias,
batch_first,
bidirectional,
split_gates=split_gates,
)
layer.qconfig = getattr(other, "qconfig", qconfig) layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}") wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}") wh = getattr(other, f"weight_hh_l{layer_idx}")
bi = getattr(other, f"bias_ih_l{layer_idx}", None) bi = getattr(other, f"bias_ih_l{layer_idx}", None)
bh = getattr(other, f"bias_hh_l{layer_idx}", None) bh = getattr(other, f"bias_hh_l{layer_idx}", None)
layer.layer_fw = _LSTMSingleLayer.from_params( layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
wi, wh, bi, bh, split_gates=split_gates
)
if other.bidirectional: if other.bidirectional:
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse") wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse") wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None) bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None) bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
layer.layer_bw = _LSTMSingleLayer.from_params( layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
wi, wh, bi, bh, split_gates=split_gates
)
return layer return layer
@ -429,8 +342,6 @@ class LSTM(torch.nn.Module):
bidirectional: bool = False, bidirectional: bool = False,
device=None, device=None,
dtype=None, dtype=None,
*,
split_gates: bool = False,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
@ -475,7 +386,6 @@ class LSTM(torch.nn.Module):
self.bias, self.bias,
batch_first=False, batch_first=False,
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
split_gates=split_gates,
**factory_kwargs, **factory_kwargs,
) )
] ]
@ -486,7 +396,6 @@ class LSTM(torch.nn.Module):
self.bias, self.bias,
batch_first=False, batch_first=False,
bidirectional=self.bidirectional, bidirectional=self.bidirectional,
split_gates=split_gates,
**factory_kwargs, **factory_kwargs,
) )
for layer in range(1, num_layers) for layer in range(1, num_layers)
@ -552,7 +461,7 @@ class LSTM(torch.nn.Module):
return "QuantizableLSTM" return "QuantizableLSTM"
@classmethod @classmethod
def from_float(cls, other, qconfig=None, split_gates=False): def from_float(cls, other, qconfig=None):
assert isinstance(other, cls._FLOAT_MODULE) assert isinstance(other, cls._FLOAT_MODULE)
assert hasattr(other, "qconfig") or qconfig assert hasattr(other, "qconfig") or qconfig
observed = cls( observed = cls(
@ -563,12 +472,11 @@ class LSTM(torch.nn.Module):
other.batch_first, other.batch_first,
other.dropout, other.dropout,
other.bidirectional, other.bidirectional,
split_gates=split_gates,
) )
observed.qconfig = getattr(other, "qconfig", qconfig) observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers): for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float( observed.layers[idx] = _LSTMLayer.from_float(
other, idx, qconfig, batch_first=False, split_gates=split_gates other, idx, qconfig, batch_first=False
) )
# Prepare the model # Prepare the model

View File

@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
@classmethod @classmethod
def from_observed(cls, other): def from_observed(cls, other):
assert isinstance(other, cls._FLOAT_MODULE) assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
converted = torch.ao.quantization.convert( converted = torch.ao.quantization.convert(
other, inplace=False, remove_qconfig=True other, inplace=False, remove_qconfig=True
) )

View File

@ -25,7 +25,6 @@ def _get_lstm_with_individually_observed_parts(
tanh_obs_ctr: Optional[_PartialWrapper] = None, tanh_obs_ctr: Optional[_PartialWrapper] = None,
cell_state_obs_ctr: Optional[_PartialWrapper] = None, cell_state_obs_ctr: Optional[_PartialWrapper] = None,
hidden_state_obs_ctr: Optional[_PartialWrapper] = None, hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
split_gates: bool = False,
) -> torch.ao.nn.quantizable.LSTM: ) -> torch.ao.nn.quantizable.LSTM:
""" """
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM` Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
@ -76,7 +75,6 @@ def _get_lstm_with_individually_observed_parts(
float_lstm.batch_first, float_lstm.batch_first,
float_lstm.dropout, float_lstm.dropout,
float_lstm.bidirectional, float_lstm.bidirectional,
split_gates=split_gates,
) )
quantizable_lstm.qconfig = float_lstm.qconfig quantizable_lstm.qconfig = float_lstm.qconfig
@ -84,11 +82,7 @@ def _get_lstm_with_individually_observed_parts(
quantizable_lstm.layers[ quantizable_lstm.layers[
idx idx
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float( ] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
float_lstm, float_lstm, idx, float_lstm.qconfig, batch_first=False
idx,
float_lstm.qconfig,
batch_first=False,
split_gates=split_gates,
) )
# Build QConfigMapping for the LSTM cell # Build QConfigMapping for the LSTM cell
@ -111,25 +105,13 @@ def _get_lstm_with_individually_observed_parts(
# to configure these ops in FX graph mode quantization today. This is because # to configure these ops in FX graph mode quantization today. This is because
# the FloatFunctional modules simply disappear from the graph after tracing. # the FloatFunctional modules simply disappear from the graph after tracing.
# In the future, we should rewrite quantizable LSTM without FloatFunctionals. # In the future, we should rewrite quantizable LSTM without FloatFunctionals.
if not split_gates: op_index_to_activation_post_process_ctr = {
op_index_to_activation_post_process_ctr = { (torch.add, 0): linear_output_obs_ctr, # gates.add
(torch.add, 0): linear_output_obs_ctr, # gates.add (torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul (torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul (torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add (torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul }
}
else:
op_index_to_activation_post_process_ctr = {
(torch.add, 0): linear_output_obs_ctr, # gates.add (input)
(torch.add, 1): linear_output_obs_ctr, # gates.add (forget)
(torch.add, 2): linear_output_obs_ctr, # gates.add (cell)
(torch.add, 3): linear_output_obs_ctr, # gates.add (output)
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
(torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
}
add_count = 0 add_count = 0
mul_count = 0 mul_count = 0
for node in cell.graph.nodes: for node in cell.graph.nodes: