[pruning][core][feature] Align BaseStructuredPruner with existing pruning flow (#88436)

Summary:

This PR aligns the "eager" mode of the structured pruning flow with the existing unstructured pruning flow.

The base pruner has been moved from and has been renamed from BasePruner to BaseStructuredPruner
`torch/ao/pruning/_experimental/pruner/base_pruner.py -> /torch/ao/pruning/_experimental/pruner/base_structured_pruner.py`

Support for pruning batchnorm modules in the config have been removed, so now the structured pruning code can use more of the BaseSparsifier logic and we don't need to override as many functions.

Since we aim to only support a single flow, we have only updated ZeroesParametrizations (FakeStructuredSparsity) and BiasHook.
The parameterizations have also been rewritten to use a bool mask tensor for keeping track of pruned rows, instead of using sets before.
This better aligns structured and unstructured sparsity.

The BaseStructuredSparsifier tests have also been updated to reflect the above changes. I also removed `squash_mask` tests because they were breaking CI and `squash_mask` is no longer used.

We will migrate the structured pruning code out of this folder in a later PR.

Test Plan:
```
python test/test_ao_sparsity -- TestBaseStructuredPruner
```

Reviewers:
z-a-f vkuzo

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88436
Approved by: https://github.com/vkuzo
This commit is contained in:
Jesse Cai 2022-12-02 13:44:13 -08:00 committed by PyTorch MergeBot
parent d3f20a20b8
commit 9a1c6fd506
6 changed files with 136 additions and 383 deletions

View File

@ -7,7 +7,7 @@ import logging
import torch
from torch import nn
from torch.ao.pruning._experimental.pruner import BasePruner, PruningParametrization, ZeroesParametrization
from torch.ao.pruning._experimental.pruner import BaseStructuredSparsifier, FakeStructuredSparsity
from torch.nn.utils import parametrize
from torch.testing._internal.common_utils import TestCase, skipIfTorchDynamo
@ -19,10 +19,6 @@ DEVICES = {
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
}
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
nn.BatchNorm2d
}
class Linear(nn.Module):
r"""Model with Linear layers, in Sequential and outside, without biases"""
@ -159,56 +155,30 @@ class Conv2dC(nn.Module):
return x
class Conv2dBN(nn.Module):
r"""Model with Conv2d layers and BatchNorms"""
def __init__(self):
super().__init__()
self.seq = nn.Sequential(
nn.Conv2d(1, 32, 3, 1, bias=True),
nn.BatchNorm2d(32)
)
self.conv2d = nn.Conv2d(32, 64, 3, 1, bias=True)
self.bn = nn.BatchNorm2d(64)
def forward(self, x):
x = self.seq(x)
x = self.conv2d(x)
x = self.bn(x)
return x
class SimplePruner(BasePruner):
class SimplePruner(BaseStructuredSparsifier):
def update_mask(self, module, tensor_name, **kwargs):
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.add(1)
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
class MultiplePruner(BasePruner):
class MultiplePruner(BaseStructuredSparsifier):
def update_mask(self, module, tensor_name, **kwargs):
getattr(module.parametrizations, tensor_name)[0].pruned_outputs.update([1, 2])
getattr(module.parametrizations, tensor_name)[0].mask[1] = False
getattr(module.parametrizations, tensor_name)[0].mask[2] = False
class TestBasePruner(TestCase):
class TestBaseStructuredSparsifier(TestCase):
def _check_pruner_prepared(self, model, pruner, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
for module in config['module']:
modules.append(module)
else:
module = config['module']
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
# Check mask exists
assert hasattr(module, 'mask')
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
if isinstance(module, tuple(NEEDS_ZEROS)):
assert type(module.parametrizations.weight[0]) == ZeroesParametrization
else:
assert type(module.parametrizations.weight[0]) == PruningParametrization
module = config["module"]
assert module.weight.device.type == device.type
# Check mask exists
assert config["tensor_fqn"] in pruner.state
# Check parametrization exists and is correct
assert parametrize.is_parametrized(module)
assert hasattr(module, "parametrizations")
# Assume that this is the 1st/only parametrization
assert type(module.parametrizations.weight[0]) == FakeStructuredSparsity
def _check_pruner_mask_squashed(self, model, pruner, device):
for config in pruner.groups:
@ -222,7 +192,6 @@ class TestBasePruner(TestCase):
for module in modules:
assert module.weight.device.type == device.type
assert not hasattr(module, "parametrizations")
assert not hasattr(module, 'mask')
def _check_pruner_valid_before_step(self, model, pruner, device):
for config in pruner.groups:
@ -235,9 +204,9 @@ class TestBasePruner(TestCase):
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == set()
assert module.parametrizations.weight[0].mask.dtype == torch.bool
def _check_pruner_valid_after_step(self, model, pruner, pruned_set, device):
def _check_pruner_valid_after_step(self, model, pruner, mask, device):
for config in pruner.groups:
modules = []
if type(config['module']) is tuple:
@ -248,11 +217,12 @@ class TestBasePruner(TestCase):
modules.append(module)
for module in modules:
assert module.weight.device.type == device.type
assert module.parametrizations.weight[0].pruned_outputs == pruned_set
total = module.parametrizations.weight[0].mask.numel()
assert module.parametrizations.weight[0].mask.count_nonzero() == total - mask
def _test_constructor_on_device(self, model, device):
self.assertRaisesRegex(TypeError, 'BasePruner .* update_mask',
BasePruner)
self.assertRaisesRegex(TypeError, 'BaseStructuredSparsifier.* update_mask',
BaseStructuredSparsifier)
model1 = copy.deepcopy(model).to(device)
pruner = SimplePruner(None)
pruner.prepare(model1, None)
@ -264,7 +234,7 @@ class TestBasePruner(TestCase):
# Can instantiate the model with configs
model2 = copy.deepcopy(model).to(device)
pruner = SimplePruner({'test': 3})
pruner.prepare(model2, [model2.linear])
pruner.prepare(model2, [{"tensor_fqn": "linear.weight"}])
assert len(pruner.groups) == 1
assert pruner.groups[0]['module_fqn'] == 'linear'
assert 'test' in pruner.groups[0]
@ -297,11 +267,9 @@ class TestBasePruner(TestCase):
assert model(x).shape == (1, 64, 24, 24)
def test_prepare_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
models = [Conv2dA(), Conv2dB(), Conv2dC()]
configs = [None, None, None]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
@ -332,11 +300,9 @@ class TestBasePruner(TestCase):
assert model(x).shape == (1, 64, 24, 24)
def test_squash_mask_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]), (bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, bn_config]
models = [Conv2dA(), Conv2dB(), Conv2dC()]
configs = [None, None, None]
for device in DEVICES:
for model, config in zip(models, configs):
model = model.to(device)
@ -345,19 +311,19 @@ class TestBasePruner(TestCase):
def _test_step_linear_on_device(self, model, is_basic, device):
model = model.to(device)
if is_basic:
x = torch.ones(16, 16)
x = torch.ones(16, 16, device=device)
pruner = SimplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1}, device)
self._check_pruner_valid_after_step(model, pruner, 1, device)
else:
x = torch.ones(7, 7)
x = torch.ones(7, 7, device=device)
pruner = MultiplePruner(None)
pruner.prepare(model, None)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
self._check_pruner_valid_after_step(model, pruner, {1, 2}, device)
self._check_pruner_valid_after_step(model, pruner, 2, device)
def test_step_linear(self):
basic_models = [Linear(), LinearB()]
@ -375,20 +341,13 @@ class TestBasePruner(TestCase):
pruner.prepare(model, config)
self._check_pruner_valid_before_step(model, pruner, device)
pruner.step()
if type(model) is Conv2dBN:
assert pruner.get_module_pruned_outputs(model.seq[1]) == pruner.get_module_pruned_outputs(model.seq[0])
assert pruner.get_module_pruned_outputs(model.bn) == pruner.get_module_pruned_outputs(model.conv2d)
self._check_pruner_valid_after_step(model, pruner, {1}, device)
self._check_pruner_valid_after_step(model, pruner, 1, device)
assert model(x).shape == (1, 64, 24, 24)
@skipIfTorchDynamo("TorchDynamo fails with unknown reason")
def test_step_conv2d(self):
bn_model = Conv2dBN()
bn_config = [(bn_model.seq[0], bn_model.seq[1]),
(bn_model.conv2d, bn_model.bn)]
models = [Conv2dA(), Conv2dB(), Conv2dC(), bn_model]
configs = [None, None, None, None, bn_config]
models = [Conv2dA(), Conv2dB(), Conv2dC()]
configs = [None, None, None, None]
for device in DEVICES:
for model, config in zip(models, configs):
self._test_step_conv2d_on_device(model, config, torch.device(device))

