mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add option to split Linear gates for Quantizable LSTM into separate ops (#140868)"
This reverts commit 3fcf66f61f.
Reverted https://github.com/pytorch/pytorch/pull/140868 on behalf of https://github.com/huydhn due to Sorry for reverting your change but I think lint is failing on this in trunk ([comment](https://github.com/pytorch/pytorch/pull/140868#issuecomment-2494076202))
This commit is contained in:
parent
080f992d68
commit
cf1d95a965
|
|
@ -2917,11 +2917,6 @@ class TestQuantizedOps(TestCase):
|
||||||
|
|
||||||
@override_qengines
|
@override_qengines
|
||||||
def test_custom_module_lstm(self):
|
def test_custom_module_lstm(self):
|
||||||
class QuantizableLSTMSplitGates(torch.ao.nn.quantizable.LSTM):
|
|
||||||
@classmethod
|
|
||||||
def from_float(cls, other, qconfig=None):
|
|
||||||
return super().from_float(other, qconfig, split_gates=True)
|
|
||||||
|
|
||||||
qengine = torch.backends.quantized.engine
|
qengine = torch.backends.quantized.engine
|
||||||
|
|
||||||
batch_size = 4
|
batch_size = 4
|
||||||
|
|
@ -2936,7 +2931,6 @@ class TestQuantizedOps(TestCase):
|
||||||
Bias = [False, True]
|
Bias = [False, True]
|
||||||
Batch_first = [False, True]
|
Batch_first = [False, True]
|
||||||
Bidirectional = [False, True]
|
Bidirectional = [False, True]
|
||||||
Split_gates = [False, True]
|
|
||||||
|
|
||||||
dtype = np.uint8
|
dtype = np.uint8
|
||||||
qtype = torch.quint8
|
qtype = torch.quint8
|
||||||
|
|
@ -2949,8 +2943,8 @@ class TestQuantizedOps(TestCase):
|
||||||
x = qx.dequantize()
|
x = qx.dequantize()
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for bias, batch_first, bidirectional, split_gates in itertools.product(
|
for bias, batch_first, bidirectional in itertools.product(
|
||||||
Bias, Batch_first, Bidirectional, Split_gates):
|
Bias, Batch_first, Bidirectional):
|
||||||
# Assume 12dB is sufficient for functional equivalence
|
# Assume 12dB is sufficient for functional equivalence
|
||||||
# Without the bias, linear performs poorly
|
# Without the bias, linear performs poorly
|
||||||
min_power = 10 if bias else 5
|
min_power = 10 if bias else 5
|
||||||
|
|
@ -2974,36 +2968,17 @@ class TestQuantizedOps(TestCase):
|
||||||
|
|
||||||
# Prepare
|
# Prepare
|
||||||
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
lstm.qconfig = torch.ao.quantization.get_default_qconfig(qengine)
|
||||||
custom_config_dict = (
|
lstm_prepared = torch.ao.quantization.prepare(lstm)
|
||||||
None
|
|
||||||
if not split_gates
|
|
||||||
else { # switch to class with split_gates True via from_float
|
|
||||||
"float_to_observed_custom_module_class": {
|
|
||||||
torch.nn.LSTM: QuantizableLSTMSplitGates
|
|
||||||
},
|
|
||||||
"observed_to_quantized_custom_module_class": {
|
|
||||||
QuantizableLSTMSplitGates: torch.ao.nn.quantized.LSTM,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
lstm_prepared = torch.ao.quantization.prepare(
|
|
||||||
lstm, prepare_custom_config_dict=custom_config_dict
|
|
||||||
)
|
|
||||||
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
|
self.assertTrue(hasattr(lstm_prepared[0], 'layers'))
|
||||||
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
|
self.assertEqual(num_layers, len(lstm_prepared[0].layers))
|
||||||
self.assertEqual(
|
assert type(lstm_prepared[0]) == torch.ao.nn.quantizable.LSTM
|
||||||
lstm_prepared[0].layers[0].layer_fw.cell.split_gates, split_gates
|
|
||||||
)
|
|
||||||
assert isinstance(lstm_prepared[0], torch.ao.nn.quantizable.LSTM)
|
|
||||||
|
|
||||||
# Calibrate
|
# Calibrate
|
||||||
y = lstm_prepared(x)
|
y = lstm_prepared(x)
|
||||||
self.assertEqual(y_ref, y)
|
self.assertEqual(y_ref, y)
|
||||||
|
|
||||||
# Quantize
|
# Quantize
|
||||||
lstm_quantized = torch.ao.quantization.convert(
|
lstm_quantized = torch.ao.quantization.convert(lstm_prepared)
|
||||||
lstm_prepared, convert_custom_config_dict=custom_config_dict
|
|
||||||
)
|
|
||||||
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
|
assert type(lstm_quantized[0]) == torch.ao.nn.quantized.LSTM
|
||||||
qy = lstm_quantized(qx)
|
qy = lstm_quantized(qx)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,11 +21,6 @@ class LSTMCell(torch.nn.Module):
|
||||||
|
|
||||||
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
For the description and the argument types, please, refer to :class:`~torch.nn.LSTMCell`
|
||||||
|
|
||||||
`split_gates`: specify True to compute the input/forget/cell/output gates separately
|
|
||||||
to avoid an intermediate tensor which is subsequently chunk'd. This optimization can
|
|
||||||
be beneficial for on-device inference latency. This flag is cascaded down from the
|
|
||||||
parent classes.
|
|
||||||
|
|
||||||
Examples::
|
Examples::
|
||||||
|
|
||||||
>>> import torch.ao.nn.quantizable as nnqa
|
>>> import torch.ao.nn.quantizable as nnqa
|
||||||
|
|
@ -39,7 +34,6 @@ class LSTMCell(torch.nn.Module):
|
||||||
... output.append(hx)
|
... output.append(hx)
|
||||||
"""
|
"""
|
||||||
_FLOAT_MODULE = torch.nn.LSTMCell
|
_FLOAT_MODULE = torch.nn.LSTMCell
|
||||||
__constants__ = ["split_gates"] # for jit.script
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -48,37 +42,20 @@ class LSTMCell(torch.nn.Module):
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
*,
|
|
||||||
split_gates=False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.input_size = input_dim
|
self.input_size = input_dim
|
||||||
self.hidden_size = hidden_dim
|
self.hidden_size = hidden_dim
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
self.split_gates = split_gates
|
|
||||||
|
|
||||||
if not split_gates:
|
self.igates = torch.nn.Linear(
|
||||||
self.igates: torch.nn.Module = torch.nn.Linear(
|
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||||
input_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
)
|
||||||
)
|
self.hgates = torch.nn.Linear(
|
||||||
self.hgates: torch.nn.Module = torch.nn.Linear(
|
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
||||||
hidden_dim, 4 * hidden_dim, bias=bias, **factory_kwargs
|
)
|
||||||
)
|
self.gates = torch.ao.nn.quantized.FloatFunctional()
|
||||||
self.gates: torch.nn.Module = torch.ao.nn.quantized.FloatFunctional()
|
|
||||||
else:
|
|
||||||
# keep separate Linear layers for each gate
|
|
||||||
self.igates = torch.nn.ModuleDict()
|
|
||||||
self.hgates = torch.nn.ModuleDict()
|
|
||||||
self.gates = torch.nn.ModuleDict()
|
|
||||||
for g in ["input", "forget", "cell", "output"]:
|
|
||||||
self.igates[g] = torch.nn.Linear(
|
|
||||||
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.hgates[g] = torch.nn.Linear(
|
|
||||||
hidden_dim, hidden_dim, bias=bias, **factory_kwargs
|
|
||||||
)
|
|
||||||
self.gates[g] = torch.ao.nn.quantized.FloatFunctional()
|
|
||||||
|
|
||||||
self.input_gate = torch.nn.Sigmoid()
|
self.input_gate = torch.nn.Sigmoid()
|
||||||
self.forget_gate = torch.nn.Sigmoid()
|
self.forget_gate = torch.nn.Sigmoid()
|
||||||
|
|
@ -103,29 +80,16 @@ class LSTMCell(torch.nn.Module):
|
||||||
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
hidden = self.initialize_hidden(x.shape[0], x.is_quantized)
|
||||||
hx, cx = hidden
|
hx, cx = hidden
|
||||||
|
|
||||||
if not self.split_gates:
|
igates = self.igates(x)
|
||||||
igates = self.igates(x)
|
hgates = self.hgates(hx)
|
||||||
hgates = self.hgates(hx)
|
gates = self.gates.add(igates, hgates)
|
||||||
gates = self.gates.add(igates, hgates)
|
|
||||||
|
|
||||||
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
input_gate, forget_gate, cell_gate, out_gate = gates.chunk(4, 1)
|
||||||
|
|
||||||
input_gate = self.input_gate(input_gate)
|
input_gate = self.input_gate(input_gate)
|
||||||
forget_gate = self.forget_gate(forget_gate)
|
forget_gate = self.forget_gate(forget_gate)
|
||||||
cell_gate = self.cell_gate(cell_gate)
|
cell_gate = self.cell_gate(cell_gate)
|
||||||
out_gate = self.output_gate(out_gate)
|
out_gate = self.output_gate(out_gate)
|
||||||
else:
|
|
||||||
# apply each input + hidden projection and add together
|
|
||||||
gate = {}
|
|
||||||
for (key, gates), igates, hgates in zip(
|
|
||||||
self.gates.items(), self.igates.values(), self.hgates.values()
|
|
||||||
):
|
|
||||||
gate[key] = gates.add(igates(x), hgates(hx))
|
|
||||||
|
|
||||||
input_gate = self.input_gate(gate["input"])
|
|
||||||
forget_gate = self.forget_gate(gate["forget"])
|
|
||||||
cell_gate = self.cell_gate(gate["cell"])
|
|
||||||
out_gate = self.output_gate(gate["output"])
|
|
||||||
|
|
||||||
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
fgate_cx = self.fgate_cx.mul(forget_gate, cx)
|
||||||
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
igate_cgate = self.igate_cgate.mul(input_gate, cell_gate)
|
||||||
|
|
@ -158,7 +122,7 @@ class LSTMCell(torch.nn.Module):
|
||||||
return "QuantizableLSTMCell"
|
return "QuantizableLSTMCell"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_params(cls, wi, wh, bi=None, bh=None, split_gates=False):
|
def from_params(cls, wi, wh, bi=None, bh=None):
|
||||||
"""Uses the weights and biases to create a new LSTM cell.
|
"""Uses the weights and biases to create a new LSTM cell.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -168,52 +132,25 @@ class LSTMCell(torch.nn.Module):
|
||||||
assert (bi is None) == (bh is None) # Either both None or both have values
|
assert (bi is None) == (bh is None) # Either both None or both have values
|
||||||
input_size = wi.shape[1]
|
input_size = wi.shape[1]
|
||||||
hidden_size = wh.shape[1]
|
hidden_size = wh.shape[1]
|
||||||
cell = cls(
|
cell = cls(input_dim=input_size, hidden_dim=hidden_size, bias=(bi is not None))
|
||||||
input_dim=input_size,
|
cell.igates.weight = torch.nn.Parameter(wi)
|
||||||
hidden_dim=hidden_size,
|
if bi is not None:
|
||||||
bias=(bi is not None),
|
cell.igates.bias = torch.nn.Parameter(bi)
|
||||||
split_gates=split_gates,
|
cell.hgates.weight = torch.nn.Parameter(wh)
|
||||||
)
|
if bh is not None:
|
||||||
|
cell.hgates.bias = torch.nn.Parameter(bh)
|
||||||
if not split_gates:
|
|
||||||
cell.igates.weight = torch.nn.Parameter(wi)
|
|
||||||
if bi is not None:
|
|
||||||
cell.igates.bias = torch.nn.Parameter(bi)
|
|
||||||
cell.hgates.weight = torch.nn.Parameter(wh)
|
|
||||||
if bh is not None:
|
|
||||||
cell.hgates.bias = torch.nn.Parameter(bh)
|
|
||||||
else:
|
|
||||||
# split weight/bias
|
|
||||||
for w, b, gates in zip([wi, wh], [bi, bh], [cell.igates, cell.hgates]):
|
|
||||||
for w_chunk, gate in zip(w.chunk(4, dim=0), gates.values()):
|
|
||||||
gate.weight = torch.nn.Parameter(w_chunk)
|
|
||||||
|
|
||||||
if b is not None:
|
|
||||||
for b_chunk, gate in zip(b.chunk(4, dim=0), gates.values()):
|
|
||||||
gate.bias = torch.nn.Parameter(b_chunk)
|
|
||||||
|
|
||||||
return cell
|
return cell
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, other, use_precomputed_fake_quant=False, split_gates=False):
|
def from_float(cls, other, use_precomputed_fake_quant=False):
|
||||||
assert type(other) == cls._FLOAT_MODULE
|
assert type(other) == cls._FLOAT_MODULE
|
||||||
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
|
assert hasattr(other, "qconfig"), "The float module must have 'qconfig'"
|
||||||
observed = cls.from_params(
|
observed = cls.from_params(
|
||||||
other.weight_ih,
|
other.weight_ih, other.weight_hh, other.bias_ih, other.bias_hh
|
||||||
other.weight_hh,
|
|
||||||
other.bias_ih,
|
|
||||||
other.bias_hh,
|
|
||||||
split_gates=split_gates,
|
|
||||||
)
|
)
|
||||||
observed.qconfig = other.qconfig
|
observed.qconfig = other.qconfig
|
||||||
observed.igates.qconfig = other.qconfig
|
observed.igates.qconfig = other.qconfig
|
||||||
observed.hgates.qconfig = other.qconfig
|
observed.hgates.qconfig = other.qconfig
|
||||||
if split_gates:
|
|
||||||
# also apply qconfig directly to Linear modules
|
|
||||||
for g in observed.igates.values():
|
|
||||||
g.qconfig = other.qconfig
|
|
||||||
for g in observed.hgates.values():
|
|
||||||
g.qconfig = other.qconfig
|
|
||||||
return observed
|
return observed
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -231,14 +168,10 @@ class _LSTMSingleLayer(torch.nn.Module):
|
||||||
bias: bool = True,
|
bias: bool = True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
*,
|
|
||||||
split_gates=False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.cell = LSTMCell(
|
self.cell = LSTMCell(input_dim, hidden_dim, bias=bias, **factory_kwargs)
|
||||||
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||||
result = []
|
result = []
|
||||||
|
|
@ -252,9 +185,7 @@ class _LSTMSingleLayer(torch.nn.Module):
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_params(cls, *args, **kwargs):
|
def from_params(cls, *args, **kwargs):
|
||||||
cell = LSTMCell.from_params(*args, **kwargs)
|
cell = LSTMCell.from_params(*args, **kwargs)
|
||||||
layer = cls(
|
layer = cls(cell.input_size, cell.hidden_size, cell.bias)
|
||||||
cell.input_size, cell.hidden_size, cell.bias, split_gates=cell.split_gates
|
|
||||||
)
|
|
||||||
layer.cell = cell
|
layer.cell = cell
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
@ -271,23 +202,17 @@ class _LSTMLayer(torch.nn.Module):
|
||||||
bidirectional: bool = False,
|
bidirectional: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
*,
|
|
||||||
split_gates=False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.batch_first = batch_first
|
self.batch_first = batch_first
|
||||||
self.bidirectional = bidirectional
|
self.bidirectional = bidirectional
|
||||||
self.layer_fw = _LSTMSingleLayer(
|
self.layer_fw = _LSTMSingleLayer(
|
||||||
input_dim, hidden_dim, bias=bias, split_gates=split_gates, **factory_kwargs
|
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||||
)
|
)
|
||||||
if self.bidirectional:
|
if self.bidirectional:
|
||||||
self.layer_bw = _LSTMSingleLayer(
|
self.layer_bw = _LSTMSingleLayer(
|
||||||
input_dim,
|
input_dim, hidden_dim, bias=bias, **factory_kwargs
|
||||||
hidden_dim,
|
|
||||||
bias=bias,
|
|
||||||
split_gates=split_gates,
|
|
||||||
**factory_kwargs,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||||
|
|
@ -358,34 +283,22 @@ class _LSTMLayer(torch.nn.Module):
|
||||||
bias = kwargs.get("bias", other.bias)
|
bias = kwargs.get("bias", other.bias)
|
||||||
batch_first = kwargs.get("batch_first", other.batch_first)
|
batch_first = kwargs.get("batch_first", other.batch_first)
|
||||||
bidirectional = kwargs.get("bidirectional", other.bidirectional)
|
bidirectional = kwargs.get("bidirectional", other.bidirectional)
|
||||||
split_gates = kwargs.get("split_gates", False)
|
|
||||||
|
|
||||||
layer = cls(
|
layer = cls(input_size, hidden_size, bias, batch_first, bidirectional)
|
||||||
input_size,
|
|
||||||
hidden_size,
|
|
||||||
bias,
|
|
||||||
batch_first,
|
|
||||||
bidirectional,
|
|
||||||
split_gates=split_gates,
|
|
||||||
)
|
|
||||||
layer.qconfig = getattr(other, "qconfig", qconfig)
|
layer.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
||||||
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
||||||
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
|
bi = getattr(other, f"bias_ih_l{layer_idx}", None)
|
||||||
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
|
bh = getattr(other, f"bias_hh_l{layer_idx}", None)
|
||||||
|
|
||||||
layer.layer_fw = _LSTMSingleLayer.from_params(
|
layer.layer_fw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||||
wi, wh, bi, bh, split_gates=split_gates
|
|
||||||
)
|
|
||||||
|
|
||||||
if other.bidirectional:
|
if other.bidirectional:
|
||||||
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
|
wi = getattr(other, f"weight_ih_l{layer_idx}_reverse")
|
||||||
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
|
wh = getattr(other, f"weight_hh_l{layer_idx}_reverse")
|
||||||
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
|
bi = getattr(other, f"bias_ih_l{layer_idx}_reverse", None)
|
||||||
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
|
bh = getattr(other, f"bias_hh_l{layer_idx}_reverse", None)
|
||||||
layer.layer_bw = _LSTMSingleLayer.from_params(
|
layer.layer_bw = _LSTMSingleLayer.from_params(wi, wh, bi, bh)
|
||||||
wi, wh, bi, bh, split_gates=split_gates
|
|
||||||
)
|
|
||||||
return layer
|
return layer
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -429,8 +342,6 @@ class LSTM(torch.nn.Module):
|
||||||
bidirectional: bool = False,
|
bidirectional: bool = False,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
*,
|
|
||||||
split_gates: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -475,7 +386,6 @@ class LSTM(torch.nn.Module):
|
||||||
self.bias,
|
self.bias,
|
||||||
batch_first=False,
|
batch_first=False,
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
split_gates=split_gates,
|
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
@ -486,7 +396,6 @@ class LSTM(torch.nn.Module):
|
||||||
self.bias,
|
self.bias,
|
||||||
batch_first=False,
|
batch_first=False,
|
||||||
bidirectional=self.bidirectional,
|
bidirectional=self.bidirectional,
|
||||||
split_gates=split_gates,
|
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
for layer in range(1, num_layers)
|
for layer in range(1, num_layers)
|
||||||
|
|
@ -552,7 +461,7 @@ class LSTM(torch.nn.Module):
|
||||||
return "QuantizableLSTM"
|
return "QuantizableLSTM"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_float(cls, other, qconfig=None, split_gates=False):
|
def from_float(cls, other, qconfig=None):
|
||||||
assert isinstance(other, cls._FLOAT_MODULE)
|
assert isinstance(other, cls._FLOAT_MODULE)
|
||||||
assert hasattr(other, "qconfig") or qconfig
|
assert hasattr(other, "qconfig") or qconfig
|
||||||
observed = cls(
|
observed = cls(
|
||||||
|
|
@ -563,12 +472,11 @@ class LSTM(torch.nn.Module):
|
||||||
other.batch_first,
|
other.batch_first,
|
||||||
other.dropout,
|
other.dropout,
|
||||||
other.bidirectional,
|
other.bidirectional,
|
||||||
split_gates=split_gates,
|
|
||||||
)
|
)
|
||||||
observed.qconfig = getattr(other, "qconfig", qconfig)
|
observed.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
for idx in range(other.num_layers):
|
for idx in range(other.num_layers):
|
||||||
observed.layers[idx] = _LSTMLayer.from_float(
|
observed.layers[idx] = _LSTMLayer.from_float(
|
||||||
other, idx, qconfig, batch_first=False, split_gates=split_gates
|
other, idx, qconfig, batch_first=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the model
|
# Prepare the model
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ class LSTM(torch.ao.nn.quantizable.LSTM):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_observed(cls, other):
|
def from_observed(cls, other):
|
||||||
assert isinstance(other, cls._FLOAT_MODULE)
|
assert type(other) == cls._FLOAT_MODULE # type: ignore[has-type]
|
||||||
converted = torch.ao.quantization.convert(
|
converted = torch.ao.quantization.convert(
|
||||||
other, inplace=False, remove_qconfig=True
|
other, inplace=False, remove_qconfig=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,6 @@ def _get_lstm_with_individually_observed_parts(
|
||||||
tanh_obs_ctr: Optional[_PartialWrapper] = None,
|
tanh_obs_ctr: Optional[_PartialWrapper] = None,
|
||||||
cell_state_obs_ctr: Optional[_PartialWrapper] = None,
|
cell_state_obs_ctr: Optional[_PartialWrapper] = None,
|
||||||
hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
|
hidden_state_obs_ctr: Optional[_PartialWrapper] = None,
|
||||||
split_gates: bool = False,
|
|
||||||
) -> torch.ao.nn.quantizable.LSTM:
|
) -> torch.ao.nn.quantizable.LSTM:
|
||||||
"""
|
"""
|
||||||
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
|
Return an observed `torch.ao.nn.quantizable.LSTM` created from a `torch.nn.LSTM`
|
||||||
|
|
@ -76,7 +75,6 @@ def _get_lstm_with_individually_observed_parts(
|
||||||
float_lstm.batch_first,
|
float_lstm.batch_first,
|
||||||
float_lstm.dropout,
|
float_lstm.dropout,
|
||||||
float_lstm.bidirectional,
|
float_lstm.bidirectional,
|
||||||
split_gates=split_gates,
|
|
||||||
)
|
)
|
||||||
quantizable_lstm.qconfig = float_lstm.qconfig
|
quantizable_lstm.qconfig = float_lstm.qconfig
|
||||||
|
|
||||||
|
|
@ -84,11 +82,7 @@ def _get_lstm_with_individually_observed_parts(
|
||||||
quantizable_lstm.layers[
|
quantizable_lstm.layers[
|
||||||
idx
|
idx
|
||||||
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
|
] = torch.ao.nn.quantizable.modules.rnn._LSTMLayer.from_float(
|
||||||
float_lstm,
|
float_lstm, idx, float_lstm.qconfig, batch_first=False
|
||||||
idx,
|
|
||||||
float_lstm.qconfig,
|
|
||||||
batch_first=False,
|
|
||||||
split_gates=split_gates,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build QConfigMapping for the LSTM cell
|
# Build QConfigMapping for the LSTM cell
|
||||||
|
|
@ -111,25 +105,13 @@ def _get_lstm_with_individually_observed_parts(
|
||||||
# to configure these ops in FX graph mode quantization today. This is because
|
# to configure these ops in FX graph mode quantization today. This is because
|
||||||
# the FloatFunctional modules simply disappear from the graph after tracing.
|
# the FloatFunctional modules simply disappear from the graph after tracing.
|
||||||
# In the future, we should rewrite quantizable LSTM without FloatFunctionals.
|
# In the future, we should rewrite quantizable LSTM without FloatFunctionals.
|
||||||
if not split_gates:
|
op_index_to_activation_post_process_ctr = {
|
||||||
op_index_to_activation_post_process_ctr = {
|
(torch.add, 0): linear_output_obs_ctr, # gates.add
|
||||||
(torch.add, 0): linear_output_obs_ctr, # gates.add
|
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
||||||
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
||||||
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
||||||
(torch.add, 1): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
||||||
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
}
|
||||||
}
|
|
||||||
else:
|
|
||||||
op_index_to_activation_post_process_ctr = {
|
|
||||||
(torch.add, 0): linear_output_obs_ctr, # gates.add (input)
|
|
||||||
(torch.add, 1): linear_output_obs_ctr, # gates.add (forget)
|
|
||||||
(torch.add, 2): linear_output_obs_ctr, # gates.add (cell)
|
|
||||||
(torch.add, 3): linear_output_obs_ctr, # gates.add (output)
|
|
||||||
(torch.mul, 0): cell_state_obs_ctr, # fgate_cx.mul
|
|
||||||
(torch.mul, 1): cell_state_obs_ctr, # igate_cgate.mul
|
|
||||||
(torch.add, 4): cell_state_obs_ctr, # fgate_cx_igate_cgate.add
|
|
||||||
(torch.mul, 2): hidden_state_obs_ctr, # ogate_cy.mul
|
|
||||||
}
|
|
||||||
add_count = 0
|
add_count = 0
|
||||||
mul_count = 0
|
mul_count = 0
|
||||||
for node in cell.graph.nodes:
|
for node in cell.graph.nodes:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user