mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytorch][ao] Add torch.matmul in FloatFunctional/QFunctional (#106831)
Summary: As title Test Plan: new unit tests Differential Revision: D48172841 Pull Request resolved: https://github.com/pytorch/pytorch/pull/106831 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
dfb1b95919
commit
152203d3c3
|
|
@ -327,6 +327,7 @@ class TestQuantizeEagerOps(QuantizationTestCase):
|
|||
self.assertEqual(type(model.myadd), torch.ao.nn.quantized.QFunctional)
|
||||
self.assertEqual(type(model.mycat), torch.ao.nn.quantized.QFunctional)
|
||||
self.assertEqual(type(model.myadd_relu), torch.ao.nn.quantized.QFunctional)
|
||||
self.assertEqual(type(model.mymatmul), torch.ao.nn.quantized.QFunctional)
|
||||
self.checkNoQconfig(model)
|
||||
|
||||
checkQuantized(model)
|
||||
|
|
|
|||
|
|
@ -79,6 +79,12 @@ class FloatFunctional(torch.nn.Module):
|
|||
r = self.activation_post_process(r)
|
||||
return r
|
||||
|
||||
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
|
||||
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
r = torch.matmul(x, y)
|
||||
r = self.activation_post_process(r)
|
||||
return r
|
||||
|
||||
class FXFloatFunctional(torch.nn.Module):
|
||||
r""" module to replace FloatFunctional module before FX graph mode quantization,
|
||||
since activation_post_process will be inserted in top level module directly
|
||||
|
|
@ -126,6 +132,11 @@ class FXFloatFunctional(torch.nn.Module):
|
|||
r = torch.nn.functional.relu(r)
|
||||
return r
|
||||
|
||||
r"""Operation equivalent to ``torch.matmul(Tensor, Tensor)``"""
|
||||
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
r = torch.matmul(x, y)
|
||||
return r
|
||||
|
||||
class QFunctional(torch.nn.Module):
|
||||
r"""Wrapper class for quantized operations.
|
||||
|
||||
|
|
@ -220,6 +231,13 @@ class QFunctional(torch.nn.Module):
|
|||
r = self.activation_post_process(r)
|
||||
return r
|
||||
|
||||
r"""Operation equivalent to ``torch.ops.quantized.matmul(Tensor, Tensor)``"""
|
||||
def matmul(self, x: Tensor, y: Tensor) -> Tensor:
|
||||
r = ops.quantized.matmul(x, y, scale=self.scale, zero_point=self.zero_point)
|
||||
# Note: this operation is not observed because the observation is not
|
||||
# needed for the quantized op.
|
||||
return r
|
||||
|
||||
@classmethod
|
||||
def from_float(cls, mod):
|
||||
assert type(mod) == FloatFunctional,\
|
||||
|
|
|
|||
|
|
@ -2218,6 +2218,7 @@ class ModelWithFunctionals(torch.nn.Module):
|
|||
self.mycat = nnq.FloatFunctional()
|
||||
self.myadd = nnq.FloatFunctional()
|
||||
self.myadd_relu = nnq.FloatFunctional()
|
||||
self.mymatmul = nnq.FloatFunctional()
|
||||
# Tracing doesnt work yet for c10 ops with scalar inputs
|
||||
# https://github.com/pytorch/pytorch/issues/27097
|
||||
# self.my_scalar_add = nnq.FloatFunctional()
|
||||
|
|
@ -2227,11 +2228,12 @@ class ModelWithFunctionals(torch.nn.Module):
|
|||
y = self.mycat.cat([x, x, x])
|
||||
z = self.myadd.add(y, y)
|
||||
w = self.myadd_relu.add_relu(z, z)
|
||||
u = self.mymatmul.matmul(w, w.T)
|
||||
# Tracing doesnt work yet for c10 ops with scalar inputs
|
||||
# https://github.com/pytorch/pytorch/issues/27097
|
||||
# w = self.my_scalar_add.add_scalar(w, -0.5)
|
||||
# w = self.my_scalar_mul.mul_scalar(w, 0.5)
|
||||
return w
|
||||
return u
|
||||
|
||||
|
||||
class ResNetBase(torch.nn.Module):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user