View File

@ -14,9 +14,7 @@ from ao.sparsity.test_parametrization import TestFakeSparsity # noqa: F401
from ao.sparsity.test_sparsifier import TestBaseSparsifier # noqa: F401
from ao.sparsity.test_sparsifier import TestWeightNormSparsifier # noqa: F401
from ao.sparsity.test_sparsifier import TestNearlyDiagonalSparsifier # noqa: F401
# Pruner
from ao.sparsity.test_pruner import TestBasePruner # noqa: F401
from ao.sparsity.test_structured_sparsifier import TestBaseStructuredSparsifier # noqa: F401
# Scheduler
from ao.sparsity.test_scheduler import TestScheduler # noqa: F401

View File

@ -1,15 +1,11 @@
from .base_pruner import BasePruner
from .base_structured_sparsifier import BaseStructuredSparsifier
from .parametrization import (
ActivationReconstruction,
FakeStructuredSparsity,
BiasHook,
PruningParametrization,
ZeroesParametrization,
)
__all__ = [
"ActivationReconstruction",
"BasePruner",
"FakeStructuredSparsity",
"BaseStructuredSparsifier",
"BiasHook",
"PruningParametrization",
"ZeroesParametrization",
]

View File

@ -1,247 +0,0 @@
import copy
import warnings
import abc
import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.nn.modules.container import ModuleDict, ModuleList
from .parametrization import PruningParametrization, ZeroesParametrization, ActivationReconstruction, BiasHook
from torch.ao.pruning import BaseSparsifier, module_to_fqn, fqn_to_module
from torch.ao.pruning.sparsifier.utils import get_arg_info_from_tensor_fqn
__all__ = ["BasePruner"]
SUPPORTED_MODULES = { # added to config if None given
nn.Linear,
nn.Conv2d,
nn.BatchNorm2d, # will need manual update to match conv2d
}
NEEDS_ZEROS = { # these layers should have pruned indices zero-ed, not removed
nn.BatchNorm2d
}
class BasePruner(BaseSparsifier):
r"""Base class for all pruners.
Abstract methods that need to be implemented:
- update_mask: Function to compute a new mask for all keys in the
`groups` attribute.
Args:
- defaults [dict]: default configurations will be attached to the
configuration. Only the keys that don't exist in the `config` will
be updated.
- also_prune_bias [bool]: whether to prune bias in addition to weights (to prune full output channel)
or not; default=True.
"""
def __init__(self, defaults, also_prune_bias=True):
super().__init__(defaults)
self.prune_bias = also_prune_bias
def _get_modules_and_tensor_names(self, config, use_path):
modules = []
tensor_names = []
if use_path:
if type(config['module']) is tuple: # (Conv2d, BN)
for module_fqn, tensor_name in zip(config['module_fqn'], config['tensor_name']):
module = fqn_to_module(self.model, module_fqn)
modules.append(module)
tensor_names.append(tensor_name)
else:
module = fqn_to_module(self.model, config['module_fqn'])
modules.append(module)
tensor_name = config['tensor_name']
tensor_names.append(tensor_name)
else:
if type(config['module']) is tuple:
for module, tensor_name in zip(config['module'], config['tensor_name']):
modules.append(module)
tensor_names.append(tensor_name)
else:
module = config['module']
modules.append(module)
tensor_name = config['tensor_name']
tensor_names.append(tensor_name)
return modules, tensor_names
def _prepare(self, use_path=False, *args, **kwargs):
r"""Adds mask parametrization to the layer weight
"""
self.activation_handles = [] # store removable hook handles
self.bias_handles = []
for config in self.groups:
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
for module, tensor_name in zip(modules, tensor_names):
if not isinstance(module, tuple(NEEDS_ZEROS)):
# add pruning parametrization and forward hooks
if getattr(module, 'mask', None) is None:
module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0]))
param = config.get('parametrization', PruningParametrization)
parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True)
assert isinstance(module.parametrizations, ModuleDict) # make mypy happy
assert isinstance(module.parametrizations.weight, ModuleList)
if isinstance(module, tuple(SUPPORTED_MODULES)):
self.activation_handles.append(module.register_forward_hook(
ActivationReconstruction(getattr(module.parametrizations, tensor_name)[0])
))
else:
raise NotImplementedError("This module type is not supported yet.")
else: # needs zeros
if getattr(module, 'mask', None) is None:
module.register_buffer('mask', torch.tensor(getattr(module, tensor_name).shape[0]))
param = config.get('parametrization', ZeroesParametrization)
parametrize.register_parametrization(module, tensor_name, param(module.mask), unsafe=True)
if module.bias is not None:
module.register_parameter('_bias', nn.Parameter(module.bias.detach()))
module.bias = None
self.bias_handles.append(module.register_forward_hook(BiasHook(module.parametrizations.weight[0], self.prune_bias)))
if len(modules) == 2: # (Conv2d, BN)
# should have the same set of pruned outputs
modules[1].parametrizations.weight[0].pruned_outputs = modules[0].parametrizations.weight[0].pruned_outputs
def make_config_from_model(self, model, SUPPORTED_MODULES=SUPPORTED_MODULES, NEEDS_ZEROS=NEEDS_ZEROS):
self.config = []
stack = [model]
while stack:
module = stack.pop()
for name, child in module.named_children():
if type(child) in SUPPORTED_MODULES:
child_fqn = module_to_fqn(model, child)
assert isinstance(child_fqn, str) # for mypy
self.config.append({'tensor_fqn': child_fqn + '.weight'})
else:
if NEEDS_ZEROS is not None and type(child) in NEEDS_ZEROS and hasattr(self, "prune_bias") and self.prune_bias:
# only useful for Pruner
warnings.warn(f"Models with {type(child)} layers have config provided by user.")
stack.append(child)
def prepare(self, model, config):
r"""Prepares a model, by adding the parametrizations and forward post-hooks.
Note::
The model is modified inplace. If you need to preserve the original
model, use copy.deepcopy.
Args:
- model [nn.Module]: model to configure. The model itself is not saved
but used for the state_dict saving / loading.
- config [list]: configuration elements could either be instances of
tuples of dict maps or dict maps. The dicts must have a key 'tensor_fqn' with the
value being the fqn of the tensor to be pruned.
"""
self.model = model # TODO: Need to figure out how to load without this.
self.config = config
# If no config -- try getting all the supported layers
if self.config is None:
# Add all models to the config
self.make_config_from_model(self.model)
for module_config in self.config:
if type(module_config) is tuple:
first_layer, next_layer = module_config
assert isinstance(first_layer, nn.Conv2d) and isinstance(next_layer, nn.BatchNorm2d)
assert isinstance(module_config, tuple) # for mypy
module_config = {'module': module_config}
local_args = copy.deepcopy(self.defaults)
local_args.update(module_config)
module_fqn_list = []
tensor_fqn_list = []
tensor_name_list = []
for module in local_args['module']:
module_fqn = module_to_fqn(model, module)
if module_fqn is None:
module_fqn = ''
if module_fqn and module_fqn[0] == '.':
module_fqn = module_fqn[1:]
module_fqn_list.append(module_fqn)
tensor_fqn_list.append(module_fqn + '.weight')
tensor_name_list.append('weight')
local_args['module_fqn'] = module_fqn_list
local_args['tensor_fqn'] = tensor_fqn_list
local_args['tensor_name'] = tensor_name_list
else:
if isinstance(module_config, nn.Module):
module_config = {'module': module_config} # type: ignore[dict-item]
local_args = copy.deepcopy(self.defaults)
local_args.update(module_config)
# now that we're working with a dict, does it have the new format?
if local_args.get('tensor_fqn', None) is not None:
tensor_fqn = local_args.get('tensor_fqn')
assert isinstance(tensor_fqn, str) # for mypy
info_from_tensor_fqn = get_arg_info_from_tensor_fqn(model, tensor_fqn)
for key in info_from_tensor_fqn.keys():
if key in local_args:
# info_from_tensor_fqn will chop leading '.' from tensor_fqn so ignore that
assert key == 'tensor_fqn' or info_from_tensor_fqn[key] == local_args[key], (
"Given both `{}` and `tensor_fqn`, it is expected them to "
"agree!".format(key)
)
local_args.update(info_from_tensor_fqn)
else:
module = local_args['module']
module_fqn = module_to_fqn(model, module)
if module_fqn and module_fqn[0] == '.':
module_fqn = module_fqn[1:]
local_args['module_fqn'] = module_fqn
local_args['tensor_name'] = "weight"
assert isinstance(module_fqn, str) # for mypy
local_args['tensor_fqn'] = module_fqn + ".weight"
self.groups.append(local_args)
self._prepare()
def squash_mask(self, use_path=False, *args, **kwargs):
for config in self.groups:
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
for module, tensor_name in zip(modules, tensor_names):
parametrize.remove_parametrizations(module, tensor_name,
leave_parametrized=True)
if getattr(module._parameters, 'mask', None):
del module._parameters['mask']
elif getattr(module._buffers, 'mask', None):
del module._buffers['mask']
delattr(module, 'mask')
def get_module_pruned_outputs(self, module, tensor_name='weight'):
r"""Returns the set of pruned indices of module"""
assert parametrize.is_parametrized(module) # can only get pruned indices of pruned module
return getattr(module.parametrizations, tensor_name)[0].pruned_outputs # assume only one parametrization attached
def step(self, use_path=False):
if not self.enable_mask_update:
return
with torch.no_grad():
for config in self.groups:
modules, tensor_names = self._get_modules_and_tensor_names(config, use_path)
untupled_args: dict = {}
untupled_args.update()
# only need to update the first module in modules if len(modules) > 1
# since they should share the same set of pruned outputs
untupled_args['module'] = modules[0]
untupled_args['tensor_name'] = tensor_names[0]
self.update_mask(**config)
@abc.abstractmethod
def update_mask(self, module, tensor_name, **kwargs):
pass

