mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
NNAPI: quant logistic fix (#70847)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70847 NNAPI needs a fixed zero point and scale for sigmoid (logistic) ghstack-source-id: 146555935 Test Plan: LIBNEURALNETWORKS_PATH="/path/to/libneuralnetworks.so" pytest test/test_nnapi.py Reviewed By: dreiss Differential Revision: D33237918 fbshipit-source-id: 05ef3a81bf1589ad44b599a19bce4066531c432b
This commit is contained in:
parent
ed50a35cf8
commit
a70297e7cb
|
|
@ -286,6 +286,10 @@ class TestNNAPI(TestCase):
|
|||
return torch.sigmoid(arg)
|
||||
raise Exception("Bad op")
|
||||
self.check(UnaryModule(), torch.tensor([-1.0, 1.0]))
|
||||
self.check(
|
||||
UnaryModule(),
|
||||
qpt(torch.tensor([-1.0, 1.0]), 1. / 256, 0),
|
||||
)
|
||||
|
||||
def test_pointwise_binary(self):
|
||||
for op in ["add", "sub", "mul", "div"]:
|
||||
|
|
|
|||
|
|
@ -1267,7 +1267,16 @@ class _NnapiSerializer(object):
|
|||
assert node.outputsSize() == 1
|
||||
|
||||
in_id, in_oper = self.get_tensor_operand_by_jitval(node.inputsAt(0))
|
||||
out_id = self.add_tensor_operand(node.outputsAt(0), in_oper)
|
||||
|
||||
out_oper = in_oper
|
||||
if opcode == NNAPI_OperationCode.LOGISTIC:
|
||||
# NNAPI docs: For ANEURALNETWORKS_TENSOR_QUANT8_ASYMM, the scale
|
||||
# must be 1.f / 256 and the zeroPoint must be 0.
|
||||
# https://fburl.com/h52stoog
|
||||
if in_oper.op_type == NNAPI_OperandCode.TENSOR_QUANT8_ASYMM:
|
||||
out_oper = in_oper._replace(zero_point=0, scale=1.0 / 256)
|
||||
|
||||
out_id = self.add_tensor_operand(node.outputsAt(0), out_oper)
|
||||
|
||||
for idx, dim in enumerate(in_oper.shape):
|
||||
if dim == 0:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user