[pytorch][ao] Add torch.matmul in FloatFunctional/QFunctional (#106831)

Summary: As title

Test Plan: new unit tests

Differential Revision: D48172841

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106831
Approved by: https://github.com/jerryzh168
This commit is contained in:
Jiaxu Zhu 2023-08-10 22:43:33 +00:00 committed by PyTorch MergeBot
parent dfb1b95919
commit 152203d3c3
3 changed files with 22 additions and 1 deletions

View File

@ -327,6 +327,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
self.assertEqual(type(model.myadd), torch.ao.nn.quantized.QFunctional)
self.assertEqual(type(model.mycat), torch.ao.nn.quantized.QFunctional)
self.assertEqual(type(model.myadd_relu), torch.ao.nn.quantized.QFunctional)
self.assertEqual(type(model.mymatmul), torch.ao.nn.quantized.QFunctional)
self.checkNoQconfig(model)
checkQuantized(model)

View File

@ -79,6 +79,12 @@ class FloatFunctional(torch.nn.Module):
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
r = self.activation_post_process(r)
return r
class FXFloatFunctional(torch.nn.Module):
r""" module to replace FloatFunctional module before FX graph mode quantization,
since activation_post_process will be inserted in top level module directly
@ -126,6 +132,11 @@ class FXFloatFunctional(torch.nn.Module):
r = torch.nn.functional.relu(r)
return r
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = torch.matmul(x, y)
return r
class QFunctional(torch.nn.Module):
r"""Wrapper class for quantized operations.
@ -220,6 +231,13 @@ class QFunctional(torch.nn.Module):
r = self.activation_post_process(r)
return r
r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
# Note: this operation is not observed because the observation is not
# needed for the quantized op.
return r
@classmethod
def from_float(cls, mod):
assert type(mod) == FloatFunctional,\

View File

@ -2218,6 +2218,7 @@ class ModelWithFunctionals(torch.nn.Module):
self.mycat = nnq.FloatFunctional()
self.myadd = nnq.FloatFunctional()
self.myadd_relu = nnq.FloatFunctional()
self.mymatmul = nnq.FloatFunctional()
# Tracing doesnt work yet for c10 ops with scalar inputs
# https://github.com/pytorch/pytorch/issues/27097
# self.my_scalar_add = nnq.FloatFunctional()
@ -2227,11 +2228,12 @@ class ModelWithFunctionals(torch.nn.Module):
y = self.mycat.cat([x, x, x])
z = self.myadd.add(y, y)
w = self.myadd_relu.add_relu(z, z)
u = self.mymatmul.matmul(w, w.T)
# Tracing doesnt work yet for c10 ops with scalar inputs
# https://github.com/pytorch/pytorch/issues/27097
# w = self.my_scalar_add.add_scalar(w, -0.5)
# w = self.my_scalar_mul.mul_scalar(w, 0.5)
return w
return u
class ResNetBase(torch.nn.Module):