View File

@ -0,0 +1,74 @@
from typing import Set, Type
import torch
from torch import nn
from torch.nn.utils import parametrize
from torch.ao.pruning import BaseSparsifier
from .parametrization import FakeStructuredSparsity, BiasHook
__all__ = ["BaseStructuredSparsifier"]
SUPPORTED_STRUCTURED_PRUNING_MODULES = { # added to config if None given
nn.Linear,
nn.Conv2d,
}
class BaseStructuredSparsifier(BaseSparsifier):
r"""Base class for structured pruning.
Abstract methods that need to be implemented:
- update_mask: Function to compute a new mask for all keys in the
`groups` attribute.
Args:
- defaults [dict]: default configurations will be attached to the
configuration. Only the keys that don't exist in the `config` will
be updated.
"""
def __init__(self, defaults):
super().__init__(defaults)
def make_config_from_model(
self,
model: nn.Module,
SUPPORTED_MODULES: Set[Type] = SUPPORTED_STRUCTURED_PRUNING_MODULES,
) -> None:
super().make_config_from_model(
model, SUPPORTED_MODULES=SUPPORTED_STRUCTURED_PRUNING_MODULES
)
def _prepare(self, *args, **kwargs) -> None:
r"""This function will attach the FakeStructuredSparsity parameterizations
and BiasHooks at the appropriate points in the model.
"""
self.bias_handles = []
for config in self.groups:
module = config["module"]
tensor_name = config["tensor_name"]
parametrization = config.get("parametrization", FakeStructuredSparsity)
tensor = getattr(module, tensor_name)
mask = config.get(
"mask",
torch.ones(tensor.shape[0], dtype=torch.bool, device=tensor.device),
)
self.state[config["tensor_fqn"]]["mask"] = mask
parametrize.register_parametrization(
module, tensor_name, parametrization(mask), unsafe=True
)
prune_bias = config.get("prune_bias", True)
if prune_bias and module.bias is not None:
module.register_parameter("_bias", nn.Parameter(module.bias.detach()))
module.bias = None
self.bias_handles.append(
module.register_forward_hook(
BiasHook(module.parametrizations.weight[0], prune_bias)
)
)
def convert(self):
pass

