mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In quantization tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154728 Approved by: https://github.com/ezyang
211 lines
8.0 KiB
Python
211 lines
8.0 KiB
Python
# Owner(s): ["oncall: quantization"]
|
|
|
|
import copy
|
|
|
|
import torch
|
|
import torch.ao.quantization._equalize as _equalize
|
|
import torch.nn as nn
|
|
from torch.ao.quantization.fuse_modules import fuse_modules
|
|
from torch.testing._internal.common_quantization import QuantizationTestCase
|
|
from torch.testing._internal.common_utils import raise_on_run_directly
|
|
|
|
|
|
class TestEqualizeEager(QuantizationTestCase):
|
|
def checkChannelsEqualized(self, tensor1, tensor2, output_axis, input_axis):
|
|
"""Checks the channel ranges of tensor1, tensor2 are the same,
|
|
which is an indication that equalization has been applied correctly
|
|
"""
|
|
output_channel_tensor1 = _equalize.channel_range(tensor1, output_axis)
|
|
input_channel_tensor2 = _equalize.channel_range(tensor2, input_axis)
|
|
|
|
# ensuring the channels ranges of tensor1's input is the same as
|
|
# tensor2's output
|
|
self.assertEqual(output_channel_tensor1, input_channel_tensor2)
|
|
|
|
def getModule(self, model, name):
|
|
"""Given the name is a submodule to a model, return the submodule"""
|
|
curr = model
|
|
name = name.split(".")
|
|
for subname in name:
|
|
curr = curr._modules[subname]
|
|
return curr
|
|
|
|
def test_cross_layer_equalization(self):
|
|
"""applies _equalize.cross_layer_equalization on two modules and checks
|
|
to make sure channels ranges are equivalent
|
|
"""
|
|
module1 = nn.Conv2d(3, 4, 2)
|
|
module2 = nn.Linear(4, 4)
|
|
|
|
module1_output_channel_axis = 0
|
|
module2_input_channel_axis = 1
|
|
|
|
_equalize.cross_layer_equalization(module1, module2)
|
|
|
|
mod_tensor1, mod_tensor2 = module1.weight, module2.weight
|
|
|
|
self.checkChannelsEqualized(
|
|
mod_tensor1,
|
|
mod_tensor2,
|
|
module1_output_channel_axis,
|
|
module2_input_channel_axis,
|
|
)
|
|
|
|
def test_converged(self):
|
|
"""Sanity checks on _equalize.converged working
|
|
identical modules should return true
|
|
modules with high difference in weights should return false
|
|
"""
|
|
module1 = nn.Linear(3, 3)
|
|
module2 = nn.Linear(3, 3)
|
|
|
|
module1.weight = nn.parameter.Parameter(torch.ones(module1.weight.size()))
|
|
module2.weight = nn.parameter.Parameter(torch.zeros(module1.weight.size()))
|
|
|
|
# input is a dictionary
|
|
dictionary_1 = {"linear1": module1}
|
|
dictionary_2 = {"linear1": module2}
|
|
self.assertTrue(_equalize.converged(dictionary_1, dictionary_1, 1e-6))
|
|
self.assertFalse(_equalize.converged(dictionary_1, dictionary_2, 1e-6))
|
|
|
|
def test_equalize(self):
|
|
"""First checks to see if _equalize.equalize can handle multiple
|
|
pair modules as input
|
|
then checks correctness of the function by ensuring the equalized
|
|
and unequalized versions of the model yield the same output
|
|
given the same input
|
|
"""
|
|
|
|
class ChainModule(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(3, 4)
|
|
self.linear2 = nn.Linear(4, 5)
|
|
self.linear3 = nn.Linear(5, 6)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
x = self.linear3(x)
|
|
return x
|
|
|
|
chain1 = ChainModule()
|
|
chain2 = copy.deepcopy(chain1)
|
|
|
|
_equalize.equalize(
|
|
chain1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
|
)
|
|
linear1 = self.getModule(chain1, "linear1")
|
|
linear2 = self.getModule(chain1, "linear2")
|
|
linear3 = self.getModule(chain1, "linear3")
|
|
|
|
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
|
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
|
|
|
input = torch.randn(20, 3)
|
|
self.assertEqual(chain1(input), chain2(input))
|
|
|
|
def test_equalize_fused_convrelu(self):
|
|
"""Checks to see if eager mode equalization supports fused
|
|
ConvReLU2d models
|
|
|
|
A model with 3 ConvReLU2d is constructed. Next, the conv2d and relu
|
|
layers are fused together and adjacent conv2d layers have cross-layer
|
|
equalization applied. Finally, we ensure that the channels have been
|
|
equalized and that the equalized and unequalized versions of the model
|
|
yield the same output given the same input
|
|
"""
|
|
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv1 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
|
|
self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
self.conv2 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
|
|
self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
self.conv3 = nn.Conv2d(3, 3, 1).to(dtype=torch.float)
|
|
self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.relu1(x)
|
|
x = self.conv2(x)
|
|
x = self.relu2(x)
|
|
x = self.conv3(x)
|
|
x = self.relu3(x)
|
|
return x
|
|
|
|
model = M()
|
|
|
|
fused_model1 = fuse_modules(
|
|
model, [["conv1", "relu1"], ["conv2", "relu2"], ["conv3", "relu3"]]
|
|
)
|
|
fused_model2 = copy.deepcopy(fused_model1)
|
|
|
|
_equalize.equalize(fused_model1, [["conv1", "conv2"], ["conv2", "conv3"]], 1e-6)
|
|
conv1 = self.getModule(fused_model1, "conv1")[0]
|
|
conv2 = self.getModule(fused_model1, "conv2")[0]
|
|
conv3 = self.getModule(fused_model1, "conv3")[0]
|
|
|
|
self.checkChannelsEqualized(conv1.weight, conv2.weight, 0, 1)
|
|
self.checkChannelsEqualized(conv2.weight, conv3.weight, 0, 1)
|
|
|
|
input = torch.randn(3, 3, 1, 1)
|
|
self.assertEqual(fused_model1(input), fused_model2(input))
|
|
self.assertEqual(fused_model1(input), model(input))
|
|
|
|
def test_equalize_fused_linearrelu(self):
|
|
"""Checks to see if eager mode equalization supports fused
|
|
LinearReLU models
|
|
|
|
A model with 3 LinearReLU is constructed. Next, the linear and relu
|
|
layers are fused together and adjacent linear layers have cross-layer
|
|
equalization applied. Finally, we ensure that the channels have been
|
|
equalized and that the equalized and unequalized versions of the model
|
|
yield the same output given the same input
|
|
"""
|
|
|
|
class M(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(3, 4)
|
|
self.relu1 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
self.linear2 = nn.Linear(4, 5)
|
|
self.relu2 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
self.linear3 = nn.Linear(5, 6)
|
|
self.relu3 = nn.ReLU(inplace=False).to(dtype=torch.float)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.relu1(x)
|
|
x = self.linear2(x)
|
|
x = self.relu2(x)
|
|
x = self.linear3(x)
|
|
x = self.relu3(x)
|
|
return x
|
|
|
|
model = M()
|
|
|
|
fused_model1 = fuse_modules(
|
|
model, [["linear1", "relu1"], ["linear2", "relu2"], ["linear3", "relu3"]]
|
|
)
|
|
fused_model2 = copy.deepcopy(fused_model1)
|
|
|
|
_equalize.equalize(
|
|
fused_model1, [["linear1", "linear2"], ["linear2", "linear3"]], 1e-6
|
|
)
|
|
linear1 = self.getModule(fused_model1, "linear1")[0]
|
|
linear2 = self.getModule(fused_model1, "linear2")[0]
|
|
linear3 = self.getModule(fused_model1, "linear3")[0]
|
|
|
|
self.checkChannelsEqualized(linear1.weight, linear2.weight, 0, 1)
|
|
self.checkChannelsEqualized(linear2.weight, linear3.weight, 0, 1)
|
|
|
|
input = torch.randn(20, 3)
|
|
self.assertEqual(fused_model1(input), fused_model2(input))
|
|
self.assertEqual(fused_model1(input), model(input))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_quantization.py")
|