mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support broadcast for quantized mul kernel (#30442)
Summary: Since the tensor iterator supports the broadcast, we will just remove the assertion on input shapes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30442 Differential Revision: D19976562 Pulled By: lly-zero-one fbshipit-source-id: 91b27fc8b2570f29d110c6df26eacdd16f587b9f
This commit is contained in:
parent
ea514c819a
commit
ecb05f12c3
|
|
@ -14,12 +14,10 @@ namespace {
|
|||
inline void check_inputs(const Tensor& qa, const Tensor& qb) {
|
||||
TORCH_CHECK(qa.qscheme() == kPerTensorAffine,
|
||||
"Only per tensor quantization is supported in Mul.");
|
||||
TORCH_CHECK(qa.qscheme() == qb.qscheme(),
|
||||
"Both inputs to Mul must have the same quantization scheme.");
|
||||
TORCH_CHECK(qa.numel() == qb.numel(),
|
||||
"Mul operands must be the same size!");
|
||||
TORCH_CHECK(qa.scalar_type() == qb.scalar_type(),
|
||||
"Mul operands should have same data type.");
|
||||
TORCH_CHECK(qa.qscheme() == qb.qscheme(),
|
||||
"Both inputs to Mul must have the same quantization shceme.");
|
||||
}
|
||||
|
||||
// Note: out is assumed to be the same size as self and other.
|
||||
|
|
|
|||
|
|
@ -505,6 +505,37 @@ class TestQuantizedOps(TestCase):
|
|||
self.assertEqual(qCrelu_hat, qCrelu_out_hat,
|
||||
message="mulReLU.out failed")
|
||||
|
||||
"""Tests the correctness of the mul and mul_relu op."""
|
||||
def test_qmul_broadcast(self):
|
||||
mul_relu = torch.ops.quantized.mul_relu
|
||||
mul = torch.ops.quantized.mul
|
||||
mul_out = torch.ops.quantized.mul_out
|
||||
mul_relu_out = torch.ops.quantized.mul_relu_out
|
||||
|
||||
# A = torch.arange(-25, 25, dtype=torch.float)
|
||||
# B = torch.arange(-25, 25, dtype=torch.float)
|
||||
A = torch.randn(8, 1, 6, 1)
|
||||
B = torch.randn(7, 1, 5)
|
||||
scale_A = 3.0
|
||||
zero_point_A = 7
|
||||
scale_B = 5.0
|
||||
zero_point_B = 127
|
||||
|
||||
scale_C = 0.5
|
||||
zero_point_C = 5
|
||||
|
||||
qA = torch.quantize_per_tensor(A, scale=scale_A, zero_point=zero_point_A,
|
||||
dtype=torch.quint8)
|
||||
qB = torch.quantize_per_tensor(B, scale=scale_B, zero_point=zero_point_B,
|
||||
dtype=torch.quint8)
|
||||
|
||||
# mul ground truth
|
||||
C = (qA.dequantize() * qB.dequantize()).numpy()
|
||||
qC = _quantize(C, scale_C, zero_point_C)
|
||||
qC_hat = mul(qA, qB, scale=scale_C, zero_point=zero_point_C)
|
||||
np.testing.assert_equal(qC, qC_hat.int_repr(),
|
||||
"Quantized multiplication failed.")
|
||||
|
||||
"""Tests max pool operation on quantized tensors."""
|
||||
@given(X=hu.tensor(shapes=hu.array_shapes(min_dims=3, max_dims=4,
|
||||
min_side=1, max_side=10),
|
||||
|
|
|
|||
|
|
@ -124,6 +124,14 @@ class QFunctional(torch.nn.Module):
|
|||
super(QFunctional, self)._load_from_state_dict(state_dict, prefix, local_metadata, False,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
def _get_name(self):
|
||||
return 'QFunctional'
|
||||
|
||||
def extra_repr(self):
|
||||
return 'scale={}, zero_point={}'.format(
|
||||
self.scale, self.zero_point
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
raise RuntimeError("Functional is not intended to use the " +
|
||||
"'forward'. Please use the underlying operation")
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user