View File

@ -1,72 +1,45 @@
import torch
from torch import nn
from typing import Any, List
__all__ = ['PruningParametrization', 'ZeroesParametrization', 'ActivationReconstruction', 'BiasHook']
__all__ = ['FakeStructuredSparsity', 'BiasHook']
class PruningParametrization(nn.Module):
def __init__(self, original_outputs):
# Structured Pruning Parameterizations
class FakeStructuredSparsity(nn.Module):
r"""
Parametrization for Structured Pruning. Like FakeSparsity, this should be attached to
the 'weight' or any other parameter that requires a mask.
Instead of an element-wise bool mask, this parameterization uses a row-wise bool mask.
"""
def __init__(self, mask):
super().__init__()
self.original_outputs = set(range(original_outputs.item()))
self.pruned_outputs = set() # Will contain indicies of outputs to prune
self.register_buffer("mask", mask)
def forward(self, x):
valid_outputs = self.original_outputs - self.pruned_outputs
return x[list(valid_outputs)]
class ZeroesParametrization(nn.Module):
r"""Zero out pruned channels instead of removing.
E.g. used for Batch Norm pruning, which should match previous Conv2d layer."""
def __init__(self, original_outputs):
super().__init__()
self.original_outputs = set(range(original_outputs.item()))
self.pruned_outputs = set() # Will contain indicies of outputs to prune
def forward(self, x):
x.data[list(self.pruned_outputs)] = 0
return x
class ActivationReconstruction:
def __init__(self, parametrization):
self.param = parametrization
def __call__(self, module, input, output):
max_outputs = self.param.original_outputs
pruned_outputs = self.param.pruned_outputs
valid_columns = list(max_outputs - pruned_outputs)
# get size of reconstructed output
sizes = list(output.shape)
sizes[1] = len(max_outputs)
# get valid indices of reconstructed output
indices: List[Any] = []
for size in output.shape:
indices.append(slice(0, size, 1))
indices[1] = valid_columns
reconstructed_tensor = torch.zeros(sizes,
dtype=output.dtype,
device=output.device,
layout=output.layout)
reconstructed_tensor[indices] = output
return reconstructed_tensor
assert isinstance(self.mask, torch.Tensor)
assert self.mask.shape[0] == x.shape[0]
shape = [1] * len(x.shape)
shape[0] = -1
return self.mask.reshape(shape) * x
def state_dict(self, *args, **kwargs):
# avoid double saving masks
return {}
class BiasHook:
def __init__(self, parametrization, prune_bias):
self.param = parametrization
self.prune_bias = prune_bias
def __call__(self, module, input, output):
pruned_outputs = self.param.pruned_outputs
if getattr(module, '_bias', None) is not None:
bias = module._bias.data
if self.prune_bias:
bias[list(pruned_outputs)] = 0
bias[~self.param.mask] = 0
# reshape bias to broadcast over output dimensions
idx = [1] * len(output.shape)