[Quant] [PT2] Enable QLinear input with multi dims (#113733)

**Summary**
In the previous QLinear implementation, it was assumed that inputs have a dimension of 2. In this update, we have modified QLinear to accept inputs with a dimension greater than 2, incorporating input and output reshaping accordingly.

**Test Plan**
```
python -u -m pytest -s -v test_quantized_op.py -k test_qlinear_pt2e
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113733
Approved by: https://github.com/jgong5, https://github.com/eellison
This commit is contained in:
leslie-fang-intel 2023-12-04 09:53:59 +08:00 committed by PyTorch MergeBot
parent b8ce05456c
commit 4a624d1f8a
3 changed files with 17 additions and 6 deletions

View File

@ -931,10 +931,18 @@ static at::Tensor linear_int8_with_onednn_weight(
output_scale == 1.0f && output_zero_point == 0, "onednn qlinear: expect scale=1 and zero point=0 for fp32 output");
}
auto input_contig = input.contiguous();
// If the input has more than two dimensions, we will reshape it to a 2-dimensional form
// for calculation and subsequently reshape the output back.
auto input_contig =
dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous();
auto src = at::native::itensor_from_tensor(input_contig);
auto packed_weight = at::native::itensor_from_mkldnn(onednn_weight);
int64_t K = input.size(dim - 1), M = input.numel() / K, N = packed_weight.get_dim(1);
auto output_size = input.sizes().vec();
output_size[dim - 1] = N;
c10::optional<ideep::tensor> onednn_bias{c10::nullopt};
bool with_bias = bias.has_value();
at::Tensor bias_val_float;
@ -1020,7 +1028,7 @@ static at::Tensor linear_int8_with_onednn_weight(
args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, dst_zp_t});
}
primitive.execute(ideep::stream::default_stream(), args);
return output;
return dim == 2 ? output : output.reshape(output_size);
}
#endif // #if AT_MKLDNN_ENABLED()

View File

@ -4186,18 +4186,19 @@ class TestQuantizedLinear(TestCase):
w_scale, w_zp = 0.8, 0
y_scale, y_zp = 4.7, 2
post_op_args = []
input_dim_list = [2, 3]
cases = itertools.product(
in_channels_list, out_channels_list, use_bias_list,
supported_post_ops, weight_quant_per_channel_list, output_dtype_list)
supported_post_ops, weight_quant_per_channel_list, output_dtype_list, input_dim_list)
with override_quantized_engine('onednn'):
for ic, oc, use_bias, post_op, weight_quant_per_channel, output_dtype in cases:
for ic, oc, use_bias, post_op, weight_quant_per_channel, output_dtype, input_dim in cases:
used_y_scale = y_scale
used_y_zp = y_zp
fp32_out = output_dtype == torch.float32
bfloat16_out = output_dtype == torch.bfloat16
if fp32_out or bfloat16_out:
used_y_scale, used_y_zp = 1.0, 0
x = torch.rand(batch_size, ic) * 10
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
w = torch.rand(oc, ic) * 10
qx = torch.quantize_per_tensor(x, x_scale, x_zp, torch.quint8)
if weight_quant_per_channel:
@ -4233,6 +4234,8 @@ class TestQuantizedLinear(TestCase):
used_y_zp, dtype=torch.quint8
).int_repr()
self.assertEqual(x.dim(), qy_cpu.dim())
np.testing.assert_array_almost_equal(
qy_ref.int_repr().cpu().numpy(),
qy_cpu.cpu().numpy(),

View File

@ -5096,7 +5096,7 @@ def _prepare_linear_fusion_create(
)
output_size = output.size()
req_stride_order = [1, 0]
req_stride_order = list(reversed(range(len(x.get_size()))))
output_stride = make_contiguous_strides_for(output_size)
x = cls.require_stride_order(x, req_stride_order)