Summary:
This PR implements `BaseSparsifier.convert()`, which performs module swapping.
The modules and mappings will be merged in a future PR.
Test Plan:
`python test/test_ao_sparsity.py -- TestBaseSparsifier.test_convert`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97545
Approved by: https://github.com/jerryzh168
Summary:
This PR adds in support for LSTM Structured Pruning.
- Adds in LSTMSaliencyPruner, an implemented pruner that splits the packed weights, finds the appropriate mask for each piece individually based on saliency, and then combines to create an overall mask for the LSTM.
- Adds in pruning functions for LSTM pruning, which will split the weights, apply the masks, and then recombine the pruned weights. Works for both single and multiple-layer LSTMs.
Also added a basic pattern to the default set of of patterns for
LSTM -> Linear pruning
LSTM -> LayerNorm -> Linear pruning
Adds in test to check that LSTM pruning works, as well as for LSTMSaliencyPruner
Test Plan:
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_linear_single_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_linear_multiple_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_layernorm_linear_single_layer`
`python test/test_ao_sparsity.py -- TestBaseStructuredSparsifier.test_prune_lstm_layernorm_linear_multiple_layer`
`python test/test_ao_sparsity.py -- TestSaliencyPruner.test_lstm_saliency_pruner_update_mask`
Reviewers:
Subscribers:
Tasks:
Tags:
Differential Revision: [D42199001](https://our.internmc.facebook.com/intern/diff/D42199001)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90801
Approved by: https://github.com/jerryzh168
Summary:
This PR implements `prune` in BaseStructuredSparsifier:
`prune` is a function that takes in a model with structured sparsity parametritizations (the result of `prepare`) and will return a resized model with the masked out weights removed.
`prune` is defined by a mapping from **patterns** to different **pruning functions**.
- **patterns** are just sequences of operations, for example `(nn.Linear, activation, nn.Linear)`
- **pruning functions** are functions that take in an matched pattern as args and will resize the appropriate layer sizes and weights.
```
def prune_linear_activation_linear(linear1, activation, linear2):
pass
```
- This is one line in the pattern config `(nn.Linear, activation, nn.Linear): prune_linear_activation_linear`
At a high level `prune` works by finding instances of the graph that match different patterns and then calling the mapped pruning functions on those matched patterns.
This is unlike the previous code which attempted to do both at the same time.
There may be some gaps in the patterns compared to the previous implementation, but the conversion functionality support should be the same.
Currently we have pruning functions for the following patterns:
- linear -> linear
- linear -> activation -> linear
- conv2d -> conv2d
- conv2d -> activation -> conv2d
- conv2d -> activation -> pool -> conv2d
- conv2d -> pool -> activation -> conv2d
- conv2d -> adaptive pool -> flatten -> linear
Added in MyPy type hints as well for the prune_functions.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89777
Approved by: https://github.com/vkuzo
Summary:
This PR implements `prune` in BaseStructuredSparsifier:
`prune` is a function that takes in a model with structured sparsity parametritizations (the result of `prepare`) and will return a resized model with the masked out weights removed.
`prune` is defined by a mapping from **patterns** to different **pruning functions**.
- **patterns** are just sequences of operations, for example `(nn.Linear, activation, nn.Linear)`
- **pruning functions** are functions that take in an matched pattern as args and will resize the appropriate layer sizes and weights.
```
def prune_linear_activation_linear(linear1, activation, linear2):
pass
```
- This is one line in the pattern config `(nn.Linear, activation, nn.Linear): prune_linear_activation_linear`
At a high level `prune` works by finding instances of the graph that match different patterns and then calling the mapped pruning functions on those matched patterns.
This is unlike the previous code which attempted to do both at the same time.
There may be some gaps in the patterns compared to the previous implementation, but the conversion functionality support should be the same.
Currently we have pruning functions for the following patterns:
- linear -> linear
- linear -> activation -> linear
- conv2d -> conv2d
- conv2d -> activation -> conv2d
- conv2d -> activation -> pool -> conv2d
- conv2d -> pool -> activation -> conv2d
- conv2d -> adaptive pool -> flatten -> linear
Added in MyPy type hints as well for the prune_functions.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89777
Approved by: https://github.com/vkuzo