Commit Graph

3 Commits

Author SHA1 Message Date
Jesse Cai
de016b3799 [pruning][core][feature] Implement prune for structured pruning (#89777)
Summary:

This PR implements `prune` in BaseStructuredSparsifier:

`prune` is a function that takes in a model with structured sparsity parametritizations (the result of `prepare`) and will return a resized model with the masked out weights removed.

`prune` is defined by a mapping from **patterns** to different **pruning functions**.
	- **patterns** are just sequences of operations, for example `(nn.Linear, activation, nn.Linear)`
	- **pruning functions** are functions that take in an matched pattern as args and will resize the appropriate layer sizes and weights.
	  ```
	  def prune_linear_activation_linear(linear1, activation, linear2):
		pass
	  ```
	- This is one line in the pattern config `(nn.Linear, activation, nn.Linear): prune_linear_activation_linear`

At a high level `prune` works by finding instances of the graph that match different patterns and then calling the mapped pruning functions on those matched patterns.
This is unlike the previous code which attempted to do both at the same time.

There may be some gaps in the patterns compared to the previous implementation, but the conversion functionality support should be the same.

Currently we have pruning functions for the following patterns:
    - linear -> linear
    - linear -> activation -> linear
    - conv2d -> conv2d
    - conv2d -> activation -> conv2d
    - conv2d -> activation -> pool -> conv2d
    - conv2d -> pool -> activation -> conv2d
    - conv2d -> adaptive pool -> flatten -> linear

Added in MyPy type hints as well for the prune_functions.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89777
Approved by: https://github.com/vkuzo
2022-12-08 07:13:24 +00:00
PyTorch MergeBot
1b1301f16a Revert "[pruning][core][feature] Implement prune for structured pruning (#89777)"
This reverts commit 3531e44307.

Reverted https://github.com/pytorch/pytorch/pull/89777 on behalf of https://github.com/clee2000 due to breaking test_ao_sparcity due to import 3531e44307 https://github.com/pytorch/pytorch/actions/runs/3641476330/jobs/6147830487, probably a landrace with 824641b083860df4d7ffef06a798ea2702bc4bde?
2022-12-07 19:41:15 +00:00
Jesse Cai
3531e44307 [pruning][core][feature] Implement prune for structured pruning (#89777)
Summary:

This PR implements `prune` in BaseStructuredSparsifier:

`prune` is a function that takes in a model with structured sparsity parametritizations (the result of `prepare`) and will return a resized model with the masked out weights removed.

`prune` is defined by a mapping from **patterns** to different **pruning functions**.
	- **patterns** are just sequences of operations, for example `(nn.Linear, activation, nn.Linear)`
	- **pruning functions** are functions that take in an matched pattern as args and will resize the appropriate layer sizes and weights.
	  ```
	  def prune_linear_activation_linear(linear1, activation, linear2):
		pass
	  ```
	- This is one line in the pattern config `(nn.Linear, activation, nn.Linear): prune_linear_activation_linear`

At a high level `prune` works by finding instances of the graph that match different patterns and then calling the mapped pruning functions on those matched patterns.
This is unlike the previous code which attempted to do both at the same time.

There may be some gaps in the patterns compared to the previous implementation, but the conversion functionality support should be the same.

Currently we have pruning functions for the following patterns:
    - linear -> linear
    - linear -> activation -> linear
    - conv2d -> conv2d
    - conv2d -> activation -> conv2d
    - conv2d -> activation -> pool -> conv2d
    - conv2d -> pool -> activation -> conv2d
    - conv2d -> adaptive pool -> flatten -> linear

Added in MyPy type hints as well for the prune_functions.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89777
Approved by: https://github.com/vkuzo
2022-12-07 17:52:01 +00:00