mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[pruning][core][feature] Align BaseStructuredPruner with existing pruning flow (#88436)
Summary: This PR aligns the "eager" mode of the structured pruning flow with the existing unstructured pruning flow. The base pruner has been moved from and has been renamed from BasePruner to BaseStructuredPruner `torch/ao/pruning/_experimental/pruner/base_pruner.py -> /torch/ao/pruning/_experimental/pruner/base_structured_pruner.py` Support for pruning batchnorm modules in the config have been removed, so now the structured pruning code can use more of the BaseSparsifier logic and we don't need to override as many functions. Since we aim to only support a single flow, we have only updated ZeroesParametrizations (FakeStructuredSparsity) and BiasHook. The parameterizations have also been rewritten to use a bool mask tensor for keeping track of pruned rows, instead of using sets before. This better aligns structured and unstructured sparsity. The BaseStructuredSparsifier tests have also been updated to reflect the above changes. I also removed `squash_mask` tests because they were breaking CI and `squash_mask` is no longer used. We will migrate the structured pruning code out of this folder in a later PR. Test Plan: ``` python test/test_ao_sparsity -- TestBaseStructuredPruner ``` Reviewers: z-a-f vkuzo Subscribers: Tasks: Tags: Pull Request resolved: https://github.com/pytorch/pytorch/pull/88436 Approved by: https://github.com/vkuzo
This commit is contained in:
parent
d3f20a20b8
commit
9a1c6fd506
|
|
@ -7,7 +7,7 @@ import logging
|
|||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.ao.pruning._experimental.pruner import BasePruner, PruningParametrization, ZeroesParametrization
|
||||
from torch.ao.pruning._experimental.pruner import BaseStructuredSparsifier, FakeStructuredSparsity
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
|
||||
|
|
@ -19,10 +19,6 @@ DEVICES = {
|
|||
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
}
|
||||
|
||||
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
|
||||
nn.BatchNorm2d
|
||||
}
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
r"""Model with Linear layers, in Sequential and outside, without biases"""
|
||||
|
|
@ -159,56 +155,30 @@ class Conv2dC(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
class Conv2dBN(nn.Module):
|
||||
r"""Model with Conv2d layers and BatchNorms"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.seq = nn.Sequential(
|
||||
nn.Conv2d(1, 32, 3, 1, bias=True),
|
||||
nn.BatchNorm2d(32)
|
||||
)
|
||||
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
|
||||
self.bn = nn.BatchNorm2d(64)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.seq(x)
|
||||
x = self.conv2d(x)
|
||||
x = self.bn(x)
|
||||
return x
|
||||
|
||||
|
||||
class SimplePruner(BasePruner):
|
||||
class SimplePruner(BaseStructuredSparsifier):
|
||||
def update_mask(self, module, tensor_name, **kwargs):
|
||||
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.add(1)
|
||||
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
|
||||
|
||||
|
||||
class MultiplePruner(BasePruner):
|
||||
class MultiplePruner(BaseStructuredSparsifier):
|
||||
def update_mask(self, module, tensor_name, **kwargs):
|
||||
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.update([1, 2])
|
||||
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
|
||||
getattr(module.parametrizations, tensor_name)[0].mask[2] = False
|
||||
|
||||
|
||||
class TestBasePruner(TestCase):
|
||||
class TestBaseStructuredSparsifier(TestCase):
|
||||
def _check_pruner_prepared(self, model, pruner, device):
|
||||
for config in pruner.groups:
|
||||
modules = []
|
||||
if type(config['module']) is tuple:
|
||||
for module in config['module']:
|
||||
modules.append(module)
|
||||
else:
|
||||
module = config['module']
|
||||
modules.append(module)
|
||||
for module in modules:
|
||||
assert module.weight.device.type == device.type
|
||||
# Check mask exists
|
||||
assert hasattr(module, 'mask')
|
||||
# Check parametrization exists and is correct
|
||||
assert parametrize.is_parametrized(module)
|
||||
assert hasattr(module, "parametrizations")
|
||||
# Assume that this is the 1st/only parametrization
|
||||
if isinstance(module, tuple(NEEDS_ZEROS)):
|
||||
assert type(module.parametrizations.weight[0]) == ZeroesParametrization
|
||||
else:
|
||||
assert type(module.parametrizations.weight[0]) == PruningParametrization
|
||||
module = config["module"]
|
||||
assert module.weight.device.type == device.type
|
||||
# Check mask exists
|
||||
assert config["tensor_fqn"] in pruner.state
|
||||
# Check parametrization exists and is correct
|
||||
assert parametrize.is_parametrized(module)
|
||||
assert hasattr(module, "parametrizations")
|
||||
# Assume that this is the 1st/only parametrization
|
||||
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
|
||||
|
||||
def _check_pruner_mask_squashed(self, model, pruner, device):
|
||||
for config in pruner.groups:
|
||||
|
|
@ -222,7 +192,6 @@ class TestBasePruner(TestCase):
|
|||
for module in modules:
|
||||
assert module.weight.device.type == device.type
|
||||
assert not hasattr(module, "parametrizations")
|
||||
assert not hasattr(module, 'mask')
|
||||
|
||||
def _check_pruner_valid_before_step(self, model, pruner, device):
|
||||
for config in pruner.groups:
|
||||
|
|
@ -235,9 +204,9 @@ class TestBasePruner(TestCase):
|
|||
modules.append(module)
|
||||
for module in modules:
|
||||
assert module.weight.device.type == device.type
|
||||
assert module.parametrizations.weight[0].pruned_outputs == set()
|
||||
assert module.parametrizations.weight[0].mask.dtype == torch.bool
|
||||
|
||||
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
|
||||
def _check_pruner_valid_after_step(self, model, pruner, mask, device):
|
||||
for config in pruner.groups:
|
||||
modules = []
|
||||
if type(config['module']) is tuple:
|
||||
|
|
@ -248,11 +217,12 @@ class TestBasePruner(TestCase):
|
|||
modules.append(module)
|
||||
for module in modules:
|
||||
assert module.weight.device.type == device.type
|
||||
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
|
||||
total = module.parametrizations.weight[0].mask.numel()
|
||||
assert module.parametrizations.weight[0].mask.count_nonzero() == total - mask
|
||||
|
||||
def _test_constructor_on_device(self, model, device):
|
||||
self.assertRaisesRegex(TypeError, 'BasePruner .* update_mask',
|
||||
BasePruner)
|
||||
self.assertRaisesRegex(TypeError, 'BaseStructuredSparsifier.* update_mask',
|
||||
BaseStructuredSparsifier)
|
||||
model1 = copy.deepcopy(model).to(device)
|
||||
pruner = SimplePruner(None)
|
||||
pruner.prepare(model1, None)
|
||||
|
|
@ -264,7 +234,7 @@ class TestBasePruner(TestCase):
|
|||
# Can instantiate the model with configs
|
||||
model2 = copy.deepcopy(model).to(device)
|
||||
pruner = SimplePruner({'test': 3})
|
||||
pruner.prepare(model2, [model2.linear])
|
||||
pruner.prepare(model2, [{"tensor_fqn": "linear.weight"}])
|
||||
assert len(pruner.groups) == 1
|
||||
assert pruner.groups[0]['module_fqn'] == 'linear'
|
||||
assert 'test' in pruner.groups[0]
|
||||
|
|
@ -297,11 +267,9 @@ class TestBasePruner(TestCase):
|
|||
assert model(x).shape == (1, 64, 24, 24)
|
||||
|
||||
def test_prepare_conv2d(self):
|
||||
bn_model = Conv2dBN()
|
||||
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
|
||||
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
|
||||
configs = [None, None, None, bn_config]
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC()]
|
||||
configs = [None, None, None]
|
||||
for device in DEVICES:
|
||||
for model, config in zip(models, configs):
|
||||
model = model.to(device)
|
||||
|
|
@ -332,11 +300,9 @@ class TestBasePruner(TestCase):
|
|||
assert model(x).shape == (1, 64, 24, 24)
|
||||
|
||||
def test_squash_mask_conv2d(self):
|
||||
bn_model = Conv2dBN()
|
||||
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
|
||||
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
|
||||
configs = [None, None, None, bn_config]
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC()]
|
||||
configs = [None, None, None]
|
||||
for device in DEVICES:
|
||||
for model, config in zip(models, configs):
|
||||
model = model.to(device)
|
||||
|
|
@ -345,19 +311,19 @@ class TestBasePruner(TestCase):
|
|||
def _test_step_linear_on_device(self, model, is_basic, device):
|
||||
model = model.to(device)
|
||||
if is_basic:
|
||||
x = torch.ones(16, 16)
|
||||
x = torch.ones(16, 16, device=device)
|
||||
pruner = SimplePruner(None)
|
||||
pruner.prepare(model, None)
|
||||
self._check_pruner_valid_before_step(model, pruner, device)
|
||||
pruner.step()
|
||||
self._check_pruner_valid_after_step(model, pruner, {1}, device)
|
||||
self._check_pruner_valid_after_step(model, pruner, 1, device)
|
||||
else:
|
||||
x = torch.ones(7, 7)
|
||||
x = torch.ones(7, 7, device=device)
|
||||
pruner = MultiplePruner(None)
|
||||
pruner.prepare(model, None)
|
||||
self._check_pruner_valid_before_step(model, pruner, device)
|
||||
pruner.step()
|
||||
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
|
||||
self._check_pruner_valid_after_step(model, pruner, 2, device)
|
||||
|
||||
def test_step_linear(self):
|
||||
basic_models = [Linear(), LinearB()]
|
||||
|
|
@ -375,20 +341,13 @@ class TestBasePruner(TestCase):
|
|||
pruner.prepare(model, config)
|
||||
self._check_pruner_valid_before_step(model, pruner, device)
|
||||
pruner.step()
|
||||
if type(model) is Conv2dBN:
|
||||
assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0])
|
||||
assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d)
|
||||
self._check_pruner_valid_after_step(model, pruner, {1}, device)
|
||||
self._check_pruner_valid_after_step(model, pruner, 1, device)
|
||||
assert model(x).shape == (1, 64, 24, 24)
|
||||
|
||||
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
|
||||
def test_step_conv2d(self):
|
||||
bn_model = Conv2dBN()
|
||||
bn_config = [(bn_model.seq[0], bn_model.seq[1]),
|
||||
(bn_model.conv2d, bn_model.bn)]
|
||||
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
|
||||
configs = [None, None, None, None, bn_config]
|
||||
models = [Conv2dA(), Conv2dB(), Conv2dC()]
|
||||
configs = [None, None, None, None]
|
||||
for device in DEVICES:
|
||||
for model, config in zip(models, configs):
|
||||
self._test_step_conv2d_on_device(model, config, torch.device(device))
|
||||
|
|
@ -14,9 +14,7 @@ from ao.sparsity.test_parametrization import TestFakeSparsity # noqa: F401
|
|||
from ao.sparsity.test_sparsifier import TestBaseSparsifier # noqa: F401
|
||||
from ao.sparsity.test_sparsifier import TestWeightNormSparsifier # noqa: F401
|
||||
from ao.sparsity.test_sparsifier import TestNearlyDiagonalSparsifier # noqa: F401
|
||||
|
||||
# Pruner
|
||||
from ao.sparsity.test_pruner import TestBasePruner # noqa: F401
|
||||
from ao.sparsity.test_structured_sparsifier import TestBaseStructuredSparsifier # noqa: F401
|
||||
|
||||
# Scheduler
|
||||
from ao.sparsity.test_scheduler import TestScheduler # noqa: F401
|
||||
|
|
|
|||
|
|
@ -1,15 +1,11 @@
|
|||
from .base_pruner import BasePruner
|
||||
from .base_structured_sparsifier import BaseStructuredSparsifier
|
||||
from .parametrization import (
|
||||
ActivationReconstruction,
|
||||
FakeStructuredSparsity,
|
||||
BiasHook,
|
||||
PruningParametrization,
|
||||
ZeroesParametrization,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ActivationReconstruction",
|
||||
"BasePruner",
|
||||
"FakeStructuredSparsity",
|
||||
"BaseStructuredSparsifier",
|
||||
"BiasHook",
|
||||
"PruningParametrization",
|
||||
"ZeroesParametrization",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,247 +0,0 @@
|
|||
|
||||
import copy
|
||||
import warnings
|
||||
import abc
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from torch.nn.modules.container import ModuleDict, ModuleList
|
||||
|
||||
from .parametrization import PruningParametrization, ZeroesParametrization, ActivationReconstruction, BiasHook
|
||||
|
||||
from torch.ao.pruning import BaseSparsifier, module_to_fqn, fqn_to_module
|
||||
from torch.ao.pruning.sparsifier.utils import get_arg_info_from_tensor_fqn
|
||||
|
||||
__all__ = ["BasePruner"]
|
||||
|
||||
SUPPORTED_MODULES = { # added to config if None given
|
||||
nn.Linear,
|
||||
nn.Conv2d,
|
||||
nn.BatchNorm2d, # will need manual update to match conv2d
|
||||
}
|
||||
|
||||
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
|
||||
nn.BatchNorm2d
|
||||
}
|
||||
|
||||
class BasePruner(BaseSparsifier):
|
||||
r"""Base class for all pruners.
|
||||
|
||||
Abstract methods that need to be implemented:
|
||||
|
||||
- update_mask: Function to compute a new mask for all keys in the
|
||||
`groups` attribute.
|
||||
|
||||
Args:
|
||||
- defaults [dict]: default configurations will be attached to the
|
||||
configuration. Only the keys that don't exist in the `config` will
|
||||
be updated.
|
||||
- also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel)
|
||||
or not; default=True.
|
||||
|
||||
"""
|
||||
def __init__(self, defaults, also_prune_bias=True):
|
||||
super().__init__(defaults)
|
||||
self.prune_bias = also_prune_bias
|
||||
|
||||
def _get_modules_and_tensor_names(self, config, use_path):
|
||||
modules = []
|
||||
tensor_names = []
|
||||
if use_path:
|
||||
if type(config['module']) is tuple: # (Conv2d, BN)
|
||||
for module_fqn, tensor_name in zip(config['module_fqn'], config['tensor_name']):
|
||||
module = fqn_to_module(self.model, module_fqn)
|
||||
modules.append(module)
|
||||
tensor_names.append(tensor_name)
|
||||
else:
|
||||
module = fqn_to_module(self.model, config['module_fqn'])
|
||||
modules.append(module)
|
||||
tensor_name = config['tensor_name']
|
||||
tensor_names.append(tensor_name)
|
||||
|
||||
else:
|
||||
if type(config['module']) is tuple:
|
||||
for module, tensor_name in zip(config['module'], config['tensor_name']):
|
||||
modules.append(module)
|
||||
tensor_names.append(tensor_name)
|
||||
else:
|
||||
module = config['module']
|
||||
modules.append(module)
|
||||
tensor_name = config['tensor_name']
|
||||
tensor_names.append(tensor_name)
|
||||
return modules, tensor_names
|
||||
|
||||
def _prepare(self, use_path=False, *args, **kwargs):
|
||||
r"""Adds mask parametrization to the layer weight
|
||||
"""
|
||||
self.activation_handles = [] # store removable hook handles
|
||||
self.bias_handles = []
|
||||
|
||||
for config in self.groups:
|
||||
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
|
||||
|
||||
for module, tensor_name in zip(modules, tensor_names):
|
||||
if not isinstance(module, tuple(NEEDS_ZEROS)):
|
||||
# add pruning parametrization and forward hooks
|
||||
if getattr(module, 'mask', None) is None:
|
||||
module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0]))
|
||||
param = config.get('parametrization', PruningParametrization)
|
||||
parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True)
|
||||
|
||||
assert isinstance(module.parametrizations, ModuleDict) # make mypy happy
|
||||
assert isinstance(module.parametrizations.weight, ModuleList)
|
||||
if isinstance(module, tuple(SUPPORTED_MODULES)):
|
||||
self.activation_handles.append(module.register_forward_hook(
|
||||
ActivationReconstruction(getattr(module.parametrizations, tensor_name)[0])
|
||||
))
|
||||
else:
|
||||
raise NotImplementedError("This module type is not supported yet.")
|
||||
|
||||
else: # needs zeros
|
||||
if getattr(module, 'mask', None) is None:
|
||||
module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0]))
|
||||
param = config.get('parametrization', ZeroesParametrization)
|
||||
parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True)
|
||||
|
||||
if module.bias is not None:
|
||||
module.register_parameter('_bias', nn.Parameter(module.bias.detach()))
|
||||
module.bias = None
|
||||
self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias)))
|
||||
|
||||
if len(modules) == 2: # (Conv2d, BN)
|
||||
# should have the same set of pruned outputs
|
||||
modules[1].parametrizations.weight[0].pruned_outputs = modules[0].parametrizations.weight[0].pruned_outputs
|
||||
|
||||
def make_config_from_model(self, model, SUPPORTED_MODULES=SUPPORTED_MODULES, NEEDS_ZEROS=NEEDS_ZEROS):
|
||||
self.config = []
|
||||
stack = [model]
|
||||
while stack:
|
||||
module = stack.pop()
|
||||
for name, child in module.named_children():
|
||||
if type(child) in SUPPORTED_MODULES:
|
||||
child_fqn = module_to_fqn(model, child)
|
||||
assert isinstance(child_fqn, str) # for mypy
|
||||
self.config.append({'tensor_fqn': child_fqn + '.weight'})
|
||||
else:
|
||||
if NEEDS_ZEROS is not None and type(child) in NEEDS_ZEROS and hasattr(self, "prune_bias") and self.prune_bias:
|
||||
# only useful for Pruner
|
||||
warnings.warn(f"Models with {type(child)} layers have config provided by user.")
|
||||
stack.append(child)
|
||||
|
||||
def prepare(self, model, config):
|
||||
r"""Prepares a model, by adding the parametrizations and forward post-hooks.
|
||||
Note::
|
||||
The model is modified inplace. If you need to preserve the original
|
||||
model, use copy.deepcopy.
|
||||
|
||||
Args:
|
||||
- model [nn.Module]: model to configure. The model itself is not saved
|
||||
but used for the state_dict saving / loading.
|
||||
- config [list]: configuration elements could either be instances of
|
||||
tuples of dict maps or dict maps. The dicts must have a key 'tensor_fqn' with the
|
||||
value being the fqn of the tensor to be pruned.
|
||||
"""
|
||||
self.model = model # TODO: Need to figure out how to load without this.
|
||||
self.config = config
|
||||
|
||||
# If no config -- try getting all the supported layers
|
||||
if self.config is None:
|
||||
# Add all models to the config
|
||||
self.make_config_from_model(self.model)
|
||||
|
||||
for module_config in self.config:
|
||||
if type(module_config) is tuple:
|
||||
first_layer, next_layer = module_config
|
||||
assert isinstance(first_layer, nn.Conv2d) and isinstance(next_layer, nn.BatchNorm2d)
|
||||
assert isinstance(module_config, tuple) # for mypy
|
||||
module_config = {'module': module_config}
|
||||
local_args = copy.deepcopy(self.defaults)
|
||||
local_args.update(module_config)
|
||||
module_fqn_list = []
|
||||
tensor_fqn_list = []
|
||||
tensor_name_list = []
|
||||
for module in local_args['module']:
|
||||
module_fqn = module_to_fqn(model, module)
|
||||
if module_fqn is None:
|
||||
module_fqn = ''
|
||||
if module_fqn and module_fqn[0] == '.':
|
||||
module_fqn = module_fqn[1:]
|
||||
module_fqn_list.append(module_fqn)
|
||||
tensor_fqn_list.append(module_fqn + '.weight')
|
||||
tensor_name_list.append('weight')
|
||||
|
||||
local_args['module_fqn'] = module_fqn_list
|
||||
local_args['tensor_fqn'] = tensor_fqn_list
|
||||
local_args['tensor_name'] = tensor_name_list
|
||||
else:
|
||||
if isinstance(module_config, nn.Module):
|
||||
module_config = {'module': module_config} # type: ignore[dict-item]
|
||||
|
||||
local_args = copy.deepcopy(self.defaults)
|
||||
local_args.update(module_config)
|
||||
|
||||
# now that we're working with a dict, does it have the new format?
|
||||
if local_args.get('tensor_fqn', None) is not None:
|
||||
tensor_fqn = local_args.get('tensor_fqn')
|
||||
assert isinstance(tensor_fqn, str) # for mypy
|
||||
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
|
||||
|
||||
for key in info_from_tensor_fqn.keys():
|
||||
if key in local_args:
|
||||
# info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
|
||||
assert key == 'tensor_fqn' or info_from_tensor_fqn[key] == local_args[key], (
|
||||
"Given both `{}` and `tensor_fqn`, it is expected them to "
|
||||
"agree!".format(key)
|
||||
)
|
||||
local_args.update(info_from_tensor_fqn)
|
||||
else:
|
||||
module = local_args['module']
|
||||
module_fqn = module_to_fqn(model, module)
|
||||
if module_fqn and module_fqn[0] == '.':
|
||||
module_fqn = module_fqn[1:]
|
||||
local_args['module_fqn'] = module_fqn
|
||||
local_args['tensor_name'] = "weight"
|
||||
assert isinstance(module_fqn, str) # for mypy
|
||||
local_args['tensor_fqn'] = module_fqn + ".weight"
|
||||
self.groups.append(local_args)
|
||||
|
||||
self._prepare()
|
||||
|
||||
def squash_mask(self, use_path=False, *args, **kwargs):
|
||||
for config in self.groups:
|
||||
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
|
||||
|
||||
for module, tensor_name in zip(modules, tensor_names):
|
||||
parametrize.remove_parametrizations(module, tensor_name,
|
||||
leave_parametrized=True)
|
||||
if getattr(module._parameters, 'mask', None):
|
||||
del module._parameters['mask']
|
||||
elif getattr(module._buffers, 'mask', None):
|
||||
del module._buffers['mask']
|
||||
delattr(module, 'mask')
|
||||
|
||||
def get_module_pruned_outputs(self, module, tensor_name='weight'):
|
||||
r"""Returns the set of pruned indices of module"""
|
||||
assert parametrize.is_parametrized(module) # can only get pruned indices of pruned module
|
||||
return getattr(module.parametrizations, tensor_name)[0].pruned_outputs # assume only one parametrization attached
|
||||
|
||||
def step(self, use_path=False):
|
||||
if not self.enable_mask_update:
|
||||
return
|
||||
with torch.no_grad():
|
||||
for config in self.groups:
|
||||
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
|
||||
|
||||
untupled_args: dict = {}
|
||||
untupled_args.update()
|
||||
# only need to update the first module in modules if len(modules) > 1
|
||||
# since they should share the same set of pruned outputs
|
||||
untupled_args['module'] = modules[0]
|
||||
untupled_args['tensor_name'] = tensor_names[0]
|
||||
self.update_mask(**config)
|
||||
|
||||
@abc.abstractmethod
|
||||
def update_mask(self, module, tensor_name, **kwargs):
|
||||
pass
|
||||
|
|
@ -0,0 +1,74 @@
|
|||
from typing import Set, Type
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.utils import parametrize
|
||||
|
||||
from torch.ao.pruning import BaseSparsifier
|
||||
from .parametrization import FakeStructuredSparsity, BiasHook
|
||||
|
||||
__all__ = ["BaseStructuredSparsifier"]
|
||||
|
||||
SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given
|
||||
nn.Linear,
|
||||
nn.Conv2d,
|
||||
}
|
||||
|
||||
|
||||
class BaseStructuredSparsifier(BaseSparsifier):
|
||||
r"""Base class for structured pruning.
|
||||
|
||||
Abstract methods that need to be implemented:
|
||||
- update_mask: Function to compute a new mask for all keys in the
|
||||
`groups` attribute.
|
||||
|
||||
Args:
|
||||
- defaults [dict]: default configurations will be attached to the
|
||||
configuration. Only the keys that don't exist in the `config` will
|
||||
be updated.
|
||||
"""
|
||||
|
||||
def __init__(self, defaults):
|
||||
super().__init__(defaults)
|
||||
|
||||
def make_config_from_model(
|
||||
self,
|
||||
model: nn.Module,
|
||||
SUPPORTED_MODULES: Set[Type] = SUPPORTED_STRUCTURED_PRUNING_MODULES,
|
||||
) -> None:
|
||||
super().make_config_from_model(
|
||||
model, SUPPORTED_MODULES=SUPPORTED_STRUCTURED_PRUNING_MODULES
|
||||
)
|
||||
|
||||
def _prepare(self, *args, **kwargs) -> None:
|
||||
r"""This function will attach the FakeStructuredSparsity parameterizations
|
||||
and BiasHooks at the appropriate points in the model.
|
||||
"""
|
||||
self.bias_handles = []
|
||||
|
||||
for config in self.groups:
|
||||
module = config["module"]
|
||||
tensor_name = config["tensor_name"]
|
||||
parametrization = config.get("parametrization", FakeStructuredSparsity)
|
||||
tensor = getattr(module, tensor_name)
|
||||
|
||||
mask = config.get(
|
||||
"mask",
|
||||
torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device),
|
||||
)
|
||||
self.state[config["tensor_fqn"]]["mask"] = mask
|
||||
parametrize.register_parametrization(
|
||||
module, tensor_name, parametrization(mask), unsafe=True
|
||||
)
|
||||
|
||||
prune_bias = config.get("prune_bias", True)
|
||||
if prune_bias and module.bias is not None:
|
||||
module.register_parameter("_bias", nn.Parameter(module.bias.detach()))
|
||||
module.bias = None
|
||||
self.bias_handles.append(
|
||||
module.register_forward_hook(
|
||||
BiasHook(module.parametrizations.weight[0], prune_bias)
|
||||
)
|
||||
)
|
||||
|
||||
def convert(self):
|
||||
pass
|
||||
|
|
@ -1,72 +1,45 @@
|
|||
import torch
|
||||
from torch import nn
|
||||
from typing import Any, List
|
||||
|
||||
__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook']
|
||||
__all__ = ['FakeStructuredSparsity', 'BiasHook']
|
||||
|
||||
class PruningParametrization(nn.Module):
|
||||
def __init__(self, original_outputs):
|
||||
|
||||
# Structured Pruning Parameterizations
|
||||
class FakeStructuredSparsity(nn.Module):
|
||||
r"""
|
||||
Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to
|
||||
the 'weight' or any other parameter that requires a mask.
|
||||
|
||||
Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask.
|
||||
"""
|
||||
|
||||
def __init__(self, mask):
|
||||
super().__init__()
|
||||
self.original_outputs = set(range(original_outputs.item()))
|
||||
self.pruned_outputs = set() # Will contain indicies of outputs to prune
|
||||
self.register_buffer("mask", mask)
|
||||
|
||||
def forward(self, x):
|
||||
valid_outputs = self.original_outputs - self.pruned_outputs
|
||||
return x[list(valid_outputs)]
|
||||
|
||||
|
||||
class ZeroesParametrization(nn.Module):
|
||||
r"""Zero out pruned channels instead of removing.
|
||||
E.g. used for Batch Norm pruning, which should match previous Conv2d layer."""
|
||||
def __init__(self, original_outputs):
|
||||
super().__init__()
|
||||
self.original_outputs = set(range(original_outputs.item()))
|
||||
self.pruned_outputs = set() # Will contain indicies of outputs to prune
|
||||
|
||||
def forward(self, x):
|
||||
x.data[list(self.pruned_outputs)] = 0
|
||||
return x
|
||||
|
||||
|
||||
class ActivationReconstruction:
|
||||
def __init__(self, parametrization):
|
||||
self.param = parametrization
|
||||
|
||||
def __call__(self, module, input, output):
|
||||
max_outputs = self.param.original_outputs
|
||||
pruned_outputs = self.param.pruned_outputs
|
||||
valid_columns = list(max_outputs - pruned_outputs)
|
||||
|
||||
# get size of reconstructed output
|
||||
sizes = list(output.shape)
|
||||
sizes[1] = len(max_outputs)
|
||||
|
||||
# get valid indices of reconstructed output
|
||||
indices: List[Any] = []
|
||||
for size in output.shape:
|
||||
indices.append(slice(0, size, 1))
|
||||
indices[1] = valid_columns
|
||||
|
||||
reconstructed_tensor = torch.zeros(sizes,
|
||||
dtype=output.dtype,
|
||||
device=output.device,
|
||||
layout=output.layout)
|
||||
reconstructed_tensor[indices] = output
|
||||
return reconstructed_tensor
|
||||
assert isinstance(self.mask, torch.Tensor)
|
||||
assert self.mask.shape[0] == x.shape[0]
|
||||
shape = [1] * len(x.shape)
|
||||
shape[0] = -1
|
||||
return self.mask.reshape(shape) * x
|
||||
|
||||
def state_dict(self, *args, **kwargs):
|
||||
# avoid double saving masks
|
||||
return {}
|
||||
|
||||
class BiasHook:
|
||||
|
||||
def __init__(self, parametrization, prune_bias):
|
||||
self.param = parametrization
|
||||
self.prune_bias = prune_bias
|
||||
|
||||
def __call__(self, module, input, output):
|
||||
pruned_outputs = self.param.pruned_outputs
|
||||
|
||||
if getattr(module, '_bias', None) is not None:
|
||||
bias = module._bias.data
|
||||
if self.prune_bias:
|
||||
bias[list(pruned_outputs)] = 0
|
||||
bias[~self.param.mask] = 0
|
||||
|
||||
# reshape bias to broadcast over output dimensions
|
||||
idx = [1] * len(output.shape)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user