Support broadcast for quantized mul kernel (#30442)

Summary:
Since the tensor iterator supports the broadcast, we will just remove the assertion on input shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/30442

Differential Revision: D19976562

Pulled By: lly-zero-one

fbshipit-source-id: 91b27fc8b2570f29d110c6df26eacdd16f587b9f
This commit is contained in:
Lingyi Liu 2020-02-19 16:50:34 -08:00 committed by Facebook Github Bot
parent ea514c819a
commit ecb05f12c3
3 changed files with 41 additions and 4 deletions

View File

@ -14,12 +14,10 @@ namespace {
inline void check_inputs(const Tensor& qa, const Tensor& qb) {
TORCH_CHECK(qa.qscheme() == kPerTensorAffine,
"Only per tensor quantization is supported in Mul.");
TORCH_CHECK(qa.qscheme() == qb.qscheme(),
"Both inputs to Mul must have the same quantization scheme.");
TORCH_CHECK(qa.numel() == qb.numel(),
"Mul operands must be the same size!");
TORCH_CHECK(qa.scalar_type() == qb.scalar_type(),
"Mul operands should have same data type.");
TORCH_CHECK(qa.qscheme() == qb.qscheme(),
"Both inputs to Mul must have the same quantization shceme.");
}
// Note: out is assumed to be the same size as self and other.

View File

@ -505,6 +505,37 @@ class TestQuantizedOps(TestCase):
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
message="mulReLU.out failed")
"""Tests the correctness of the mul and mul_relu op."""
def test_qmul_broadcast(self):
mul_relu = torch.ops.quantized.mul_relu
mul = torch.ops.quantized.mul
mul_out = torch.ops.quantized.mul_out
mul_relu_out = torch.ops.quantized.mul_relu_out
# A = torch.arange(-25, 25, dtype=torch.float)
# B = torch.arange(-25, 25, dtype=torch.float)
A = torch.randn(8, 1, 6, 1)
B = torch.randn(7, 1, 5)
scale_A = 3.0
zero_point_A = 7
scale_B = 5.0
zero_point_B = 127
scale_C = 0.5
zero_point_C = 5
qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
dtype=torch.quint8)
qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
dtype=torch.quint8)
# mul ground truth
C = (qA.dequantize() * qB.dequantize()).numpy()
qC = _quantize(C, scale_C, zero_point_C)
qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C)
np.testing.assert_equal(qC, qC_hat.int_repr(),
"Quantized multiplication failed.")
"""Tests max pool operation on quantized tensors."""
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
min_side=1, max_side=10),

View File

@ -124,6 +124,14 @@ class QFunctional(torch.nn.Module):
super(QFunctional, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
missing_keys, unexpected_keys, error_msgs)
def _get_name(self):
return 'QFunctional'
def extra_repr(self):
return 'scale={}, zero_point={}'.format(
self.scale, self.zero_point
)
def forward(self, x):
raise RuntimeError("Functional is not intended to use the " +
"'forward'. Please use the underlying operation")